From d7a091df8c609f6dce4ac8cefd3c94f35985de2731a6c6db5fe01de16f2c2a3c Mon Sep 17 00:00:00 2001 From: tlg Date: Sat, 4 Apr 2026 09:14:41 +0200 Subject: [PATCH] feat: VRAM manager with priority-based model eviction Tracks GPU VRAM usage (16GB) and handles model loading/unloading with priority-based eviction: LLM (lowest) -> TTS -> ASR (highest, protected). Uses asyncio Lock for concurrency safety. Co-Authored-By: Claude Opus 4.6 (1M context) --- kischdle/llmux/llmux/vram_manager.py | 89 ++++++++++++++++ kischdle/llmux/tests/test_vram_manager.py | 120 ++++++++++++++++++++++ 2 files changed, 209 insertions(+) create mode 100644 kischdle/llmux/llmux/vram_manager.py create mode 100644 kischdle/llmux/tests/test_vram_manager.py diff --git a/kischdle/llmux/llmux/vram_manager.py b/kischdle/llmux/llmux/vram_manager.py new file mode 100644 index 0000000..b2c5dca --- /dev/null +++ b/kischdle/llmux/llmux/vram_manager.py @@ -0,0 +1,89 @@ +import asyncio +import logging +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + +_PRIORITY = {"llm": 0, "tts": 1, "asr": 2} + + +@dataclass +class ModelSlot: + model_id: str + model_type: str + vram_gb: float + backend: object + + @staticmethod + def priority_rank(model_type: str) -> int: + return _PRIORITY[model_type] + + @property + def priority(self) -> int: + return _PRIORITY[self.model_type] + + +class VRAMManager: + def __init__(self, total_vram_gb: float = 16.0): + self._total_vram_gb = total_vram_gb + self._loaded: dict[str, ModelSlot] = {} + self._lock = asyncio.Lock() + + @property + def available_vram_gb(self) -> float: + used = sum(slot.vram_gb for slot in self._loaded.values()) + return self._total_vram_gb - used + + def is_loaded(self, model_id: str) -> bool: + return model_id in self._loaded + + def get_loaded_models(self) -> dict[str, ModelSlot]: + return dict(self._loaded) + + async def load_model(self, model_id, model_type, vram_gb, backend): + async with self._lock: + await self._load_model_locked(model_id, model_type, vram_gb, backend) + + async def _load_model_locked(self, model_id, model_type, vram_gb, backend): + if model_id in self._loaded: + return + + if self.available_vram_gb < vram_gb: + await self._evict_for(vram_gb, model_type) + + if self.available_vram_gb < vram_gb: + raise RuntimeError( + f"Cannot free enough VRAM for {model_id} " + f"(need {vram_gb}GB, available {self.available_vram_gb}GB)" + ) + + logger.info(f"Loading {model_id} ({vram_gb}GB VRAM)") + await backend.load(model_id) + self._loaded[model_id] = ModelSlot( + model_id=model_id, + model_type=model_type, + vram_gb=vram_gb, + backend=backend, + ) + + async def _evict_for(self, needed_gb, requesting_type): + requesting_priority = _PRIORITY[requesting_type] + + # Evict in priority order: lowest first (LLM=0, TTS=1, ASR=2). + # Rule: never evict highest-priority tier (ASR) for a lower-priority + # request. Same-priority replacement is always allowed (e.g., old LLM + # evicted for new LLM). Lower-priority models are fair game for any + # requester — cascade through them until enough VRAM is freed. + candidates = sorted(self._loaded.values(), key=lambda s: s.priority) + for slot in candidates: + if self.available_vram_gb >= needed_gb: + break + # Skip if this slot is the highest-priority tier and the requester + # is lower priority. (Protects ASR from eviction by TTS/LLM.) + if slot.priority > requesting_priority and slot.model_type == "asr": + continue + logger.info( + f"Evicting {slot.model_id} ({slot.model_type}, {slot.vram_gb}GB)" + ) + await slot.backend.unload(slot.model_id) + del self._loaded[slot.model_id] diff --git a/kischdle/llmux/tests/test_vram_manager.py b/kischdle/llmux/tests/test_vram_manager.py new file mode 100644 index 0000000..489530d --- /dev/null +++ b/kischdle/llmux/tests/test_vram_manager.py @@ -0,0 +1,120 @@ +import asyncio +import pytest + +from llmux.vram_manager import VRAMManager, ModelSlot + + +class FakeBackend: + """Simulates a backend that tracks load/unload calls.""" + + def __init__(self): + self.loaded = {} + self.load_count = 0 + self.unload_count = 0 + + async def load(self, model_id: str): + self.loaded[model_id] = True + self.load_count += 1 + + async def unload(self, model_id: str): + self.loaded.pop(model_id, None) + self.unload_count += 1 + + +@pytest.fixture +def manager(): + return VRAMManager(total_vram_gb=16.0) + + +def test_priority_ordering(): + assert ModelSlot.priority_rank("llm") == 0 + assert ModelSlot.priority_rank("tts") == 1 + assert ModelSlot.priority_rank("asr") == 2 + + +@pytest.mark.asyncio +async def test_load_into_empty_vram(manager): + backend = FakeBackend() + await manager.load_model("qwen3.5-4b", model_type="llm", vram_gb=4.0, backend=backend) + assert manager.is_loaded("qwen3.5-4b") + assert manager.available_vram_gb == pytest.approx(12.0) + + +@pytest.mark.asyncio +async def test_load_alongside_when_fits(manager): + backend = FakeBackend() + await manager.load_model("cohere-transcribe", model_type="asr", vram_gb=4.0, backend=backend) + await manager.load_model("qwen3.5-4b", model_type="llm", vram_gb=4.0, backend=backend) + assert manager.is_loaded("cohere-transcribe") + assert manager.is_loaded("qwen3.5-4b") + assert manager.available_vram_gb == pytest.approx(8.0) + + +@pytest.mark.asyncio +async def test_evict_llm_first(manager): + backend = FakeBackend() + await manager.load_model("cohere-transcribe", model_type="asr", vram_gb=4.0, backend=backend) + await manager.load_model("chatterbox-multilingual", model_type="tts", vram_gb=2.0, backend=backend) + await manager.load_model("qwen3.5-4b", model_type="llm", vram_gb=4.0, backend=backend) + # 10 GB used. Loading 9B (9GB). Evict LLM (4B), free=12. ASR+TTS+9B=15, fits. + await manager.load_model("qwen3.5-9b-fp8", model_type="llm", vram_gb=9.0, backend=backend) + assert not manager.is_loaded("qwen3.5-4b") + assert manager.is_loaded("cohere-transcribe") + assert manager.is_loaded("chatterbox-multilingual") + assert manager.is_loaded("qwen3.5-9b-fp8") + + +@pytest.mark.asyncio +async def test_evict_cascade_for_large_llm(manager): + backend = FakeBackend() + await manager.load_model("cohere-transcribe", model_type="asr", vram_gb=4.0, backend=backend) + await manager.load_model("chatterbox-multilingual", model_type="tts", vram_gb=2.0, backend=backend) + await manager.load_model("qwen3.5-4b", model_type="llm", vram_gb=4.0, backend=backend) + # 10 GB used. gpt-oss-20b needs 12GB. Evict LLM(4)->free=10. Evict TTS(2)->free=12. Load. + await manager.load_model("gpt-oss-20b", model_type="llm", vram_gb=12.0, backend=backend) + assert not manager.is_loaded("qwen3.5-4b") + assert not manager.is_loaded("chatterbox-multilingual") + assert manager.is_loaded("cohere-transcribe") # ASR survives if possible + assert manager.is_loaded("gpt-oss-20b") + + +@pytest.mark.asyncio +async def test_asr_evicts_llm_not_reversed(manager): + """When ASR request arrives and LLM is loaded, evict LLM (lower priority).""" + backend = FakeBackend() + await manager.load_model("gpt-oss-20b", model_type="llm", vram_gb=13.0, backend=backend) + # 13GB used, 3GB free. ASR needs 4GB. Must evict LLM. + await manager.load_model("cohere-transcribe", model_type="asr", vram_gb=4.0, backend=backend) + assert not manager.is_loaded("gpt-oss-20b") + assert manager.is_loaded("cohere-transcribe") + + +@pytest.mark.asyncio +async def test_already_loaded_is_noop(manager): + backend = FakeBackend() + await manager.load_model("qwen3.5-4b", model_type="llm", vram_gb=4.0, backend=backend) + await manager.load_model("qwen3.5-4b", model_type="llm", vram_gb=4.0, backend=backend) + assert backend.load_count == 1 + + +@pytest.mark.asyncio +async def test_spec_scenario_switch_to_9b(manager): + backend = FakeBackend() + await manager.load_model("cohere-transcribe", model_type="asr", vram_gb=4.0, backend=backend) + await manager.load_model("chatterbox-multilingual", model_type="tts", vram_gb=2.0, backend=backend) + await manager.load_model("qwen3.5-4b", model_type="llm", vram_gb=4.0, backend=backend) + await manager.load_model("qwen3.5-9b-fp8", model_type="llm", vram_gb=9.0, backend=backend) + assert manager.is_loaded("cohere-transcribe") + assert manager.is_loaded("chatterbox-multilingual") + assert manager.is_loaded("qwen3.5-9b-fp8") + assert not manager.is_loaded("qwen3.5-4b") + assert manager.available_vram_gb == pytest.approx(1.0) + + +@pytest.mark.asyncio +async def test_get_loaded_models(manager): + backend = FakeBackend() + await manager.load_model("cohere-transcribe", model_type="asr", vram_gb=4.0, backend=backend) + await manager.load_model("qwen3.5-4b", model_type="llm", vram_gb=4.0, backend=backend) + loaded = manager.get_loaded_models() + assert set(loaded.keys()) == {"cohere-transcribe", "qwen3.5-4b"}