fix: Chatterbox uses separate classes per variant, remove turbo
ChatterboxTTS and ChatterboxMultilingualTTS are separate classes. Turbo variant doesn't exist in chatterbox-tts 0.1.7. Multilingual generate() requires language_id parameter. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -48,12 +48,6 @@ physical_models:
|
|||||||
estimated_vram_gb: 4
|
estimated_vram_gb: 4
|
||||||
default_language: "en"
|
default_language: "en"
|
||||||
|
|
||||||
chatterbox-turbo:
|
|
||||||
type: tts
|
|
||||||
backend: chatterbox
|
|
||||||
variant: "turbo"
|
|
||||||
estimated_vram_gb: 2
|
|
||||||
|
|
||||||
chatterbox-multilingual:
|
chatterbox-multilingual:
|
||||||
type: tts
|
type: tts
|
||||||
backend: chatterbox
|
backend: chatterbox
|
||||||
@@ -110,8 +104,6 @@ virtual_models:
|
|||||||
|
|
||||||
cohere-transcribe:
|
cohere-transcribe:
|
||||||
physical: cohere-transcribe
|
physical: cohere-transcribe
|
||||||
Chatterbox-Turbo:
|
|
||||||
physical: chatterbox-turbo
|
|
||||||
Chatterbox-Multilingual:
|
Chatterbox-Multilingual:
|
||||||
physical: chatterbox-multilingual
|
physical: chatterbox-multilingual
|
||||||
Chatterbox:
|
Chatterbox:
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import gc
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@@ -24,25 +25,26 @@ class ChatterboxTTSBackend(BaseBackend):
|
|||||||
logger.info(f"Loading Chatterbox {variant} to {device}")
|
logger.info(f"Loading Chatterbox {variant} to {device}")
|
||||||
|
|
||||||
def _load():
|
def _load():
|
||||||
from chatterbox.tts import ChatterboxTTS
|
if variant == "multilingual":
|
||||||
if variant == "turbo":
|
from chatterbox import ChatterboxMultilingualTTS
|
||||||
model = ChatterboxTTS.from_pretrained(device=device, variant="turbo")
|
return ChatterboxMultilingualTTS.from_pretrained(device=device)
|
||||||
elif variant == "multilingual":
|
|
||||||
model = ChatterboxTTS.from_pretrained(device=device, variant="multilingual")
|
|
||||||
else:
|
else:
|
||||||
model = ChatterboxTTS.from_pretrained(device=device)
|
from chatterbox.tts import ChatterboxTTS
|
||||||
return model
|
return ChatterboxTTS.from_pretrained(device=device)
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
model = await loop.run_in_executor(None, _load)
|
model = await loop.run_in_executor(None, _load)
|
||||||
self._loaded[model_id] = {"model": model, "device": device}
|
self._loaded[model_id] = {"model": model, "variant": variant, "device": device}
|
||||||
|
|
||||||
async def unload(self, model_id: str) -> None:
|
async def unload(self, model_id: str) -> None:
|
||||||
if model_id not in self._loaded:
|
if model_id not in self._loaded:
|
||||||
return
|
return
|
||||||
entry = self._loaded.pop(model_id)
|
entry = self._loaded.pop(model_id)
|
||||||
del entry["model"]
|
del entry["model"]
|
||||||
|
del entry
|
||||||
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
logger.info(f"Unloaded Chatterbox {model_id}")
|
||||||
|
|
||||||
async def generate(self, model_id, messages, params, stream=False, tools=None):
|
async def generate(self, model_id, messages, params, stream=False, tools=None):
|
||||||
raise NotImplementedError("TTS backend does not support chat generation")
|
raise NotImplementedError("TTS backend does not support chat generation")
|
||||||
@@ -50,9 +52,15 @@ class ChatterboxTTSBackend(BaseBackend):
|
|||||||
async def synthesize(self, model_id: str, text: str, voice: str = "default") -> bytes:
|
async def synthesize(self, model_id: str, text: str, voice: str = "default") -> bytes:
|
||||||
entry = self._loaded[model_id]
|
entry = self._loaded[model_id]
|
||||||
model = entry["model"]
|
model = entry["model"]
|
||||||
|
variant = entry["variant"]
|
||||||
|
|
||||||
def _synthesize():
|
def _synthesize():
|
||||||
wav = model.generate(text)
|
if variant == "multilingual":
|
||||||
|
# Default to English; voice param could encode language
|
||||||
|
lang = "en" if voice == "default" else voice
|
||||||
|
wav = model.generate(text, language_id=lang)
|
||||||
|
else:
|
||||||
|
wav = model.generate(text)
|
||||||
buf = io.BytesIO()
|
buf = io.BytesIO()
|
||||||
sf.write(buf, wav.cpu().numpy().squeeze(), samplerate=24000, format="WAV")
|
sf.write(buf, wav.cpu().numpy().squeeze(), samplerate=24000, format="WAV")
|
||||||
buf.seek(0)
|
buf.seek(0)
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ def test_load_models_config_returns_physical_and_virtual():
|
|||||||
physical, virtual = load_models_config()
|
physical, virtual = load_models_config()
|
||||||
assert isinstance(physical, dict)
|
assert isinstance(physical, dict)
|
||||||
assert isinstance(virtual, dict)
|
assert isinstance(virtual, dict)
|
||||||
assert len(physical) == 9
|
assert len(physical) == 8
|
||||||
assert len(virtual) == 16
|
assert len(virtual) == 15
|
||||||
|
|
||||||
|
|
||||||
def test_physical_model_has_required_fields():
|
def test_physical_model_has_required_fields():
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ def registry():
|
|||||||
|
|
||||||
def test_list_virtual_models(registry):
|
def test_list_virtual_models(registry):
|
||||||
models = registry.list_virtual_models()
|
models = registry.list_virtual_models()
|
||||||
assert len(models) == 16
|
assert len(models) == 15
|
||||||
names = [m["id"] for m in models]
|
names = [m["id"] for m in models]
|
||||||
assert "Qwen3.5-9B-FP8-Thinking" in names
|
assert "Qwen3.5-9B-FP8-Thinking" in names
|
||||||
assert "GPT-OSS-20B-High" in names
|
assert "GPT-OSS-20B-High" in names
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ def test_list_models_returns_16(client, auth_headers):
|
|||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
body = resp.json()
|
body = resp.json()
|
||||||
assert body["object"] == "list"
|
assert body["object"] == "list"
|
||||||
assert len(body["data"]) == 16
|
assert len(body["data"]) == 15
|
||||||
|
|
||||||
|
|
||||||
def test_list_models_contains_expected_names(client, auth_headers):
|
def test_list_models_contains_expected_names(client, auth_headers):
|
||||||
|
|||||||
Reference in New Issue
Block a user