diff --git a/kischdle/llmux/llmux/routes/admin.py b/kischdle/llmux/llmux/routes/admin.py index 01b5e46..69688b0 100644 --- a/kischdle/llmux/llmux/routes/admin.py +++ b/kischdle/llmux/llmux/routes/admin.py @@ -12,6 +12,19 @@ TEST_PROMPT = [{"role": "user", "content": "Say hello in one sentence."}] def create_admin_router(registry, vram_manager, backends, require_api_key): router = APIRouter() + @router.post("/admin/clear-vram") + async def clear_vram(api_key: str = Depends(require_api_key)): + """Unload all models and clear GPU VRAM.""" + result = await vram_manager.clear_all() + import torch + gpu_info = {} + if torch.cuda.is_available(): + gpu_info = { + "gpu_memory_used_mb": round(torch.cuda.memory_allocated() / 1024**2, 1), + "gpu_memory_reserved_mb": round(torch.cuda.memory_reserved() / 1024**2, 1), + } + return {**result, **gpu_info} + @router.post("/admin/test/performance") async def test_performance(request: Request, api_key: str = Depends(require_api_key)): body = await request.json() diff --git a/kischdle/llmux/llmux/vram_manager.py b/kischdle/llmux/llmux/vram_manager.py index e3af2a6..17e6ff4 100644 --- a/kischdle/llmux/llmux/vram_manager.py +++ b/kischdle/llmux/llmux/vram_manager.py @@ -40,6 +40,27 @@ class VRAMManager: def get_loaded_models(self) -> dict[str, ModelSlot]: return dict(self._loaded) + async def clear_all(self) -> dict: + """Unload all models and clear CUDA cache. Returns what was unloaded.""" + import gc + import torch + + async with self._lock: + unloaded = [] + for slot in list(self._loaded.values()): + logger.info(f"Clearing {slot.model_id} ({slot.model_type}, {slot.vram_gb}GB)") + await slot.backend.unload(slot.model_id) + unloaded.append(slot.model_id) + self._loaded.clear() + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + return { + "unloaded": unloaded, + "available_vram_gb": round(self.available_vram_gb, 1), + } + async def load_model(self, model_id, model_type, vram_gb, backend): async with self._lock: await self._load_model_locked(model_id, model_type, vram_gb, backend)