diff --git a/kischdle/llmux/llmux/backends/chatterbox_tts.py b/kischdle/llmux/llmux/backends/chatterbox_tts.py new file mode 100644 index 0000000..3d8b09b --- /dev/null +++ b/kischdle/llmux/llmux/backends/chatterbox_tts.py @@ -0,0 +1,73 @@ +import asyncio +import io +import logging + +import soundfile as sf +import torch + +from llmux.backends.base import BaseBackend +from llmux.config import PhysicalModel + +logger = logging.getLogger(__name__) + + +class ChatterboxTTSBackend(BaseBackend): + def __init__(self, models_dir: str = "/models"): + self._models_dir = models_dir + self._loaded: dict[str, dict] = {} + + async def load(self, model_id: str, device: str = "cuda") -> None: + if model_id in self._loaded: + return + physical = _get_physical_config(model_id) + variant = physical.variant + logger.info(f"Loading Chatterbox {variant} to {device}") + + def _load(): + from chatterbox.tts import ChatterboxTTS + if variant == "turbo": + model = ChatterboxTTS.from_pretrained(device=device, variant="turbo") + elif variant == "multilingual": + model = ChatterboxTTS.from_pretrained(device=device, variant="multilingual") + else: + model = ChatterboxTTS.from_pretrained(device=device) + return model + + loop = asyncio.get_event_loop() + model = await loop.run_in_executor(None, _load) + self._loaded[model_id] = {"model": model, "device": device} + + async def unload(self, model_id: str) -> None: + if model_id not in self._loaded: + return + entry = self._loaded.pop(model_id) + del entry["model"] + torch.cuda.empty_cache() + + async def generate(self, model_id, messages, params, stream=False, tools=None): + raise NotImplementedError("TTS backend does not support chat generation") + + async def synthesize(self, model_id: str, text: str, voice: str = "default") -> bytes: + entry = self._loaded[model_id] + model = entry["model"] + + def _synthesize(): + wav = model.generate(text) + buf = io.BytesIO() + sf.write(buf, wav.cpu().numpy().squeeze(), samplerate=24000, format="WAV") + buf.seek(0) + return buf.read() + + loop = asyncio.get_event_loop() + audio_bytes = await loop.run_in_executor(None, _synthesize) + return audio_bytes + + +_physical_models: dict[str, PhysicalModel] = {} + +def set_physical_models(models: dict[str, PhysicalModel]) -> None: + global _physical_models + _physical_models = models + +def _get_physical_config(model_id: str) -> PhysicalModel: + return _physical_models[model_id]