Files
DesTEngSsv006_swd/kischdle/llmux/llmux/backends/transformers_asr.py

74 lines
2.7 KiB
Python

import asyncio
import logging
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from llmux.backends.base import BaseBackend
from llmux.config import PhysicalModel
logger = logging.getLogger(__name__)
class TransformersASRBackend(BaseBackend):
def __init__(self, models_dir: str = "/models"):
self._models_dir = models_dir
self._loaded: dict[str, dict] = {}
async def load(self, model_id: str, device: str = "cuda") -> None:
if model_id in self._loaded:
return
physical = _get_physical_config(model_id)
hf_id = physical.model_id
logger.info(f"Loading ASR model {hf_id} to {device}")
def _load():
processor = AutoProcessor.from_pretrained(hf_id, cache_dir=self._models_dir, trust_remote_code=True)
model = AutoModelForSpeechSeq2Seq.from_pretrained(hf_id, cache_dir=self._models_dir, torch_dtype="auto", device_map=device, trust_remote_code=True)
return model, processor
loop = asyncio.get_event_loop()
model, processor = await loop.run_in_executor(None, _load)
self._loaded[model_id] = {"model": model, "processor": processor, "device": device}
async def unload(self, model_id: str) -> None:
if model_id not in self._loaded:
return
entry = self._loaded.pop(model_id)
del entry["model"]
del entry["processor"]
torch.cuda.empty_cache()
async def generate(self, model_id, messages, params, stream=False, tools=None):
raise NotImplementedError("ASR backend does not support chat generation")
async def transcribe(self, model_id: str, audio_data: bytes, language: str = "en") -> dict:
import io
import soundfile as sf
entry = self._loaded[model_id]
model = entry["model"]
processor = entry["processor"]
def _transcribe():
audio_array, sample_rate = sf.read(io.BytesIO(audio_data))
inputs = processor(audio_array, sampling_rate=sample_rate, return_tensors="pt", language=language).to(model.device)
with torch.no_grad():
predicted_ids = model.generate(**inputs)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
return transcription
loop = asyncio.get_event_loop()
text = await loop.run_in_executor(None, _transcribe)
return {"text": text}
_physical_models: dict[str, PhysicalModel] = {}
def set_physical_models(models: dict[str, PhysicalModel]) -> None:
global _physical_models
_physical_models = models
def _get_physical_config(model_id: str) -> PhysicalModel:
return _physical_models[model_id]