fix: Chatterbox uses separate classes per variant, remove turbo

ChatterboxTTS and ChatterboxMultilingualTTS are separate classes.
Turbo variant doesn't exist in chatterbox-tts 0.1.7.
Multilingual generate() requires language_id parameter.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
tlg
2026-04-05 21:43:40 +02:00
parent f24a225baf
commit d615bb4553
5 changed files with 21 additions and 21 deletions

View File

@@ -48,12 +48,6 @@ physical_models:
estimated_vram_gb: 4 estimated_vram_gb: 4
default_language: "en" default_language: "en"
chatterbox-turbo:
type: tts
backend: chatterbox
variant: "turbo"
estimated_vram_gb: 2
chatterbox-multilingual: chatterbox-multilingual:
type: tts type: tts
backend: chatterbox backend: chatterbox
@@ -110,8 +104,6 @@ virtual_models:
cohere-transcribe: cohere-transcribe:
physical: cohere-transcribe physical: cohere-transcribe
Chatterbox-Turbo:
physical: chatterbox-turbo
Chatterbox-Multilingual: Chatterbox-Multilingual:
physical: chatterbox-multilingual physical: chatterbox-multilingual
Chatterbox: Chatterbox:

View File

@@ -1,4 +1,5 @@
import asyncio import asyncio
import gc
import io import io
import logging import logging
@@ -24,25 +25,26 @@ class ChatterboxTTSBackend(BaseBackend):
logger.info(f"Loading Chatterbox {variant} to {device}") logger.info(f"Loading Chatterbox {variant} to {device}")
def _load(): def _load():
from chatterbox.tts import ChatterboxTTS if variant == "multilingual":
if variant == "turbo": from chatterbox import ChatterboxMultilingualTTS
model = ChatterboxTTS.from_pretrained(device=device, variant="turbo") return ChatterboxMultilingualTTS.from_pretrained(device=device)
elif variant == "multilingual":
model = ChatterboxTTS.from_pretrained(device=device, variant="multilingual")
else: else:
model = ChatterboxTTS.from_pretrained(device=device) from chatterbox.tts import ChatterboxTTS
return model return ChatterboxTTS.from_pretrained(device=device)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
model = await loop.run_in_executor(None, _load) model = await loop.run_in_executor(None, _load)
self._loaded[model_id] = {"model": model, "device": device} self._loaded[model_id] = {"model": model, "variant": variant, "device": device}
async def unload(self, model_id: str) -> None: async def unload(self, model_id: str) -> None:
if model_id not in self._loaded: if model_id not in self._loaded:
return return
entry = self._loaded.pop(model_id) entry = self._loaded.pop(model_id)
del entry["model"] del entry["model"]
del entry
gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
logger.info(f"Unloaded Chatterbox {model_id}")
async def generate(self, model_id, messages, params, stream=False, tools=None): async def generate(self, model_id, messages, params, stream=False, tools=None):
raise NotImplementedError("TTS backend does not support chat generation") raise NotImplementedError("TTS backend does not support chat generation")
@@ -50,9 +52,15 @@ class ChatterboxTTSBackend(BaseBackend):
async def synthesize(self, model_id: str, text: str, voice: str = "default") -> bytes: async def synthesize(self, model_id: str, text: str, voice: str = "default") -> bytes:
entry = self._loaded[model_id] entry = self._loaded[model_id]
model = entry["model"] model = entry["model"]
variant = entry["variant"]
def _synthesize(): def _synthesize():
wav = model.generate(text) if variant == "multilingual":
# Default to English; voice param could encode language
lang = "en" if voice == "default" else voice
wav = model.generate(text, language_id=lang)
else:
wav = model.generate(text)
buf = io.BytesIO() buf = io.BytesIO()
sf.write(buf, wav.cpu().numpy().squeeze(), samplerate=24000, format="WAV") sf.write(buf, wav.cpu().numpy().squeeze(), samplerate=24000, format="WAV")
buf.seek(0) buf.seek(0)

View File

@@ -5,8 +5,8 @@ def test_load_models_config_returns_physical_and_virtual():
physical, virtual = load_models_config() physical, virtual = load_models_config()
assert isinstance(physical, dict) assert isinstance(physical, dict)
assert isinstance(virtual, dict) assert isinstance(virtual, dict)
assert len(physical) == 9 assert len(physical) == 8
assert len(virtual) == 16 assert len(virtual) == 15
def test_physical_model_has_required_fields(): def test_physical_model_has_required_fields():

View File

@@ -10,7 +10,7 @@ def registry():
def test_list_virtual_models(registry): def test_list_virtual_models(registry):
models = registry.list_virtual_models() models = registry.list_virtual_models()
assert len(models) == 16 assert len(models) == 15
names = [m["id"] for m in models] names = [m["id"] for m in models]
assert "Qwen3.5-9B-FP8-Thinking" in names assert "Qwen3.5-9B-FP8-Thinking" in names
assert "GPT-OSS-20B-High" in names assert "GPT-OSS-20B-High" in names

View File

@@ -45,7 +45,7 @@ def test_list_models_returns_16(client, auth_headers):
assert resp.status_code == 200 assert resp.status_code == 200
body = resp.json() body = resp.json()
assert body["object"] == "list" assert body["object"] == "list"
assert len(body["data"]) == 16 assert len(body["data"]) == 15
def test_list_models_contains_expected_names(client, auth_headers): def test_list_models_contains_expected_names(client, auth_headers):