adri 1 месяц назад
Сommit
1b4cb53ac8

+ 48 - 0
docker-compose.yml

@@ -0,0 +1,48 @@
+services:
+  kokoro-fastapi:
+    image: ghcr.io/remsky/kokoro-fastapi-gpu:v0.2.4
+    container_name: kokoro-fastapi
+    restart: unless-stopped
+    volumes:
+      - ${MODELS_DIR}/models:/app/api/src/models
+      - ${MODELS_DIR}/voices:/app/api/src/voices
+    ports:
+      - "8880:8880"
+    environment:
+      - USE_GPU=true
+      - PYTHONUNBUFFERED=1
+      - REPO_ID=hexgrad/Kokoro-82M
+    healthcheck:
+      test: ["CMD", "curl", "-f", "http://localhost:8880/health"]
+      interval: 20s
+      timeout: 5s
+      retries: 5
+      start_period: 90s
+    deploy:
+      resources:
+        reservations:
+          devices:
+            - driver: nvidia
+              device_ids: ['0']
+              capabilities: [gpu]
+             
+  kokoro-wyoming:
+    build: kokoro-wyoming
+    container_name: kokoro-wyoming
+    depends_on:
+      kokoro-fastapi:
+        condition: service_healthy
+    links:
+      - kokoro-fastapi
+    ports:
+      - "10200:10200"
+    restart: unless-stopped
+    tty: true
+    stdin_open: true
+    environment:
+      - API_HOST=http://kokoro-fastapi  # Set TTS service URL
+      - API_PORT=8880  # Set TTS service PORT
+      - VOICE_SPEED=${VOICE_SPEED}
+    labels:
+      com.centurylinklabs.watchtower.enable: false
+

+ 15 - 0
kokoro-wyoming/.dockerignore

@@ -0,0 +1,15 @@
+venv/
+.venv/
+__pycache__/
+*.pyc
+.git/
+.gitignore
+.env
+*.wav
+.pytest_cache/
+.coverage
+htmlcov/
+.idea/
+.vscode/
+*.json
+*.onnx

+ 41 - 0
kokoro-wyoming/.gitignore

@@ -0,0 +1,41 @@
+# Python
+__pycache__/
+*.py[cod]
+*$py.class
+*.so
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+
+# Virtual Environment
+venv/
+env/
+ENV/
+.venv/
+
+# IDE
+.idea/
+.vscode/
+*.swp
+*.swo
+
+# Project specific
+*.wav
+
+*.bin
+*.json
+*.onnx
+*.env

+ 16 - 0
kokoro-wyoming/Dockerfile

