fix: Chatterbox uses separate classes per variant, remove turbo
ChatterboxTTS and ChatterboxMultilingualTTS are separate classes. Turbo variant doesn't exist in chatterbox-tts 0.1.7. Multilingual generate() requires language_id parameter. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -48,12 +48,6 @@ physical_models:
|
||||
estimated_vram_gb: 4
|
||||
default_language: "en"
|
||||
|
||||
chatterbox-turbo:
|
||||
type: tts
|
||||
backend: chatterbox
|
||||
variant: "turbo"
|
||||
estimated_vram_gb: 2
|
||||
|
||||
chatterbox-multilingual:
|
||||
type: tts
|
||||
backend: chatterbox
|
||||
@@ -110,8 +104,6 @@ virtual_models:
|
||||
|
||||
cohere-transcribe:
|
||||
physical: cohere-transcribe
|
||||
Chatterbox-Turbo:
|
||||
physical: chatterbox-turbo
|
||||
Chatterbox-Multilingual:
|
||||
physical: chatterbox-multilingual
|
||||
Chatterbox:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import gc
|
||||
import io
|
||||
import logging
|
||||
|
||||
@@ -24,25 +25,26 @@ class ChatterboxTTSBackend(BaseBackend):
|
||||
logger.info(f"Loading Chatterbox {variant} to {device}")
|
||||
|
||||
def _load():
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
if variant == "turbo":
|
||||
model = ChatterboxTTS.from_pretrained(device=device, variant="turbo")
|
||||
elif variant == "multilingual":
|
||||
model = ChatterboxTTS.from_pretrained(device=device, variant="multilingual")
|
||||
if variant == "multilingual":
|
||||
from chatterbox import ChatterboxMultilingualTTS
|
||||
return ChatterboxMultilingualTTS.from_pretrained(device=device)
|
||||
else:
|
||||
model = ChatterboxTTS.from_pretrained(device=device)
|
||||
return model
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
return ChatterboxTTS.from_pretrained(device=device)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
model = await loop.run_in_executor(None, _load)
|
||||
self._loaded[model_id] = {"model": model, "device": device}
|
||||
self._loaded[model_id] = {"model": model, "variant": variant, "device": device}
|
||||
|
||||
async def unload(self, model_id: str) -> None:
|
||||
if model_id not in self._loaded:
|
||||
return
|
||||
entry = self._loaded.pop(model_id)
|
||||
del entry["model"]
|
||||
del entry
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
logger.info(f"Unloaded Chatterbox {model_id}")
|
||||
|
||||
async def generate(self, model_id, messages, params, stream=False, tools=None):
|
||||
raise NotImplementedError("TTS backend does not support chat generation")
|
||||
@@ -50,8 +52,14 @@ class ChatterboxTTSBackend(BaseBackend):
|
||||
async def synthesize(self, model_id: str, text: str, voice: str = "default") -> bytes:
|
||||
entry = self._loaded[model_id]
|
||||
model = entry["model"]
|
||||
variant = entry["variant"]
|
||||
|
||||
def _synthesize():
|
||||
if variant == "multilingual":
|
||||
# Default to English; voice param could encode language
|
||||
lang = "en" if voice == "default" else voice
|
||||
wav = model.generate(text, language_id=lang)
|
||||
else:
|
||||
wav = model.generate(text)
|
||||
buf = io.BytesIO()
|
||||
sf.write(buf, wav.cpu().numpy().squeeze(), samplerate=24000, format="WAV")
|
||||
|
||||
@@ -5,8 +5,8 @@ def test_load_models_config_returns_physical_and_virtual():
|
||||
physical, virtual = load_models_config()
|
||||
assert isinstance(physical, dict)
|
||||
assert isinstance(virtual, dict)
|
||||
assert len(physical) == 9
|
||||
assert len(virtual) == 16
|
||||
assert len(physical) == 8
|
||||
assert len(virtual) == 15
|
||||
|
||||
|
||||
def test_physical_model_has_required_fields():
|
||||
|
||||
@@ -10,7 +10,7 @@ def registry():
|
||||
|
||||
def test_list_virtual_models(registry):
|
||||
models = registry.list_virtual_models()
|
||||
assert len(models) == 16
|
||||
assert len(models) == 15
|
||||
names = [m["id"] for m in models]
|
||||
assert "Qwen3.5-9B-FP8-Thinking" in names
|
||||
assert "GPT-OSS-20B-High" in names
|
||||
|
||||
@@ -45,7 +45,7 @@ def test_list_models_returns_16(client, auth_headers):
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["object"] == "list"
|
||||
assert len(body["data"]) == 16
|
||||
assert len(body["data"]) == 15
|
||||
|
||||
|
||||
def test_list_models_contains_expected_names(client, auth_headers):
|
||||
|
||||
Reference in New Issue
Block a user