| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291 |
- #!/usr/bin/env python3
- import argparse
- import asyncio
- import logging
- import os
- import sys
- import time
- from dataclasses import dataclass
- from functools import partial
- import aiohttp
- import requests
- from wyoming.audio import AudioChunk, AudioStart, AudioStop
- from wyoming.error import Error
- from wyoming.event import Event
- from wyoming.info import Attribution, TtsProgram, TtsVoice, TtsVoiceSpeaker, Describe, Info
- from wyoming.server import AsyncEventHandler
- from wyoming.server import AsyncServer
- from wyoming.tts import Synthesize
- _LOGGER = logging.getLogger(__name__)
- VERSION = "0.2"
- @dataclass
- class KokoroVoice:
- name: str
- language: str
- kokoro_id: str
- class KokoroEventHandler(AsyncEventHandler):
- def __init__(self,
- wyoming_info: Info,
- kokoro_endpoint,
- cli_args: argparse.Namespace,
- *args,
- **kwargs):
- super().__init__(*args, **kwargs)
- self.kokoro_endpoint = kokoro_endpoint
- self.cli_args = cli_args
- self.args = args
- self.wyoming_info_event = wyoming_info.event()
- self.sample_rate = 24000 # Known sample rate for Kokoro
- self.channels = 1
- self.sample_width = 2
- self.chunk_size = 512
- self.speed = 1.0
- # self.normalization_options = args["normalization"]
- _LOGGER.info("Event Handler initialized and awaiting events")
- async def handle_event(self, event: Event) -> bool:
- """Handle Wyoming protocol events."""
- _LOGGER.debug(f"Handling an Event: {event}")
- if Describe.is_type(event.type):
- await self.write_event(self.wyoming_info_event)
- _LOGGER.debug("Sent info")
- return True
- if not Synthesize.is_type(event.type):
- _LOGGER.warning("Unexpected event: %s", event)
- return True
- try:
- return await self._handle_synthesize(event)
- except Exception as err:
- await self.write_event(
- Error(text=str(err), code=err.__class__.__name__).event()
- )
- raise err
- async def _handle_synthesize(self, event: Event) -> bool:
- """Handle text to speech synthesis request."""
- synthesize = Synthesize.from_event(event)
- # Get voice settings
- voice_name = "af_heart" # default voice
- # lang_code = "en"
- if synthesize.voice:
- voice_name = synthesize.voice.name
- # lang_code = synthesize.voice.language
- _LOGGER.info("Starting TTS stream request...")
- start_time = time.time()
- # Initialize variables
- audio_started = False
- chunk_count = 0
- total_bytes = 0
- success = False
- # Make streaming request to API
- try:
- async with aiohttp.ClientSession() as session:
- async with session.post(
- url=self.kokoro_endpoint,
- json={
- "model": "kokoro",
- "input": synthesize.text,
- "voice": voice_name,
- # "lang_code": lang_code,
- "speed": self.speed,
- # "normalization_options": self.normalization_options,
- "response_format": "pcm",
- "stream": True,
- },
- timeout=1800,
- ) as response:
- response.raise_for_status()
- _LOGGER.debug(f"Request started successfully after {time.time() - start_time:.2f}s")
- # Process streaming response with smaller chunks for lower latency
- async for chunk in response.content.iter_chunked(
- self.chunk_size): # 512 bytes = 256 samples at 16-bit
- if chunk:
- chunk_count += 1
- total_bytes += len(chunk)
- # Handle first chunk
- if not audio_started:
- first_chunk_time = time.time()
- _LOGGER.debug(
- f"Received first chunk after {first_chunk_time - start_time:.2f}s"
- )
- # _LOGGER.debug(f"First chunk size: {len(chunk)} bytes")
- audio_started = True
- _LOGGER.debug("Sending AudioStart")
- await self.write_event(
- AudioStart(
- rate=self.sample_rate,
- width=self.sample_width,
- channels=self.channels,
- ).event()
- )
- # Send audio chunk
- await self.write_event(
- AudioChunk(
- audio=chunk,
- rate=self.sample_rate,
- width=self.sample_width,
- channels=self.channels,
- ).event()
- )
- # Log progress every 100 chunks
- if chunk_count % 100 == 0:
- elapsed = time.time() - start_time
- _LOGGER.debug(
- f"Progress: {chunk_count} chunks, {total_bytes / 1024:.1f}KB received, {elapsed:.1f}s elapsed"
- )
- # Final stats
- total_time = time.time() - start_time
- _LOGGER.info(f"Stream complete:")
- _LOGGER.debug(f"Total chunks: {chunk_count}")
- _LOGGER.debug(f"Total data: {total_bytes / 1024:.1f}KB")
- _LOGGER.info(f"Total time: {total_time:.2f}s")
- _LOGGER.info(f"Average speed: {(total_bytes / 1024) / total_time:.1f}KB/s")
- # Clean up
- success = True
- # except requests.exceptions.ConnectionError as e:
- # _LOGGER.error(f"Connection error - Is the server running? Error: {str(e)}")
- except Exception as e:
- _LOGGER.error(f"Error during streaming: {str(e)}")
- finally:
- # Send audio stop
- _LOGGER.debug("Sending AudioStop")
- await self.write_event(
- AudioStop().event()
- )
- return success
- async def main():
- """Main entry point."""
- kokoro_api_host = os.getenv("API_HOST", "http://localhost")
- kokoro_api_port = os.getenv("API_PORT", "8880")
- kokoro_endpoint = f"{kokoro_api_host}:{kokoro_api_port}/v1/audio"
- listen_host = os.getenv("LISTEN_HOST", "0.0.0.0")
- listen_port = os.getenv("LISTEN_PORT", 10200)
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--host",
- default=listen_host,
- help="Host to listen on"
- )
- parser.add_argument(
- "--port",
- type=int,
- default=listen_port,
- help="Port to listen on"
- )
- parser.add_argument(
- "--uri",
- default=f"{os.getenv('LISTEN_PROTOCOL', 'tcp')}://{listen_host}:{listen_port}",
- help="unix:// or tcp://[host]:[port]"
- )
- parser.add_argument(
- "--speed",
- default=os.getenv("VOICE_SPEED", 1),
- help="Voice speed"
- )
- # parser.add_argument(
- # "--normalization",
- # default=os.getenv("VOICE_NORMALIZATION_OPTIONS", ""),
- # help="Normalization options"
- # )
- parser.add_argument(
- "--debug",
- # default=(os.getenv("DEBUG", "false").lower() == "true"),
- action="store_false",
- help="Enable debug logging",
- )
- args = parser.parse_args()
- logging.basicConfig(
- level=logging.DEBUG if args.debug else logging.INFO,
- format="%(asctime)s %(levelname)s: %(message)s",
- stream=sys.stdout
- )
- _LOGGER.debug(args)
- _LOGGER.debug(f"using {kokoro_endpoint} as endpoint")
- # Get list of voices from Kokoro endpoint
- response = requests.get(f"{kokoro_endpoint}/voices")
- voice_names = response.json()["voices"]
- # TODO: Parameterize custom voice blends
- voice_names.append("af_heart(2)+af_bella(1)+af_nicole(1)")
- # Define available voices
- voices = [
- TtsVoice(
- name=voice,
- description=f"Kokoro voice {voice}",
- attribution=Attribution(
- name="hexgrad", url="https://github.com/hexgrad/kokoro"
- ),
- installed=True,
- version=None,
- languages=[
- "ja" if voice.startswith("j") else # japanese
- "zh" if voice.startswith("z") else # mandarin chinese
- "es" if voice.startswith("e") else # spanish
- "fr" if voice.startswith("f") else # french
- "hi" if voice.startswith("h") else # hindi
- "it" if voice.startswith("i") else # italian
- "pt" if voice.startswith("p") else # brazilian portuguese
- "en" # british and american english
- ],
- speakers=[
- TtsVoiceSpeaker(name=voice.split("_")[1])
- ]
- )
- for voice in voice_names
- ]
- wyoming_info = Info(
- tts=[TtsProgram(
- name="kokoro",
- description="A fast, local, kokoro-based tts engine",
- attribution=Attribution(
- name="Kokoro TTS",
- url="https://huggingface.co/hexgrad/Kokoro-82M",
- ),
- installed=True,
- voices=sorted(voices, key=lambda v: v.name),
- version="1.6.0"
- )]
- )
- server = AsyncServer.from_uri(args.uri)
- # Start server with kokoro instance
- await server.run(partial(KokoroEventHandler, wyoming_info, f"{kokoro_endpoint}/speech", args))
- if __name__ == "__main__":
- try:
- asyncio.run(main())
- except KeyboardInterrupt:
- pass
|