feat: VRAM manager with priority-based model eviction

Tracks GPU VRAM usage (16GB) and handles model loading/unloading with
priority-based eviction: LLM (lowest) -> TTS -> ASR (highest, protected).
Uses asyncio Lock for concurrency safety.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
tlg
2026-04-04 09:14:41 +02:00
parent 969bcb3292
commit d7a091df8c
2 changed files with 209 additions and 0 deletions

View File

@@ -0,0 +1,89 @@
import asyncio
import logging
from dataclasses import dataclass
logger = logging.getLogger(__name__)
_PRIORITY = {"llm": 0, "tts": 1, "asr": 2}
@dataclass
class ModelSlot:
model_id: str
model_type: str
vram_gb: float
backend: object
@staticmethod
def priority_rank(model_type: str) -> int:
return _PRIORITY[model_type]
@property
def priority(self) -> int:
return _PRIORITY[self.model_type]
class VRAMManager:
def __init__(self, total_vram_gb: float = 16.0):
self._total_vram_gb = total_vram_gb
self._loaded: dict[str, ModelSlot] = {}
self._lock = asyncio.Lock()
@property
def available_vram_gb(self) -> float:
used = sum(slot.vram_gb for slot in self._loaded.values())
return self._total_vram_gb - used
def is_loaded(self, model_id: str) -> bool:
return model_id in self._loaded
def get_loaded_models(self) -> dict[str, ModelSlot]:
return dict(self._loaded)
async def load_model(self, model_id, model_type, vram_gb, backend):
async with self._lock:
await self._load_model_locked(model_id, model_type, vram_gb, backend)
async def _load_model_locked(self, model_id, model_type, vram_gb, backend):
if model_id in self._loaded:
return
if self.available_vram_gb < vram_gb:
await self._evict_for(vram_gb, model_type)
if self.available_vram_gb < vram_gb:
raise RuntimeError(
f"Cannot free enough VRAM for {model_id} "
f"(need {vram_gb}GB, available {self.available_vram_gb}GB)"
)
logger.info(f"Loading {model_id} ({vram_gb}GB VRAM)")
await backend.load(model_id)
self._loaded[model_id] = ModelSlot(
model_id=model_id,
model_type=model_type,
vram_gb=vram_gb,
backend=backend,
)
async def _evict_for(self, needed_gb, requesting_type):
requesting_priority = _PRIORITY[requesting_type]
# Evict in priority order: lowest first (LLM=0, TTS=1, ASR=2).
# Rule: never evict highest-priority tier (ASR) for a lower-priority
# request. Same-priority replacement is always allowed (e.g., old LLM
# evicted for new LLM). Lower-priority models are fair game for any
# requester — cascade through them until enough VRAM is freed.
candidates = sorted(self._loaded.values(), key=lambda s: s.priority)
for slot in candidates:
if self.available_vram_gb >= needed_gb:
break
# Skip if this slot is the highest-priority tier and the requester
# is lower priority. (Protects ASR from eviction by TTS/LLM.)
if slot.priority > requesting_priority and slot.model_type == "asr":
continue
logger.info(
f"Evicting {slot.model_id} ({slot.model_type}, {slot.vram_gb}GB)"
)
await slot.backend.unload(slot.model_id)
del self._loaded[slot.model_id]

View File

