main.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. #!/usr/bin/env python3
  2. import argparse
  3. import asyncio
  4. import logging
  5. import os
  6. import sys
  7. import time
  8. from dataclasses import dataclass
  9. from functools import partial
  10. import aiohttp
  11. import requests
  12. from wyoming.audio import AudioChunk, AudioStart, AudioStop
  13. from wyoming.error import Error
  14. from wyoming.event import Event
  15. from wyoming.info import Attribution, TtsProgram, TtsVoice, TtsVoiceSpeaker, Describe, Info
  16. from wyoming.server import AsyncEventHandler
  17. from wyoming.server import AsyncServer
  18. from wyoming.tts import Synthesize
  19. _LOGGER = logging.getLogger(__name__)
  20. VERSION = "0.2"
  21. @dataclass
  22. class KokoroVoice:
  23. name: str
  24. language: str
  25. kokoro_id: str
  26. class KokoroEventHandler(AsyncEventHandler):
  27. def __init__(self,
  28. wyoming_info: Info,
  29. kokoro_endpoint,
  30. cli_args: argparse.Namespace,
  31. *args,
  32. **kwargs):
  33. super().__init__(*args, **kwargs)
  34. self.kokoro_endpoint = kokoro_endpoint
  35. self.cli_args = cli_args
  36. self.args = args
  37. self.wyoming_info_event = wyoming_info.event()
  38. self.sample_rate = 24000 # Known sample rate for Kokoro
  39. self.channels = 1
  40. self.sample_width = 2
  41. self.chunk_size = 512
  42. self.speed = 1.0
  43. # self.normalization_options = args["normalization"]
  44. _LOGGER.info("Event Handler initialized and awaiting events")
  45. async def handle_event(self, event: Event) -> bool:
  46. """Handle Wyoming protocol events."""
  47. _LOGGER.debug(f"Handling an Event: {event}")
  48. if Describe.is_type(event.type):
  49. await self.write_event(self.wyoming_info_event)
  50. _LOGGER.debug("Sent info")
  51. return True
  52. if not Synthesize.is_type(event.type):
  53. _LOGGER.warning("Unexpected event: %s", event)
  54. return True
  55. try:
  56. return await self._handle_synthesize(event)
  57. except Exception as err:
  58. await self.write_event(
  59. Error(text=str(err), code=err.__class__.__name__).event()
  60. )
  61. raise err
  62. async def _handle_synthesize(self, event: Event) -> bool:
  63. """Handle text to speech synthesis request."""
  64. synthesize = Synthesize.from_event(event)
  65. # Get voice settings
  66. voice_name = "af_heart" # default voice
  67. # lang_code = "en"
  68. if synthesize.voice:
  69. voice_name = synthesize.voice.name
  70. # lang_code = synthesize.voice.language
  71. _LOGGER.info("Starting TTS stream request...")
  72. start_time = time.time()
  73. # Initialize variables
  74. audio_started = False
  75. chunk_count = 0
  76. total_bytes = 0
  77. success = False
  78. # Make streaming request to API
  79. try:
  80. async with aiohttp.ClientSession() as session:
  81. async with session.post(
  82. url=self.kokoro_endpoint,
  83. json={
  84. "model": "kokoro",
  85. "input": synthesize.text,
  86. "voice": voice_name,
  87. # "lang_code": lang_code,
  88. "speed": self.speed,
  89. # "normalization_options": self.normalization_options,
  90. "response_format": "pcm",
  91. "stream": True,
  92. },
  93. timeout=1800,
  94. ) as response:
  95. response.raise_for_status()
  96. _LOGGER.debug(f"Request started successfully after {time.time() - start_time:.2f}s")
  97. # Process streaming response with smaller chunks for lower latency
  98. async for chunk in response.content.iter_chunked(
  99. self.chunk_size): # 512 bytes = 256 samples at 16-bit
  100. if chunk:
  101. chunk_count += 1
  102. total_bytes += len(chunk)
  103. # Handle first chunk
  104. if not audio_started:
  105. first_chunk_time = time.time()
  106. _LOGGER.debug(
  107. f"Received first chunk after {first_chunk_time - start_time:.2f}s"
  108. )
  109. # _LOGGER.debug(f"First chunk size: {len(chunk)} bytes")
  110. audio_started = True
  111. _LOGGER.debug("Sending AudioStart")
  112. await self.write_event(
  113. AudioStart(
  114. rate=self.sample_rate,
  115. width=self.sample_width,
  116. channels=self.channels,
  117. ).event()
  118. )
  119. # Send audio chunk
  120. await self.write_event(
  121. AudioChunk(
  122. audio=chunk,
  123. rate=self.sample_rate,
  124. width=self.sample_width,
  125. channels=self.channels,
  126. ).event()
  127. )
  128. # Log progress every 100 chunks
  129. if chunk_count % 100 == 0:
  130. elapsed = time.time() - start_time
  131. _LOGGER.debug(
  132. f"Progress: {chunk_count} chunks, {total_bytes / 1024:.1f}KB received, {elapsed:.1f}s elapsed"
  133. )
  134. # Final stats
  135. total_time = time.time() - start_time
  136. _LOGGER.info(f"Stream complete:")
  137. _LOGGER.debug(f"Total chunks: {chunk_count}")
  138. _LOGGER.debug(f"Total data: {total_bytes / 1024:.1f}KB")
  139. _LOGGER.info(f"Total time: {total_time:.2f}s")
  140. _LOGGER.info(f"Average speed: {(total_bytes / 1024) / total_time:.1f}KB/s")
  141. # Clean up
  142. success = True
  143. # except requests.exceptions.ConnectionError as e:
  144. # _LOGGER.error(f"Connection error - Is the server running? Error: {str(e)}")
  145. except Exception as e:
  146. _LOGGER.error(f"Error during streaming: {str(e)}")
  147. finally:
  148. # Send audio stop
  149. _LOGGER.debug("Sending AudioStop")
  150. await self.write_event(
  151. AudioStop().event()
  152. )
  153. return success
  154. async def main():
  155. """Main entry point."""
  156. kokoro_api_host = os.getenv("API_HOST", "http://localhost")
  157. kokoro_api_port = os.getenv("API_PORT", "8880")
  158. kokoro_endpoint = f"{kokoro_api_host}:{kokoro_api_port}/v1/audio"
  159. listen_host = os.getenv("LISTEN_HOST", "0.0.0.0")
  160. listen_port = os.getenv("LISTEN_PORT", 10200)
  161. parser = argparse.ArgumentParser()
  162. parser.add_argument(
  163. "--host",
  164. default=listen_host,
  165. help="Host to listen on"
  166. )
  167. parser.add_argument(
  168. "--port",
  169. type=int,
  170. default=listen_port,
  171. help="Port to listen on"
  172. )
  173. parser.add_argument(
  174. "--uri",
  175. default=f"{os.getenv('LISTEN_PROTOCOL', 'tcp')}://{listen_host}:{listen_port}",
  176. help="unix:// or tcp://[host]:[port]"
  177. )
  178. parser.add_argument(
  179. "--speed",
  180. default=os.getenv("VOICE_SPEED", 1),
  181. help="Voice speed"
  182. )
  183. # parser.add_argument(
  184. # "--normalization",
  185. # default=os.getenv("VOICE_NORMALIZATION_OPTIONS", ""),
  186. # help="Normalization options"
  187. # )
  188. parser.add_argument(
  189. "--debug",
  190. # default=(os.getenv("DEBUG", "false").lower() == "true"),
  191. action="store_false",
  192. help="Enable debug logging",
  193. )
  194. args = parser.parse_args()
  195. logging.basicConfig(
  196. level=logging.DEBUG if args.debug else logging.INFO,
  197. format="%(asctime)s %(levelname)s: %(message)s",
  198. stream=sys.stdout
  199. )
  200. _LOGGER.debug(args)
  201. _LOGGER.debug(f"using {kokoro_endpoint} as endpoint")
  202. # Get list of voices from Kokoro endpoint
  203. response = requests.get(f"{kokoro_endpoint}/voices")
  204. voice_names = response.json()["voices"]
  205. # TODO: Parameterize custom voice blends
  206. voice_names.append("af_heart(2)+af_bella(1)+af_nicole(1)")
  207. # Define available voices
  208. voices = [
  209. TtsVoice(
  210. name=voice,
  211. description=f"Kokoro voice {voice}",
  212. attribution=Attribution(
  213. name="hexgrad", url="https://github.com/hexgrad/kokoro"
  214. ),
  215. installed=True,
  216. version=None,
  217. languages=[
  218. "ja" if voice.startswith("j") else # japanese
  219. "zh" if voice.startswith("z") else # mandarin chinese
  220. "es" if voice.startswith("e") else # spanish
  221. "fr" if voice.startswith("f") else # french
  222. "hi" if voice.startswith("h") else # hindi
  223. "it" if voice.startswith("i") else # italian
  224. "pt" if voice.startswith("p") else # brazilian portuguese
  225. "en" # british and american english
  226. ],
  227. speakers=[
  228. TtsVoiceSpeaker(name=voice.split("_")[1])
  229. ]
  230. )
  231. for voice in voice_names
  232. ]
  233. wyoming_info = Info(
  234. tts=[TtsProgram(
  235. name="kokoro",
  236. description="A fast, local, kokoro-based tts engine",
  237. attribution=Attribution(
  238. name="Kokoro TTS",
  239. url="https://huggingface.co/hexgrad/Kokoro-82M",
  240. ),
  241. installed=True,
  242. voices=sorted(voices, key=lambda v: v.name),
  243. version="1.6.0"
  244. )]
  245. )
  246. server = AsyncServer.from_uri(args.uri)
  247. # Start server with kokoro instance
  248. await server.run(partial(KokoroEventHandler, wyoming_info, f"{kokoro_endpoint}/speech", args))
  249. if __name__ == "__main__":
  250. try:
  251. asyncio.run(main())
  252. except KeyboardInterrupt:
  253. pass