feat: transformers ASR backend for cohere-transcribe

This commit is contained in:
tlg
2026-04-04 09:40:39 +02:00
parent 449e37d318
commit de25b5e2a7
2 changed files with 207 additions and 0 deletions

View File

@@ -0,0 +1,73 @@
import asyncio
import logging
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from llmux.backends.base import BaseBackend
from llmux.config import PhysicalModel
logger = logging.getLogger(__name__)
class TransformersASRBackend(BaseBackend):
def __init__(self, models_dir: str = "/models"):
self._models_dir = models_dir
self._loaded: dict[str, dict] = {}
async def load(self, model_id: str, device: str = "cuda") -> None:
if model_id in self._loaded:
return
physical = _get_physical_config(model_id)
hf_id = physical.model_id
logger.info(f"Loading ASR model {hf_id} to {device}")
def _load():
processor = AutoProcessor.from_pretrained(hf_id, cache_dir=self._models_dir, trust_remote_code=True)
model = AutoModelForSpeechSeq2Seq.from_pretrained(hf_id, cache_dir=self._models_dir, torch_dtype="auto", device_map=device, trust_remote_code=True)
return model, processor
loop = asyncio.get_event_loop()
model, processor = await loop.run_in_executor(None, _load)
self._loaded[model_id] = {"model": model, "processor": processor, "device": device}
async def unload(self, model_id: str) -> None:
if model_id not in self._loaded:
return
entry = self._loaded.pop(model_id)
del entry["model"]
del entry["processor"]
torch.cuda.empty_cache()
async def generate(self, model_id, messages, params, stream=False, tools=None):
raise NotImplementedError("ASR backend does not support chat generation")
async def transcribe(self, model_id: str, audio_data: bytes, language: str = "en") -> dict:
import io
import soundfile as sf
entry = self._loaded[model_id]
model = entry["model"]
processor = entry["processor"]
def _transcribe():
audio_array, sample_rate = sf.read(io.BytesIO(audio_data))
inputs = processor(audio_array, sampling_rate=sample_rate, return_tensors="pt", language=language).to(model.device)
with torch.no_grad():
predicted_ids = model.generate(**inputs)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return transcription
loop = asyncio.get_event_loop()
text = await loop.run_in_executor(None, _transcribe)
return {"text": text}
_physical_models: dict[str, PhysicalModel] = {}
def set_physical_models(models: dict[str, PhysicalModel]) -> None:
global _physical_models
_physical_models = models
def _get_physical_config(model_id: str) -> PhysicalModel:
return _physical_models[model_id]

View File

@@ -0,0 +1,134 @@
import asyncio
import json
import logging
import time
import uuid
from typing import AsyncIterator
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, TextIteratorStreamer
from threading import Thread
from llmux.backends.base import BaseBackend
from llmux.config import PhysicalModel
logger = logging.getLogger(__name__)
class TransformersLLMBackend(BaseBackend):
def __init__(self, models_dir: str = "/models"):
self._models_dir = models_dir
self._loaded: dict[str, dict] = {} # model_id -> {"model", "tokenizer", "processor"}
async def load(self, model_id: str, device: str = "cuda") -> None:
if model_id in self._loaded:
return
physical = _get_physical_config(model_id)
hf_id = physical.model_id
logger.info(f"Loading transformers model {hf_id} to {device}")
def _load():
tokenizer = AutoTokenizer.from_pretrained(hf_id, cache_dir=self._models_dir, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(hf_id, cache_dir=self._models_dir, torch_dtype="auto", device_map=device, trust_remote_code=True)
processor = None
if physical.supports_vision:
try:
processor = AutoProcessor.from_pretrained(hf_id, cache_dir=self._models_dir, trust_remote_code=True)
except Exception:
logger.warning(f"No processor found for {hf_id}, vision disabled")
return model, tokenizer, processor
loop = asyncio.get_event_loop()
model, tokenizer, processor = await loop.run_in_executor(None, _load)
self._loaded[model_id] = {"model": model, "tokenizer": tokenizer, "processor": processor, "device": device}
async def unload(self, model_id: str) -> None:
if model_id not in self._loaded:
return
entry = self._loaded.pop(model_id)
del entry["model"]
del entry["tokenizer"]
if entry.get("processor"):
del entry["processor"]
torch.cuda.empty_cache()
async def generate(self, model_id, messages, params, stream=False, tools=None):
entry = self._loaded[model_id]
model = entry["model"]
tokenizer = entry["tokenizer"]
# Apply virtual model params
chat_params = {}
if "enable_thinking" in params:
chat_params["enable_thinking"] = params["enable_thinking"]
# Inject system prompt prefix for gpt-oss reasoning levels
effective_messages = list(messages)
if "system_prompt_prefix" in params:
prefix = params["system_prompt_prefix"]
if effective_messages and effective_messages[0].get("role") == "system":
effective_messages[0] = dict(effective_messages[0])
effective_messages[0]["content"] = prefix + "\n\n" + effective_messages[0]["content"]
else:
effective_messages.insert(0, {"role": "system", "content": prefix})
text = tokenizer.apply_chat_template(effective_messages, tokenize=False, add_generation_prompt=True, tools=tools, **chat_params)
inputs = tokenizer(text, return_tensors="pt").to(model.device)
if stream:
return self._stream_generate(model, tokenizer, inputs, model_id)
else:
return await self._full_generate(model, tokenizer, inputs, model_id)
async def _full_generate(self, model, tokenizer, inputs, model_id):
def _run():
with torch.no_grad():
output_ids = model.generate(**inputs, max_new_tokens=4096)
new_tokens = output_ids[0][inputs["input_ids"].shape[1]:]
return tokenizer.decode(new_tokens, skip_special_tokens=True)
loop = asyncio.get_event_loop()
text = await loop.run_in_executor(None, _run)
return {
"id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
"object": "chat.completion",
"created": int(time.time()),
"model": model_id,
"choices": [{"index": 0, "message": {"role": "assistant", "content": text}, "finish_reason": "stop"}],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}
async def _stream_generate(self, model, tokenizer, inputs, model_id):
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs = {**inputs, "max_new_tokens": 4096, "streamer": streamer}
thread = Thread(target=lambda: model.generate(**gen_kwargs))
thread.start()
chat_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
created = int(time.time())
async def _iter():
loop = asyncio.get_event_loop()
while True:
token = await loop.run_in_executor(None, lambda: next(streamer, None))
if token is None:
chunk = {"id": chat_id, "object": "chat.completion.chunk", "created": created, "model": model_id, "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]}
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
break
chunk = {"id": chat_id, "object": "chat.completion.chunk", "created": created, "model": model_id, "choices": [{"index": 0, "delta": {"content": token}, "finish_reason": None}]}
yield f"data: {json.dumps(chunk)}\n\n"
thread.join()
return _iter()
# Physical model config injection
_physical_models: dict[str, PhysicalModel] = {}
def set_physical_models(models: dict[str, PhysicalModel]) -> None:
global _physical_models
_physical_models = models
def _get_physical_config(model_id: str) -> PhysicalModel:
return _physical_models[model_id]