@@ -0,0 +1,120 @@
import asyncio
import pytest
from llmux.vram_manager import VRAMManager, ModelSlot
class FakeBackend:
"""Simulates a backend that tracks load/unload calls."""
def __init__(self):
self.loaded = {}
self.load_count = 0
self.unload_count = 0
async def load(self, model_id: str):
self.loaded[model_id] = True
self.load_count += 1
async def unload(self, model_id: str):
self.loaded.pop(model_id, None)
self.unload_count += 1
@pytest.fixture
def manager():
return VRAMManager(total_vram_gb=16.0)
def test_priority_ordering():
assert ModelSlot.priority_rank("llm") == 0
assert ModelSlot.priority_rank("tts") == 1
assert ModelSlot.priority_rank("asr") == 2
@pytest.mark.asyncio
async def test_load_into_empty_vram(manager):
backend = FakeBackend()
await manager.load_model("qwen3.5-4b", model_type="llm", vram_gb=4.0, backend=backend)
assert manager.is_loaded("qwen3.5-4b")
assert manager.available_vram_gb == pytest.approx(12.0)
@pytest.mark.asyncio
async def test_load_alongside_when_fits(manager):
backend = FakeBackend()
await manager.load_model("cohere-transcribe", model_type="asr", vram_gb=4.0, backend=backend)
await manager.load_model("qwen3.5-4b", model_type="llm", vram_gb=4.0, backend=backend)
assert manager.is_loaded("cohere-transcribe")
assert manager.is_loaded("qwen3.5-4b")
assert manager.available_vram_gb == pytest.approx(8.0)
@pytest.mark.asyncio
async def test_evict_llm_first(manager):
backend = FakeBackend()
await manager.load_model("cohere-transcribe", model_type="asr", vram_gb=4.0, backend=backend)
await manager.load_model("chatterbox-multilingual", model_type="tts", vram_gb=2.0, backend=backend)
await manager.load_model("qwen3.5-4b", model_type="llm", vram_gb=4.0, backend=backend)
# 10 GB used. Loading 9B (9GB). Evict LLM (4B), free=12. ASR+TTS+9B=15, fits.
await manager.load_model("qwen3.5-9b-fp8", model_type="llm", vram_gb=9.0, backend=backend)
assert not manager.is_loaded("qwen3.5-4b")
assert manager.is_loaded("cohere-transcribe")
assert manager.is_loaded("chatterbox-multilingual")
assert manager.is_loaded("qwen3.5-9b-fp8")
@pytest.mark.asyncio
async def test_evict_cascade_for_large_llm(manager):
backend = FakeBackend()
await manager.load_model("cohere-transcribe", model_type="asr", vram_gb=4.0, backend=backend)
await manager.load_model("chatterbox-multilingual", model_type="tts", vram_gb=2.0, backend=backend)
await manager.load_model("qwen3.5-4b", model_type="llm", vram_gb=4.0, backend=backend)
# 10 GB used. gpt-oss-20b needs 12GB. Evict LLM(4)->free=10. Evict TTS(2)->free=12. Load.
await manager.load_model("gpt-oss-20b", model_type="llm", vram_gb=12.0, backend=backend)
assert not manager.is_loaded("qwen3.5-4b")
assert not manager.is_loaded("chatterbox-multilingual")
assert manager.is_loaded("cohere-transcribe") # ASR survives if possible
assert manager.is_loaded("gpt-oss-20b")
@pytest.mark.asyncio
async def test_asr_evicts_llm_not_reversed(manager):
"""When ASR request arrives and LLM is loaded, evict LLM (lower priority)."""
backend = FakeBackend()
await manager.load_model("gpt-oss-20b", model_type="llm", vram_gb=13.0, backend=backend)
# 13GB used, 3GB free. ASR needs 4GB. Must evict LLM.
await manager.load_model("cohere-transcribe", model_type="asr", vram_gb=4.0, backend=backend)
assert not manager.is_loaded("gpt-oss-20b")
assert manager.is_loaded("cohere-transcribe")
@pytest.mark.asyncio
async def test_already_loaded_is_noop(manager):
backend = FakeBackend()
await manager.load_model("qwen3.5-4b", model_type="llm", vram_gb=4.0, backend=backend)
await manager.load_model("qwen3.5-4b", model_type="llm", vram_gb=4.0, backend=backend)
assert backend.load_count == 1
@pytest.mark.asyncio
async def test_spec_scenario_switch_to_9b(manager):
backend = FakeBackend()
await manager.load_model("cohere-transcribe", model_type="asr", vram_gb=4.0, backend=backend)
await manager.load_model("chatterbox-multilingual", model_type="tts", vram_gb=2.0, backend=backend)
await manager.load_model("qwen3.5-4b", model_type="llm", vram_gb=4.0, backend=backend)
await manager.load_model("qwen3.5-9b-fp8", model_type="llm", vram_gb=9.0, backend=backend)
assert manager.is_loaded("cohere-transcribe")
assert manager.is_loaded("chatterbox-multilingual")
assert manager.is_loaded("qwen3.5-9b-fp8")
assert not manager.is_loaded("qwen3.5-4b")
assert manager.available_vram_gb == pytest.approx(1.0)
@pytest.mark.asyncio
async def test_get_loaded_models(manager):
backend = FakeBackend()
await manager.load_model("cohere-transcribe", model_type="asr", vram_gb=4.0, backend=backend)
await manager.load_model("qwen3.5-4b", model_type="llm", vram_gb=4.0, backend=backend)
loaded = manager.get_loaded_models()
assert set(loaded.keys()) == {"cohere-transcribe", "qwen3.5-4b"}