From c200fa75ab34efa2fa4b38eae59976a65db188dc Mon Sep 17 00:00:00 2001 From: llm Date: Fri, 21 Nov 2025 21:58:51 +0100 Subject: [PATCH] Add new embed-multimodal-7b.py API for text and image embeddings, two embedding models --- .../python-apps/embed-multimodal-7b.py | 506 ++++++++++++++++++ 1 file changed, 506 insertions(+) create mode 100644 .local/share/pytorch_pod/python-apps/embed-multimodal-7b.py diff --git a/.local/share/pytorch_pod/python-apps/embed-multimodal-7b.py b/.local/share/pytorch_pod/python-apps/embed-multimodal-7b.py new file mode 100644 index 0000000..98e344b --- /dev/null +++ b/.local/share/pytorch_pod/python-apps/embed-multimodal-7b.py @@ -0,0 +1,506 @@ +#!/usr/bin/env python +import base64 +import gc +import io +import os +import threading +from typing import List, Optional + +import torch +from colpali_engine.models import ColQwen2_5, ColQwen2_5_Processor +from fastapi import FastAPI, HTTPException +from PIL import Image, ImageFile +from pydantic import BaseModel, Field +from transformers.utils.import_utils import is_flash_attn_2_available + +# -------------------------------------------------------------------------------------- +# Configuration +# -------------------------------------------------------------------------------------- + +# Allowed models (strictly limited per user request) +MODEL_ID_NOMIC = "nomic-ai/colnomic-embed-multimodal-7b" +MODEL_ID_EVO_7B = "ApsaraStackMaaS/EvoQwen2.5-VL-Retriever-7B-v1" +ALLOWED_MODEL_IDS = {MODEL_ID_NOMIC, MODEL_ID_EVO_7B} + +# Default selected model (must be one of ALLOWED_MODEL_IDS). If env not allowed, fallback to MODEL_ID_NOMIC. +ENV_DEFAULT_MODEL = os.environ.get("HF_MODEL_ID", MODEL_ID_NOMIC) +DEFAULT_MODEL_ID = ( + ENV_DEFAULT_MODEL if ENV_DEFAULT_MODEL in ALLOWED_MODEL_IDS else MODEL_ID_NOMIC +) + +HF_MODEL_URL = os.environ.get("HF_MODEL_URL") # optional informational field +API_PORT = int(os.environ.get("PYTORCH_CONTAINER_PORT", os.environ.get("PORT", "8000"))) + +# Limits (env-overridable) +MAX_TEXTS_PER_REQUEST = int(os.environ.get("TEXT_MAX_ITEMS", "32")) +MAX_IMAGES_PER_REQUEST = int(os.environ.get("IMAGE_MAX_ITEMS", "8")) +MAX_IMAGE_BASE64_BYTES = int( + os.environ.get("IMAGE_MAX_BASE64_BYTES", str(25 * 1024 * 1024)) +) # 25MB per image b64 (approx) +MAX_IMAGE_PIXELS = int( + os.environ.get("IMAGE_MAX_PIXELS", str(30_000_000)) +) # ~30MP safety + +# PIL safety for large images +ImageFile.LOAD_TRUNCATED_IMAGES = True +Image.MAX_IMAGE_PIXELS = MAX_IMAGE_PIXELS + +# -------------------------------------------------------------------------------------- +# App + Global State +# -------------------------------------------------------------------------------------- + +app = FastAPI(title="Colnomic Embed Multimodal API") + +_model_lock = threading.RLock() + +_model: Optional[torch.nn.Module] = None +_processor: Optional[ColQwen2_5_Processor] = None +_loaded_model_id: Optional[str] = None + +# For reporting +_dtype_str: Optional[str] = None +_device_str: str = "cuda:0" + +_active_model_id: str = DEFAULT_MODEL_ID + + +# -------------------------------------------------------------------------------------- +# Pydantic Models +# -------------------------------------------------------------------------------------- + + +class SelectModelRequest(BaseModel): + model_id: str = Field(..., description="One of the allowed model IDs") + + +class EmbedTextsRequest(BaseModel): + # Validated in handler for min length + texts: List[str] + + +class EmbedResponse(BaseModel): + model_id: str + # results[batch][tokens][dim] + results: List[List[List[float]]] + + +class EmbedImagesRequest(BaseModel): + # Validated in handler for min length + images_b64: List[str] # base64 encoded images only + + +class ImageMetadata(BaseModel): + index: int + status: str # "ok" | "error" + width: Optional[int] = None + height: Optional[int] = None + mode: Optional[str] = None + error: Optional[str] = None + + +class EmbedImagesResponse(BaseModel): + model_id: str + results: List[List[List[float]]] + metadata: List[ImageMetadata] + + +# -------------------------------------------------------------------------------------- +# Helpers +# -------------------------------------------------------------------------------------- + + +def _torch_dtype_str(dtype: torch.dtype) -> str: + if dtype == torch.bfloat16: + return "bfloat16" + if dtype == torch.float16: + return "float16" + return str(dtype) + + +def _hard_requirements_check(): + # CUDA hard requirement + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available; a CUDA-capable GPU is required.") + # FlashAttention-2 hard requirement + if not is_flash_attn_2_available(): + raise RuntimeError( + "flash_attn_2 is not available; this deployment requires FlashAttention-2." + ) + + +def _pick_dtype() -> torch.dtype: + return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + + +def _unload_model_locked(): + global _model, _processor, _loaded_model_id, _dtype_str + # Assumes caller holds _model_lock + before_free, before_total = (0, 0) + if torch.cuda.is_available(): + free, total = torch.cuda.mem_get_info(0) + before_free, before_total = free, total + + _model = None + _processor = None + _loaded_model_id = None + _dtype_str = None + + gc.collect() + if torch.cuda.is_available(): + try: + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + except Exception: + # Best-effort cleanup; continue + pass + + after_free, after_total = (0, 0) + if torch.cuda.is_available(): + free, total = torch.cuda.mem_get_info(0) + after_free, after_total = free, total + + return { + "before": {"free": before_free, "total": before_total}, + "after": {"free": after_free, "total": after_total}, + "freed": max(0, (after_free - before_free)), + } + + +def _load_model_locked(model_id: str): + # Assumes caller holds _model_lock + global _model, _processor, _loaded_model_id, _dtype_str, _device_str + + _hard_requirements_check() + + dtype = _pick_dtype() + device_map = "cuda:0" + attn_impl = "flash_attention_2" # we ensured availability above + + model = ColQwen2_5.from_pretrained( + model_id, + torch_dtype=dtype, + device_map=device_map, + attn_implementation=attn_impl, + ).eval() + + processor = ColQwen2_5_Processor.from_pretrained(model_id) + + _model = model + _processor = processor + _loaded_model_id = model_id + _dtype_str = _torch_dtype_str(dtype) + _device_str = device_map + + +def _ensure_model_loaded(): + with _model_lock: + if ( + _model is not None + and _processor is not None + and _loaded_model_id == _active_model_id + ): + model, processor = _model, _processor + assert model is not None and processor is not None + return model, processor + # Different or missing model: (re)load + _unload_model_locked() + _load_model_locked(_active_model_id) + model, processor = _model, _processor + assert model is not None and processor is not None + return model, processor + + +def _current_vram_info(): + if not torch.cuda.is_available(): + return {"free": None, "total": None, "used": None} + free, total = torch.cuda.mem_get_info(0) + used = total - free + return {"free": free, "total": total, "used": used} + + +def _decode_base64_image(b64_data: str): + # Size guard + approx_bytes = int(len(b64_data) * 0.75) # rough, base64 overhead ~33% + if approx_bytes > MAX_IMAGE_BASE64_BYTES: + raise ValueError( + f"Image exceeds max base64 size of {MAX_IMAGE_BASE64_BYTES} bytes" + ) + + try: + raw = base64.b64decode(b64_data, validate=True) + except Exception as e: + raise ValueError(f"Invalid base64: {e}") + + try: + img = Image.open(io.BytesIO(raw)) + # Convert to RGB for model compatibility + if img.mode != "RGB": + img = img.convert("RGB") + img.load() # ensure data is read + return img + except Exception as e: + raise ValueError(f"Unable to decode image: {e}") + + +def _extract_embeddings(outputs) -> torch.Tensor: + # ColQwen2.5 returns either: + # - a tensor shaped (batch, tokens, dim), or + # - an object with .last_hidden_state + if isinstance(outputs, torch.Tensor): + embeddings = outputs + elif hasattr(outputs, "last_hidden_state"): + embeddings = outputs.last_hidden_state + else: + raise RuntimeError(f"Unexpected model output type: {type(outputs)}") + + if embeddings.dim() == 2: # (tokens, dim) -> single item + embeddings = embeddings.unsqueeze(0) + elif embeddings.dim() != 3: + raise RuntimeError(f"Unexpected embedding shape: {tuple(embeddings.shape)}") + + return embeddings + + +# -------------------------------------------------------------------------------------- +# Endpoints +# -------------------------------------------------------------------------------------- + + +@app.get("/health") +def health(): + """ + Health check with hard requirements: + - CUDA available + - FlashAttention-2 available + - Lazy-loads the active model once + - Includes dtype, device, and VRAM info + """ + cuda_ok = bool(torch.cuda.is_available()) + flash_ok = bool(is_flash_attn_2_available()) + + info = { + "status": "ok", + "model_id": _active_model_id, + "model_url": HF_MODEL_URL, + "cuda_available": cuda_ok, + "flash_attn_2_available": flash_ok, + "dtype": _dtype_str, + "device": _device_str, + "vram_bytes": _current_vram_info(), + } + + # Hard failures + if not cuda_ok: + info["status"] = "error" + info["error"] = "CUDA is not available inside the container." + raise HTTPException(status_code=500, detail=info) + + if not flash_ok: + info["status"] = "error" + info["error"] = ( + "flash_attn_2 is not available; this deployment requires FlashAttention-2." + ) + raise HTTPException(status_code=500, detail=info) + + try: + _ensure_model_loaded() + except Exception as exc: + info["status"] = "error" + info["error"] = str(exc) + raise HTTPException(status_code=500, detail=info) from exc + + # Ensure final dtype/device populated + info["dtype"] = _dtype_str + info["device"] = _device_str + info["vram_bytes"] = _current_vram_info() + return info + + +@app.post("/select-model") +def select_model(req: SelectModelRequest): + """ + Switch the active model between the allowed set. + Fully unloads the current model (free VRAM) then loads the new one. + Blocks concurrent requests briefly via a lock. + """ + global _active_model_id + target = req.model_id.strip() + + if target not in ALLOWED_MODEL_IDS: + raise HTTPException( + status_code=400, + detail={ + "error": "Unsupported model_id", + "allowed": sorted(list(ALLOWED_MODEL_IDS)), + "received": target, + }, + ) + + with _model_lock: + if ( + target == _active_model_id + and _model is not None + and _loaded_model_id == target + ): + # No-op + return { + "status": "ok", + "model_id": _active_model_id, + "message": "Model unchanged; already active.", + } + + # Switch + _active_model_id = target + _unload_model_locked() + try: + _load_model_locked(_active_model_id) + except Exception as exc: + # Attempt to revert to a safe state: no model loaded + _unload_model_locked() + raise HTTPException( + status_code=500, + detail={"error": f"Failed to load model '{target}': {exc}"}, + ) from exc + + return {"status": "ok", "model_id": _active_model_id} + + +@app.post("/embed-texts", response_model=EmbedResponse) +def embed_texts(request: EmbedTextsRequest): + """ + Compute multi-vector embeddings for a list of texts. + Result shape: results[batch][tokens][dim] (multi-vector per text). + + Limits: + - Max texts per request: TEXT_MAX_ITEMS (default 32) + """ + texts = request.texts + if not texts: + raise HTTPException(status_code=400, detail="texts must not be empty") + if len(texts) > MAX_TEXTS_PER_REQUEST: + raise HTTPException( + status_code=400, detail=f"Too many texts; max is {MAX_TEXTS_PER_REQUEST}" + ) + + with _model_lock: + model, processor = _ensure_model_loaded() + + try: + with torch.inference_mode(): + batch = processor.process_queries(texts).to(_device_str) + outputs = model(**batch) + embeddings = _extract_embeddings(outputs) + except Exception as exc: + raise HTTPException( + status_code=500, detail=f"Failed to compute text embeddings: {exc}" + ) from exc + + embeddings = embeddings.detach().cpu().float() + results = embeddings.tolist() + + return EmbedResponse(model_id=_active_model_id, results=results) + + +@app.post("/embed-images", response_model=EmbedImagesResponse) +def embed_images(request: EmbedImagesRequest): + """ + Compute multi-vector embeddings for a list of base64-encoded images. + Returns results aligned with the input order: results[i] is [] if the i-th image failed to decode. + + Limits: + - Max images per request: IMAGE_MAX_ITEMS (default 8) + - Max base64 bytes per image: IMAGE_MAX_BASE64_BYTES (default ~25MB) + - Max image pixels (safety): IMAGE_MAX_PIXELS (default ~30MP) + """ + b64_list = request.images_b64 + if not b64_list: + raise HTTPException(status_code=400, detail="images_b64 must not be empty") + if len(b64_list) > MAX_IMAGES_PER_REQUEST: + raise HTTPException( + status_code=400, detail=f"Too many images; max is {MAX_IMAGES_PER_REQUEST}" + ) + + # Decode individually and track metadata + decoded_images: List[Optional[Image.Image]] = [None] * len(b64_list) + metadata: List[ImageMetadata] = [] + ok_indices: List[int] = [] + + for idx, b64_img in enumerate(b64_list): + try: + img = _decode_base64_image(b64_img) + decoded_images[idx] = img + w, h = img.size + metadata.append( + ImageMetadata(index=idx, status="ok", width=w, height=h, mode=img.mode) + ) + ok_indices.append(idx) + except Exception as exc: + metadata.append(ImageMetadata(index=idx, status="error", error=str(exc))) + + if not ok_indices: + raise HTTPException( + status_code=400, + detail="All provided images failed to decode or were rejected by limits", + ) + + # Prepare only successful images for batching, but preserve order in output + images_ok: List[Image.Image] = [] + for i in ok_indices: + img_i = decoded_images[i] + assert img_i is not None + images_ok.append(img_i) + + with _model_lock: + model, processor = _ensure_model_loaded() + + try: + with torch.inference_mode(): + batch_images = processor.process_images(images_ok).to(_device_str) + outputs = model(**batch_images) + embeddings = _extract_embeddings(outputs) + except Exception as exc: + raise HTTPException( + status_code=500, detail=f"Failed to compute image embeddings: {exc}" + ) from exc + + embeddings = embeddings.detach().cpu().float().tolist() + + # Distribute embeddings back into results aligned with original indices + # For failed entries, place an empty list []. + results: List[List[List[float]]] = [[] for _ in range(len(b64_list))] + for pos, idx in enumerate(ok_indices): + results[idx] = embeddings[pos] + + return EmbedImagesResponse( + model_id=_active_model_id, results=results, metadata=metadata + ) + + +@app.post("/free-vram") +def free_vram(): + """ + Frees GPU VRAM by unloading the model/processor and emptying CUDA caches. + The active model selection is preserved, but the next request will re-load the model. + """ + with _model_lock: + before = _current_vram_info() + stats = _unload_model_locked() + after = _current_vram_info() + + return { + "status": "ok", + "active_model_id": _active_model_id, + "vram_bytes_before": before, + "vram_bytes_after": after, + "free_stats": stats, + } + + +# -------------------------------------------------------------------------------------- +# Entrypoint +# -------------------------------------------------------------------------------------- + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=API_PORT)