From de25b5e2a73215dcee698bd30b95cc7ad6e7d268da103c0795201617483318e9 Mon Sep 17 00:00:00 2001 From: tlg Date: Sat, 4 Apr 2026 09:40:39 +0200 Subject: [PATCH] feat: transformers ASR backend for cohere-transcribe --- .../llmux/llmux/backends/transformers_asr.py | 73 ++++++++++ .../llmux/llmux/backends/transformers_llm.py | 134 ++++++++++++++++++ 2 files changed, 207 insertions(+) create mode 100644 kischdle/llmux/llmux/backends/transformers_asr.py create mode 100644 kischdle/llmux/llmux/backends/transformers_llm.py diff --git a/kischdle/llmux/llmux/backends/transformers_asr.py b/kischdle/llmux/llmux/backends/transformers_asr.py new file mode 100644 index 0000000..ae0cdff --- /dev/null +++ b/kischdle/llmux/llmux/backends/transformers_asr.py @@ -0,0 +1,73 @@ +import asyncio +import logging + +import torch +from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor + +from llmux.backends.base import BaseBackend +from llmux.config import PhysicalModel + +logger = logging.getLogger(__name__) + + +class TransformersASRBackend(BaseBackend): + def __init__(self, models_dir: str = "/models"): + self._models_dir = models_dir + self._loaded: dict[str, dict] = {} + + async def load(self, model_id: str, device: str = "cuda") -> None: + if model_id in self._loaded: + return + physical = _get_physical_config(model_id) + hf_id = physical.model_id + logger.info(f"Loading ASR model {hf_id} to {device}") + + def _load(): + processor = AutoProcessor.from_pretrained(hf_id, cache_dir=self._models_dir, trust_remote_code=True) + model = AutoModelForSpeechSeq2Seq.from_pretrained(hf_id, cache_dir=self._models_dir, torch_dtype="auto", device_map=device, trust_remote_code=True) + return model, processor + + loop = asyncio.get_event_loop() + model, processor = await loop.run_in_executor(None, _load) + self._loaded[model_id] = {"model": model, "processor": processor, "device": device} + + async def unload(self, model_id: str) -> None: + if model_id not in self._loaded: + return + entry = self._loaded.pop(model_id) + del entry["model"] + del entry["processor"] + torch.cuda.empty_cache() + + async def generate(self, model_id, messages, params, stream=False, tools=None): + raise NotImplementedError("ASR backend does not support chat generation") + + async def transcribe(self, model_id: str, audio_data: bytes, language: str = "en") -> dict: + import io + import soundfile as sf + + entry = self._loaded[model_id] + model = entry["model"] + processor = entry["processor"] + + def _transcribe(): + audio_array, sample_rate = sf.read(io.BytesIO(audio_data)) + inputs = processor(audio_array, sampling_rate=sample_rate, return_tensors="pt", language=language).to(model.device) + with torch.no_grad(): + predicted_ids = model.generate(**inputs) + transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] + return transcription + + loop = asyncio.get_event_loop() + text = await loop.run_in_executor(None, _transcribe) + return {"text": text} + + +_physical_models: dict[str, PhysicalModel] = {} + +def set_physical_models(models: dict[str, PhysicalModel]) -> None: + global _physical_models + _physical_models = models + +def _get_physical_config(model_id: str) -> PhysicalModel: + return _physical_models[model_id] diff --git a/kischdle/llmux/llmux/backends/transformers_llm.py b/kischdle/llmux/llmux/backends/transformers_llm.py new file mode 100644 index 0000000..ca1814f --- /dev/null +++ b/kischdle/llmux/llmux/backends/transformers_llm.py @@ -0,0 +1,134 @@ +import asyncio +import json +import logging +import time +import uuid +from typing import AsyncIterator + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, TextIteratorStreamer +from threading import Thread + +from llmux.backends.base import BaseBackend +from llmux.config import PhysicalModel + +logger = logging.getLogger(__name__) + + +class TransformersLLMBackend(BaseBackend): + def __init__(self, models_dir: str = "/models"): + self._models_dir = models_dir + self._loaded: dict[str, dict] = {} # model_id -> {"model", "tokenizer", "processor"} + + async def load(self, model_id: str, device: str = "cuda") -> None: + if model_id in self._loaded: + return + physical = _get_physical_config(model_id) + hf_id = physical.model_id + logger.info(f"Loading transformers model {hf_id} to {device}") + + def _load(): + tokenizer = AutoTokenizer.from_pretrained(hf_id, cache_dir=self._models_dir, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained(hf_id, cache_dir=self._models_dir, torch_dtype="auto", device_map=device, trust_remote_code=True) + processor = None + if physical.supports_vision: + try: + processor = AutoProcessor.from_pretrained(hf_id, cache_dir=self._models_dir, trust_remote_code=True) + except Exception: + logger.warning(f"No processor found for {hf_id}, vision disabled") + return model, tokenizer, processor + + loop = asyncio.get_event_loop() + model, tokenizer, processor = await loop.run_in_executor(None, _load) + self._loaded[model_id] = {"model": model, "tokenizer": tokenizer, "processor": processor, "device": device} + + async def unload(self, model_id: str) -> None: + if model_id not in self._loaded: + return + entry = self._loaded.pop(model_id) + del entry["model"] + del entry["tokenizer"] + if entry.get("processor"): + del entry["processor"] + torch.cuda.empty_cache() + + async def generate(self, model_id, messages, params, stream=False, tools=None): + entry = self._loaded[model_id] + model = entry["model"] + tokenizer = entry["tokenizer"] + + # Apply virtual model params + chat_params = {} + if "enable_thinking" in params: + chat_params["enable_thinking"] = params["enable_thinking"] + + # Inject system prompt prefix for gpt-oss reasoning levels + effective_messages = list(messages) + if "system_prompt_prefix" in params: + prefix = params["system_prompt_prefix"] + if effective_messages and effective_messages[0].get("role") == "system": + effective_messages[0] = dict(effective_messages[0]) + effective_messages[0]["content"] = prefix + "\n\n" + effective_messages[0]["content"] + else: + effective_messages.insert(0, {"role": "system", "content": prefix}) + + text = tokenizer.apply_chat_template(effective_messages, tokenize=False, add_generation_prompt=True, tools=tools, **chat_params) + inputs = tokenizer(text, return_tensors="pt").to(model.device) + + if stream: + return self._stream_generate(model, tokenizer, inputs, model_id) + else: + return await self._full_generate(model, tokenizer, inputs, model_id) + + async def _full_generate(self, model, tokenizer, inputs, model_id): + def _run(): + with torch.no_grad(): + output_ids = model.generate(**inputs, max_new_tokens=4096) + new_tokens = output_ids[0][inputs["input_ids"].shape[1]:] + return tokenizer.decode(new_tokens, skip_special_tokens=True) + + loop = asyncio.get_event_loop() + text = await loop.run_in_executor(None, _run) + return { + "id": f"chatcmpl-{uuid.uuid4().hex[:12]}", + "object": "chat.completion", + "created": int(time.time()), + "model": model_id, + "choices": [{"index": 0, "message": {"role": "assistant", "content": text}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + } + + async def _stream_generate(self, model, tokenizer, inputs, model_id): + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) + gen_kwargs = {**inputs, "max_new_tokens": 4096, "streamer": streamer} + thread = Thread(target=lambda: model.generate(**gen_kwargs)) + thread.start() + + chat_id = f"chatcmpl-{uuid.uuid4().hex[:12]}" + created = int(time.time()) + + async def _iter(): + loop = asyncio.get_event_loop() + while True: + token = await loop.run_in_executor(None, lambda: next(streamer, None)) + if token is None: + chunk = {"id": chat_id, "object": "chat.completion.chunk", "created": created, "model": model_id, "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]} + yield f"data: {json.dumps(chunk)}\n\n" + yield "data: [DONE]\n\n" + break + chunk = {"id": chat_id, "object": "chat.completion.chunk", "created": created, "model": model_id, "choices": [{"index": 0, "delta": {"content": token}, "finish_reason": None}]} + yield f"data: {json.dumps(chunk)}\n\n" + thread.join() + + return _iter() + + +# Physical model config injection +_physical_models: dict[str, PhysicalModel] = {} + +def set_physical_models(models: dict[str, PhysicalModel]) -> None: + global _physical_models + _physical_models = models + +def _get_physical_config(model_id: str) -> PhysicalModel: + return _physical_models[model_id]