@@ -0,0 +1,16 @@
+FROM python:3.12-slim
+WORKDIR /app
+
+# Install curl
+RUN apt-get update && apt-get install -y curl && rm -rf /var/lib/apt/lists/*
+
+# Install required packages
+COPY requirements.txt .
+RUN pip install -r requirements.txt
+
+COPY src/ /app/src/
+ENV PYTHONPATH=/app
+WORKDIR /app/src
+# TODO: figure out how to use the DEBUG env variable instead
+CMD ["python", "main.py", "--debug"]
+

+ 8 - 0
kokoro-wyoming/README.md

@@ -0,0 +1,8 @@
+This fork of [@nordwestt](https://github.com/nordwestt)'s [excellent work](https://github.com/nordwestt/kokoro-wyoming) 
+brings [Kokoro FastAPI](https://github.com/remsky/Kokoro-FastAPI), a blazing fast API implementation of the original 
+[Kokoro-82M](https://huggingface.co/hexgrad/Kokoro-82M) text-to-speech (TTS) model that can run with GPU acceleration 
+to achieve highly realistic speech synthesis, to the [Wyoming Protocol](https://github.com/rhasspy/wyoming) for 
+[Home Assistant](https://home-assistant.io).
+
+Kokoro-FastAPI runs in its own container, and kokoro-wyoming runs in another container,  provides the Wyoming Protocol 
+wrapper around it. In theory, kokoro-wyoming could be easily extended to work with any TTS engine API.

+ 4 - 0
kokoro-wyoming/requirements.txt

@@ -0,0 +1,4 @@
+wyoming~=1.6.0
+numpy~=2.2.4
+requests~=2.32.3
+aiohttp~=3.11.14

+ 291 - 0
kokoro-wyoming/src/main.py

@@ -0,0 +1,291 @@
+#!/usr/bin/env python3
+import argparse
+import asyncio
+import logging
+import os
+import sys
+import time
+from dataclasses import dataclass
+from functools import partial
+
+import aiohttp
+import requests
+from wyoming.audio import AudioChunk, AudioStart, AudioStop
+from wyoming.error import Error
+from wyoming.event import Event
+from wyoming.info import Attribution, TtsProgram, TtsVoice, TtsVoiceSpeaker, Describe, Info
+from wyoming.server import AsyncEventHandler
+from wyoming.server import AsyncServer
+from wyoming.tts import Synthesize
+
+_LOGGER = logging.getLogger(__name__)
+VERSION = "0.2"
+
+
+@dataclass
+class KokoroVoice:
+    name: str
+    language: str
+    kokoro_id: str
+
+
+class KokoroEventHandler(AsyncEventHandler):
+    def __init__(self,
+                 wyoming_info: Info,
+                 kokoro_endpoint,
+                 cli_args: argparse.Namespace,
+                 *args,
+                 **kwargs):
+        super().__init__(*args, **kwargs)
+
+        self.kokoro_endpoint = kokoro_endpoint
+        self.cli_args = cli_args
+        self.args = args
+        self.wyoming_info_event = wyoming_info.event()
+        self.sample_rate = 24000  # Known sample rate for Kokoro
+        self.channels = 1
+        self.sample_width = 2
+        self.chunk_size = 512
+        self.speed = 1.0
+        # self.normalization_options = args["normalization"]
+        _LOGGER.info("Event Handler initialized and awaiting events")
+
+    async def handle_event(self, event: Event) -> bool:
+        """Handle Wyoming protocol events."""
+        _LOGGER.debug(f"Handling an Event: {event}")
+
+        if Describe.is_type(event.type):
+            await self.write_event(self.wyoming_info_event)
+            _LOGGER.debug("Sent info")
+            return True
+
+        if not Synthesize.is_type(event.type):
+            _LOGGER.warning("Unexpected event: %s", event)
+            return True
+
+        try:
+            return await self._handle_synthesize(event)
+        except Exception as err:
+            await self.write_event(
+                Error(text=str(err), code=err.__class__.__name__).event()
+            )
+            raise err
+
+    async def _handle_synthesize(self, event: Event) -> bool:
+        """Handle text to speech synthesis request."""
+        synthesize = Synthesize.from_event(event)
+
+        # Get voice settings
+        voice_name = "af_heart"  # default voice
+        # lang_code = "en"
+        if synthesize.voice:
+            voice_name = synthesize.voice.name
+            # lang_code = synthesize.voice.language
+
+        _LOGGER.info("Starting TTS stream request...")
+        start_time = time.time()
+
+        # Initialize variables
+        audio_started = False
+        chunk_count = 0
+        total_bytes = 0
+        success = False
+
+        # Make streaming request to API
+        try:
+            async with aiohttp.ClientSession() as session:
+                async with session.post(
+                        url=self.kokoro_endpoint,
+                        json={
+                            "model": "kokoro",
+                            "input": synthesize.text,
+                            "voice": voice_name,
+                            # "lang_code": lang_code,
+                            "speed": self.speed,
+                            # "normalization_options": self.normalization_options,
+                            "response_format": "pcm",
+                            "stream": True,
+                        },
+                        timeout=1800,
+                ) as response:
+
+                    response.raise_for_status()
+                    _LOGGER.debug(f"Request started successfully after {time.time() - start_time:.2f}s")
+
+                    # Process streaming response with smaller chunks for lower latency
+                    async for chunk in response.content.iter_chunked(
+                            self.chunk_size):  # 512 bytes = 256 samples at 16-bit
+                        if chunk:
+                            chunk_count += 1
+                            total_bytes += len(chunk)
+
+                            # Handle first chunk
+                            if not audio_started:
+                                first_chunk_time = time.time()
+                                _LOGGER.debug(
+                                    f"Received first chunk after {first_chunk_time - start_time:.2f}s"
+                                )
+                                # _LOGGER.debug(f"First chunk size: {len(chunk)} bytes")
+                                audio_started = True
+
+                                _LOGGER.debug("Sending AudioStart")
+                                await self.write_event(
+                                    AudioStart(
+                                        rate=self.sample_rate,
+                                        width=self.sample_width,
+                                        channels=self.channels,
+                                    ).event()
+                                )
+
+                            # Send audio chunk
+                            await self.write_event(
+                                AudioChunk(
+                                    audio=chunk,
+                                    rate=self.sample_rate,
+                                    width=self.sample_width,
+                                    channels=self.channels,
+                                ).event()
+                            )
+
+                            # Log progress every 100 chunks
+                            if chunk_count % 100 == 0:
+                                elapsed = time.time() - start_time
+                                _LOGGER.debug(
+                                    f"Progress: {chunk_count} chunks, {total_bytes / 1024:.1f}KB received, {elapsed:.1f}s elapsed"
+                                )
+
+            # Final stats
+            total_time = time.time() - start_time
+            _LOGGER.info(f"Stream complete:")
+            _LOGGER.debug(f"Total chunks: {chunk_count}")
+            _LOGGER.debug(f"Total data: {total_bytes / 1024:.1f}KB")
+            _LOGGER.info(f"Total time: {total_time:.2f}s")
+            _LOGGER.info(f"Average speed: {(total_bytes / 1024) / total_time:.1f}KB/s")
+
+            # Clean up
+            success = True
+
+        # except requests.exceptions.ConnectionError as e:
+        #     _LOGGER.error(f"Connection error - Is the server running? Error: {str(e)}")
+        except Exception as e:
+            _LOGGER.error(f"Error during streaming: {str(e)}")
+        finally:
+            # Send audio stop
+            _LOGGER.debug("Sending AudioStop")
+            await self.write_event(
+                AudioStop().event()
+            )
+            return success
+
+
+async def main():
+    """Main entry point."""
+    kokoro_api_host = os.getenv("API_HOST", "http://localhost")
+    kokoro_api_port = os.getenv("API_PORT", "8880")
+    kokoro_endpoint = f"{kokoro_api_host}:{kokoro_api_port}/v1/audio"
+
+    listen_host = os.getenv("LISTEN_HOST", "0.0.0.0")
+    listen_port = os.getenv("LISTEN_PORT", 10200)
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--host",
+        default=listen_host,
+        help="Host to listen on"
+    )
+    parser.add_argument(
+        "--port",
+        type=int,
+        default=listen_port,
+        help="Port to listen on"
+    )
+    parser.add_argument(
+        "--uri",
+        default=f"{os.getenv('LISTEN_PROTOCOL', 'tcp')}://{listen_host}:{listen_port}",
+        help="unix:// or tcp://[host]:[port]"
+    )
+    parser.add_argument(
+        "--speed",
+        default=os.getenv("VOICE_SPEED", 1),
+        help="Voice speed"
+    )
+    # parser.add_argument(
+    #     "--normalization",
+    #     default=os.getenv("VOICE_NORMALIZATION_OPTIONS", ""),
+    #     help="Normalization options"
+    # )
+    parser.add_argument(
+        "--debug",
+        # default=(os.getenv("DEBUG", "false").lower() == "true"),
+        action="store_false",
+        help="Enable debug logging",
+    )
+    args = parser.parse_args()
+
+    logging.basicConfig(
+        level=logging.DEBUG if args.debug else logging.INFO,
+        format="%(asctime)s %(levelname)s: %(message)s",
+        stream=sys.stdout
+    )
+
+    _LOGGER.debug(args)
+
+    _LOGGER.debug(f"using {kokoro_endpoint} as endpoint")
+
+    # Get list of voices from Kokoro endpoint
+    response = requests.get(f"{kokoro_endpoint}/voices")
+    voice_names = response.json()["voices"]
+    # TODO: Parameterize custom voice blends
+    voice_names.append("af_heart(2)+af_bella(1)+af_nicole(1)")
+
+    # Define available voices
+    voices = [
+        TtsVoice(
+            name=voice,
+            description=f"Kokoro voice {voice}",
+            attribution=Attribution(
+                name="hexgrad", url="https://github.com/hexgrad/kokoro"
+            ),
+            installed=True,
+            version=None,
+            languages=[
+                "ja" if voice.startswith("j") else  # japanese
+                "zh" if voice.startswith("z") else  # mandarin chinese
+                "es" if voice.startswith("e") else  # spanish
+                "fr" if voice.startswith("f") else  # french
+                "hi" if voice.startswith("h") else  # hindi
+                "it" if voice.startswith("i") else  # italian
+                "pt" if voice.startswith("p") else  # brazilian portuguese
+                "en"  # british and american english
+            ],
+            speakers=[
+                TtsVoiceSpeaker(name=voice.split("_")[1])
+            ]
+        )
+        for voice in voice_names
+    ]
+
+    wyoming_info = Info(
+        tts=[TtsProgram(
+            name="kokoro",
+            description="A fast, local, kokoro-based tts engine",
+            attribution=Attribution(
+                name="Kokoro TTS",
+                url="https://huggingface.co/hexgrad/Kokoro-82M",
+            ),
+            installed=True,
+            voices=sorted(voices, key=lambda v: v.name),
+            version="1.6.0"
+        )]
+    )
+
+    server = AsyncServer.from_uri(args.uri)
+
+    # Start server with kokoro instance
+    await server.run(partial(KokoroEventHandler, wyoming_info, f"{kokoro_endpoint}/speech", args))
+
+if __name__ == "__main__":
+    try:
+        asyncio.run(main())
+    except KeyboardInterrupt:
+        pass