feat: VRAM manager with priority-based model eviction
Tracks GPU VRAM usage (16GB) and handles model loading/unloading with priority-based eviction: LLM (lowest) -> TTS -> ASR (highest, protected). Uses asyncio Lock for concurrency safety. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
89
kischdle/llmux/llmux/vram_manager.py
Normal file
89
kischdle/llmux/llmux/vram_manager.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_PRIORITY = {"llm": 0, "tts": 1, "asr": 2}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelSlot:
|
||||||
|
model_id: str
|
||||||
|
model_type: str
|
||||||
|
vram_gb: float
|
||||||
|
backend: object
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def priority_rank(model_type: str) -> int:
|
||||||
|
return _PRIORITY[model_type]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def priority(self) -> int:
|
||||||
|
return _PRIORITY[self.model_type]
|
||||||
|
|
||||||
|
|
||||||
|
class VRAMManager:
|
||||||
|
def __init__(self, total_vram_gb: float = 16.0):
|
||||||
|
self._total_vram_gb = total_vram_gb
|
||||||
|
self._loaded: dict[str, ModelSlot] = {}
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def available_vram_gb(self) -> float:
|
||||||
|
used = sum(slot.vram_gb for slot in self._loaded.values())
|
||||||
|
return self._total_vram_gb - used
|
||||||
|
|
||||||
|
def is_loaded(self, model_id: str) -> bool:
|
||||||
|
return model_id in self._loaded
|
||||||
|
|
||||||
|
def get_loaded_models(self) -> dict[str, ModelSlot]:
|
||||||
|
return dict(self._loaded)
|
||||||
|
|
||||||
|
async def load_model(self, model_id, model_type, vram_gb, backend):
|
||||||
|
async with self._lock:
|
||||||
|
await self._load_model_locked(model_id, model_type, vram_gb, backend)
|
||||||
|
|
||||||
|
async def _load_model_locked(self, model_id, model_type, vram_gb, backend):
|
||||||
|
if model_id in self._loaded:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.available_vram_gb < vram_gb:
|
||||||
|
await self._evict_for(vram_gb, model_type)
|
||||||
|
|
||||||
|
if self.available_vram_gb < vram_gb:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot free enough VRAM for {model_id} "
|
||||||
|
f"(need {vram_gb}GB, available {self.available_vram_gb}GB)"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Loading {model_id} ({vram_gb}GB VRAM)")
|
||||||
|
await backend.load(model_id)
|
||||||
|
self._loaded[model_id] = ModelSlot(
|
||||||
|
model_id=model_id,
|
||||||
|
model_type=model_type,
|
||||||
|
vram_gb=vram_gb,
|
||||||
|
backend=backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _evict_for(self, needed_gb, requesting_type):
|
||||||
|
requesting_priority = _PRIORITY[requesting_type]
|
||||||
|
|
||||||
|
# Evict in priority order: lowest first (LLM=0, TTS=1, ASR=2).
|
||||||
|
# Rule: never evict highest-priority tier (ASR) for a lower-priority
|
||||||
|
# request. Same-priority replacement is always allowed (e.g., old LLM
|
||||||
|
# evicted for new LLM). Lower-priority models are fair game for any
|
||||||
|
# requester — cascade through them until enough VRAM is freed.
|
||||||
|
candidates = sorted(self._loaded.values(), key=lambda s: s.priority)
|
||||||
|
for slot in candidates:
|
||||||
|
if self.available_vram_gb >= needed_gb:
|
||||||
|
break
|
||||||
|
# Skip if this slot is the highest-priority tier and the requester
|
||||||
|
# is lower priority. (Protects ASR from eviction by TTS/LLM.)
|
||||||
|
if slot.priority > requesting_priority and slot.model_type == "asr":
|
||||||
|
continue
|
||||||
|
logger.info(
|
||||||
|
f"Evicting {slot.model_id} ({slot.model_type}, {slot.vram_gb}GB)"
|
||||||
|
)
|
||||||
|
await slot.backend.unload(slot.model_id)
|
||||||
|
del self._loaded[slot.model_id]
|
||||||
120
kischdle/llmux/tests/test_vram_manager.py
Normal file
120
kischdle/llmux/tests/test_vram_manager.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
import asyncio
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llmux.vram_manager import VRAMManager, ModelSlot
|
||||||
|
|
||||||
|
|
||||||
|
class FakeBackend:
|
||||||
|
"""Simulates a backend that tracks load/unload calls."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.loaded = {}
|
||||||
|
self.load_count = 0
|
||||||
|
self.unload_count = 0
|
||||||
|
|
||||||
|
async def load(self, model_id: str):
|
||||||
|
self.loaded[model_id] = True
|
||||||
|
self.load_count += 1
|
||||||
|
|
||||||
|
async def unload(self, model_id: str):
|
||||||
|
self.loaded.pop(model_id, None)
|
||||||
|
self.unload_count += 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def manager():
|
||||||
|
return VRAMManager(total_vram_gb=16.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_priority_ordering():
|
||||||
|
assert ModelSlot.priority_rank("llm") == 0
|
||||||
|
assert ModelSlot.priority_rank("tts") == 1
|
||||||
|
assert ModelSlot.priority_rank("asr") == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_into_empty_vram(manager):
|
||||||
|
backend = FakeBackend()
|
||||||
|
await manager.load_model("qwen3.5-4b", model_type="llm", vram_gb=4.0, backend=backend)
|
||||||
|
assert manager.is_loaded("qwen3.5-4b")
|
||||||
|
assert manager.available_vram_gb == pytest.approx(12.0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_alongside_when_fits(manager):
|
||||||
|
backend = FakeBackend()
|
||||||
|
await manager.load_model("cohere-transcribe", model_type="asr", vram_gb=4.0, backend=backend)
|
||||||
|
await manager.load_model("qwen3.5-4b", model_type="llm", vram_gb=4.0, backend=backend)
|
||||||
|
assert manager.is_loaded("cohere-transcribe")
|
||||||
|
assert manager.is_loaded("qwen3.5-4b")
|
||||||
|
assert manager.available_vram_gb == pytest.approx(8.0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_evict_llm_first(manager):
|
||||||
|
backend = FakeBackend()
|
||||||
|
await manager.load_model("cohere-transcribe", model_type="asr", vram_gb=4.0, backend=backend)
|
||||||
|
await manager.load_model("chatterbox-multilingual", model_type="tts", vram_gb=2.0, backend=backend)
|
||||||
|
await manager.load_model("qwen3.5-4b", model_type="llm", vram_gb=4.0, backend=backend)
|
||||||
|
# 10 GB used. Loading 9B (9GB). Evict LLM (4B), free=12. ASR+TTS+9B=15, fits.
|
||||||
|
await manager.load_model("qwen3.5-9b-fp8", model_type="llm", vram_gb=9.0, backend=backend)
|
||||||
|
assert not manager.is_loaded("qwen3.5-4b")
|
||||||
|
assert manager.is_loaded("cohere-transcribe")
|
||||||
|
assert manager.is_loaded("chatterbox-multilingual")
|
||||||
|
assert manager.is_loaded("qwen3.5-9b-fp8")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_evict_cascade_for_large_llm(manager):
|
||||||
|
backend = FakeBackend()
|
||||||
|
await manager.load_model("cohere-transcribe", model_type="asr", vram_gb=4.0, backend=backend)
|
||||||
|
await manager.load_model("chatterbox-multilingual", model_type="tts", vram_gb=2.0, backend=backend)
|
||||||
|
await manager.load_model("qwen3.5-4b", model_type="llm", vram_gb=4.0, backend=backend)
|
||||||
|
# 10 GB used. gpt-oss-20b needs 12GB. Evict LLM(4)->free=10. Evict TTS(2)->free=12. Load.
|
||||||
|
await manager.load_model("gpt-oss-20b", model_type="llm", vram_gb=12.0, backend=backend)
|
||||||
|
assert not manager.is_loaded("qwen3.5-4b")
|
||||||
|
assert not manager.is_loaded("chatterbox-multilingual")
|
||||||
|
assert manager.is_loaded("cohere-transcribe") # ASR survives if possible
|
||||||
|
assert manager.is_loaded("gpt-oss-20b")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_asr_evicts_llm_not_reversed(manager):
|
||||||
|
"""When ASR request arrives and LLM is loaded, evict LLM (lower priority)."""
|
||||||
|
backend = FakeBackend()
|
||||||
|
await manager.load_model("gpt-oss-20b", model_type="llm", vram_gb=13.0, backend=backend)
|
||||||
|
# 13GB used, 3GB free. ASR needs 4GB. Must evict LLM.
|
||||||
|
await manager.load_model("cohere-transcribe", model_type="asr", vram_gb=4.0, backend=backend)
|
||||||
|
assert not manager.is_loaded("gpt-oss-20b")
|
||||||
|
assert manager.is_loaded("cohere-transcribe")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_already_loaded_is_noop(manager):
|
||||||
|
backend = FakeBackend()
|
||||||
|
await manager.load_model("qwen3.5-4b", model_type="llm", vram_gb=4.0, backend=backend)
|
||||||
|
await manager.load_model("qwen3.5-4b", model_type="llm", vram_gb=4.0, backend=backend)
|
||||||
|
assert backend.load_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_spec_scenario_switch_to_9b(manager):
|
||||||
|
backend = FakeBackend()
|
||||||
|
await manager.load_model("cohere-transcribe", model_type="asr", vram_gb=4.0, backend=backend)
|
||||||
|
await manager.load_model("chatterbox-multilingual", model_type="tts", vram_gb=2.0, backend=backend)
|
||||||
|
await manager.load_model("qwen3.5-4b", model_type="llm", vram_gb=4.0, backend=backend)
|
||||||
|
await manager.load_model("qwen3.5-9b-fp8", model_type="llm", vram_gb=9.0, backend=backend)
|
||||||
|
assert manager.is_loaded("cohere-transcribe")
|
||||||
|
assert manager.is_loaded("chatterbox-multilingual")
|
||||||
|
assert manager.is_loaded("qwen3.5-9b-fp8")
|
||||||
|
assert not manager.is_loaded("qwen3.5-4b")
|
||||||
|
assert manager.available_vram_gb == pytest.approx(1.0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_loaded_models(manager):
|
||||||
|
backend = FakeBackend()
|
||||||
|
await manager.load_model("cohere-transcribe", model_type="asr", vram_gb=4.0, backend=backend)
|
||||||
|
await manager.load_model("qwen3.5-4b", model_type="llm", vram_gb=4.0, backend=backend)
|
||||||
|
loaded = manager.get_loaded_models()
|
||||||
|
assert set(loaded.keys()) == {"cohere-transcribe", "qwen3.5-4b"}
|
||||||
Reference in New Issue
Block a user