feat: transformers ASR backend for cohere-transcribe
This commit is contained in:
73
kischdle/llmux/llmux/backends/transformers_asr.py
Normal file
73
kischdle/llmux/llmux/backends/transformers_asr.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
|
||||||
|
|
||||||
|
from llmux.backends.base import BaseBackend
|
||||||
|
from llmux.config import PhysicalModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformersASRBackend(BaseBackend):
|
||||||
|
def __init__(self, models_dir: str = "/models"):
|
||||||
|
self._models_dir = models_dir
|
||||||
|
self._loaded: dict[str, dict] = {}
|
||||||
|
|
||||||
|
async def load(self, model_id: str, device: str = "cuda") -> None:
|
||||||
|
if model_id in self._loaded:
|
||||||
|
return
|
||||||
|
physical = _get_physical_config(model_id)
|
||||||
|
hf_id = physical.model_id
|
||||||
|
logger.info(f"Loading ASR model {hf_id} to {device}")
|
||||||
|
|
||||||
|
def _load():
|
||||||
|
processor = AutoProcessor.from_pretrained(hf_id, cache_dir=self._models_dir, trust_remote_code=True)
|
||||||
|
model = AutoModelForSpeechSeq2Seq.from_pretrained(hf_id, cache_dir=self._models_dir, torch_dtype="auto", device_map=device, trust_remote_code=True)
|
||||||
|
return model, processor
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
model, processor = await loop.run_in_executor(None, _load)
|
||||||
|
self._loaded[model_id] = {"model": model, "processor": processor, "device": device}
|
||||||
|
|
||||||
|
async def unload(self, model_id: str) -> None:
|
||||||
|
if model_id not in self._loaded:
|
||||||
|
return
|
||||||
|
entry = self._loaded.pop(model_id)
|
||||||
|
del entry["model"]
|
||||||
|
del entry["processor"]
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
async def generate(self, model_id, messages, params, stream=False, tools=None):
|
||||||
|
raise NotImplementedError("ASR backend does not support chat generation")
|
||||||
|
|
||||||
|
async def transcribe(self, model_id: str, audio_data: bytes, language: str = "en") -> dict:
|
||||||
|
import io
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
entry = self._loaded[model_id]
|
||||||
|
model = entry["model"]
|
||||||
|
processor = entry["processor"]
|
||||||
|
|
||||||
|
def _transcribe():
|
||||||
|
audio_array, sample_rate = sf.read(io.BytesIO(audio_data))
|
||||||
|
inputs = processor(audio_array, sampling_rate=sample_rate, return_tensors="pt", language=language).to(model.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
predicted_ids = model.generate(**inputs)
|
||||||
|
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
|
||||||
|
return transcription
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
text = await loop.run_in_executor(None, _transcribe)
|
||||||
|
return {"text": text}
|
||||||
|
|
||||||
|
|
||||||
|
_physical_models: dict[str, PhysicalModel] = {}
|
||||||
|
|
||||||
|
def set_physical_models(models: dict[str, PhysicalModel]) -> None:
|
||||||
|
global _physical_models
|
||||||
|
_physical_models = models
|
||||||
|
|
||||||
|
def _get_physical_config(model_id: str) -> PhysicalModel:
|
||||||
|
return _physical_models[model_id]
|
||||||
134
kischdle/llmux/llmux/backends/transformers_llm.py
Normal file
134
kischdle/llmux/llmux/backends/transformers_llm.py
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import AsyncIterator
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, TextIteratorStreamer
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
from llmux.backends.base import BaseBackend
|
||||||
|
from llmux.config import PhysicalModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformersLLMBackend(BaseBackend):
|
||||||
|
def __init__(self, models_dir: str = "/models"):
|
||||||
|
self._models_dir = models_dir
|
||||||
|
self._loaded: dict[str, dict] = {} # model_id -> {"model", "tokenizer", "processor"}
|
||||||
|
|
||||||
|
async def load(self, model_id: str, device: str = "cuda") -> None:
|
||||||
|
if model_id in self._loaded:
|
||||||
|
return
|
||||||
|
physical = _get_physical_config(model_id)
|
||||||
|
hf_id = physical.model_id
|
||||||
|
logger.info(f"Loading transformers model {hf_id} to {device}")
|
||||||
|
|
||||||
|
def _load():
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(hf_id, cache_dir=self._models_dir, trust_remote_code=True)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(hf_id, cache_dir=self._models_dir, torch_dtype="auto", device_map=device, trust_remote_code=True)
|
||||||
|
processor = None
|
||||||
|
if physical.supports_vision:
|
||||||
|
try:
|
||||||
|
processor = AutoProcessor.from_pretrained(hf_id, cache_dir=self._models_dir, trust_remote_code=True)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(f"No processor found for {hf_id}, vision disabled")
|
||||||
|
return model, tokenizer, processor
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
model, tokenizer, processor = await loop.run_in_executor(None, _load)
|
||||||
|
self._loaded[model_id] = {"model": model, "tokenizer": tokenizer, "processor": processor, "device": device}
|
||||||
|
|
||||||
|
async def unload(self, model_id: str) -> None:
|
||||||
|
if model_id not in self._loaded:
|
||||||
|
return
|
||||||
|
entry = self._loaded.pop(model_id)
|
||||||
|
del entry["model"]
|
||||||
|
del entry["tokenizer"]
|
||||||
|
if entry.get("processor"):
|
||||||
|
del entry["processor"]
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
async def generate(self, model_id, messages, params, stream=False, tools=None):
|
||||||
|
entry = self._loaded[model_id]
|
||||||
|
model = entry["model"]
|
||||||
|
tokenizer = entry["tokenizer"]
|
||||||
|
|
||||||
|
# Apply virtual model params
|
||||||
|
chat_params = {}
|
||||||
|
if "enable_thinking" in params:
|
||||||
|
chat_params["enable_thinking"] = params["enable_thinking"]
|
||||||
|
|
||||||
|
# Inject system prompt prefix for gpt-oss reasoning levels
|
||||||
|
effective_messages = list(messages)
|
||||||
|
if "system_prompt_prefix" in params:
|
||||||
|
prefix = params["system_prompt_prefix"]
|
||||||
|
if effective_messages and effective_messages[0].get("role") == "system":
|
||||||
|
effective_messages[0] = dict(effective_messages[0])
|
||||||
|
effective_messages[0]["content"] = prefix + "\n\n" + effective_messages[0]["content"]
|
||||||
|
else:
|
||||||
|
effective_messages.insert(0, {"role": "system", "content": prefix})
|
||||||
|
|
||||||
|
text = tokenizer.apply_chat_template(effective_messages, tokenize=False, add_generation_prompt=True, tools=tools, **chat_params)
|
||||||
|
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._stream_generate(model, tokenizer, inputs, model_id)
|
||||||
|
else:
|
||||||
|
return await self._full_generate(model, tokenizer, inputs, model_id)
|
||||||
|
|
||||||
|
async def _full_generate(self, model, tokenizer, inputs, model_id):
|
||||||
|
def _run():
|
||||||
|
with torch.no_grad():
|
||||||
|
output_ids = model.generate(**inputs, max_new_tokens=4096)
|
||||||
|
new_tokens = output_ids[0][inputs["input_ids"].shape[1]:]
|
||||||
|
return tokenizer.decode(new_tokens, skip_special_tokens=True)
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
text = await loop.run_in_executor(None, _run)
|
||||||
|
return {
|
||||||
|
"id": f"chatcmpl-{uuid.uuid4().hex[:12]}",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": model_id,
|
||||||
|
"choices": [{"index": 0, "message": {"role": "assistant", "content": text}, "finish_reason": "stop"}],
|
||||||
|
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _stream_generate(self, model, tokenizer, inputs, model_id):
|
||||||
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||||
|
gen_kwargs = {**inputs, "max_new_tokens": 4096, "streamer": streamer}
|
||||||
|
thread = Thread(target=lambda: model.generate(**gen_kwargs))
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
chat_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||||
|
created = int(time.time())
|
||||||
|
|
||||||
|
async def _iter():
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
while True:
|
||||||
|
token = await loop.run_in_executor(None, lambda: next(streamer, None))
|
||||||
|
if token is None:
|
||||||
|
chunk = {"id": chat_id, "object": "chat.completion.chunk", "created": created, "model": model_id, "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]}
|
||||||
|
yield f"data: {json.dumps(chunk)}\n\n"
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
break
|
||||||
|
chunk = {"id": chat_id, "object": "chat.completion.chunk", "created": created, "model": model_id, "choices": [{"index": 0, "delta": {"content": token}, "finish_reason": None}]}
|
||||||
|
yield f"data: {json.dumps(chunk)}\n\n"
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
return _iter()
|
||||||
|
|
||||||
|
|
||||||
|
# Physical model config injection
|
||||||
|
_physical_models: dict[str, PhysicalModel] = {}
|
||||||
|
|
||||||
|
def set_physical_models(models: dict[str, PhysicalModel]) -> None:
|
||||||
|
global _physical_models
|
||||||
|
_physical_models = models
|
||||||
|
|
||||||
|
def _get_physical_config(model_id: str) -> PhysicalModel:
|
||||||
|
return _physical_models[model_id]
|
||||||
Reference in New Issue
Block a user