Add new embed-multimodal-7b.py API for text and image embeddings, two
embedding models
This commit is contained in:
506
.local/share/pytorch_pod/python-apps/embed-multimodal-7b.py
Normal file
506
.local/share/pytorch_pod/python-apps/embed-multimodal-7b.py
Normal file
@@ -0,0 +1,506 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
import base64
|
||||||
|
import gc
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from colpali_engine.models import ColQwen2_5, ColQwen2_5_Processor
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from PIL import Image, ImageFile
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from transformers.utils.import_utils import is_flash_attn_2_available
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------------------
|
||||||
|
# Configuration
|
||||||
|
# --------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Allowed models (strictly limited per user request)
|
||||||
|
MODEL_ID_NOMIC = "nomic-ai/colnomic-embed-multimodal-7b"
|
||||||
|
MODEL_ID_EVO_7B = "ApsaraStackMaaS/EvoQwen2.5-VL-Retriever-7B-v1"
|
||||||
|
ALLOWED_MODEL_IDS = {MODEL_ID_NOMIC, MODEL_ID_EVO_7B}
|
||||||
|
|
||||||
|
# Default selected model (must be one of ALLOWED_MODEL_IDS). If env not allowed, fallback to MODEL_ID_NOMIC.
|
||||||
|
ENV_DEFAULT_MODEL = os.environ.get("HF_MODEL_ID", MODEL_ID_NOMIC)
|
||||||
|
DEFAULT_MODEL_ID = (
|
||||||
|
ENV_DEFAULT_MODEL if ENV_DEFAULT_MODEL in ALLOWED_MODEL_IDS else MODEL_ID_NOMIC
|
||||||
|
)
|
||||||
|
|
||||||
|
HF_MODEL_URL = os.environ.get("HF_MODEL_URL") # optional informational field
|
||||||
|
API_PORT = int(os.environ.get("PYTORCH_CONTAINER_PORT", os.environ.get("PORT", "8000")))
|
||||||
|
|
||||||
|
# Limits (env-overridable)
|
||||||
|
MAX_TEXTS_PER_REQUEST = int(os.environ.get("TEXT_MAX_ITEMS", "32"))
|
||||||
|
MAX_IMAGES_PER_REQUEST = int(os.environ.get("IMAGE_MAX_ITEMS", "8"))
|
||||||
|
MAX_IMAGE_BASE64_BYTES = int(
|
||||||
|
os.environ.get("IMAGE_MAX_BASE64_BYTES", str(25 * 1024 * 1024))
|
||||||
|
) # 25MB per image b64 (approx)
|
||||||
|
MAX_IMAGE_PIXELS = int(
|
||||||
|
os.environ.get("IMAGE_MAX_PIXELS", str(30_000_000))
|
||||||
|
) # ~30MP safety
|
||||||
|
|
||||||
|
# PIL safety for large images
|
||||||
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
Image.MAX_IMAGE_PIXELS = MAX_IMAGE_PIXELS
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------------------
|
||||||
|
# App + Global State
|
||||||
|
# --------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
app = FastAPI(title="Colnomic Embed Multimodal API")
|
||||||
|
|
||||||
|
_model_lock = threading.RLock()
|
||||||
|
|
||||||
|
_model: Optional[torch.nn.Module] = None
|
||||||
|
_processor: Optional[ColQwen2_5_Processor] = None
|
||||||
|
_loaded_model_id: Optional[str] = None
|
||||||
|
|
||||||
|
# For reporting
|
||||||
|
_dtype_str: Optional[str] = None
|
||||||
|
_device_str: str = "cuda:0"
|
||||||
|
|
||||||
|
_active_model_id: str = DEFAULT_MODEL_ID
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------------------
|
||||||
|
# Pydantic Models
|
||||||
|
# --------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class SelectModelRequest(BaseModel):
|
||||||
|
model_id: str = Field(..., description="One of the allowed model IDs")
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedTextsRequest(BaseModel):
|
||||||
|
# Validated in handler for min length
|
||||||
|
texts: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedResponse(BaseModel):
|
||||||
|
model_id: str
|
||||||
|
# results[batch][tokens][dim]
|
||||||
|
results: List[List[List[float]]]
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedImagesRequest(BaseModel):
|
||||||
|
# Validated in handler for min length
|
||||||
|
images_b64: List[str] # base64 encoded images only
|
||||||
|
|
||||||
|
|
||||||
|
class ImageMetadata(BaseModel):
|
||||||
|
index: int
|
||||||
|
status: str # "ok" | "error"
|
||||||
|
width: Optional[int] = None
|
||||||
|
height: Optional[int] = None
|
||||||
|
mode: Optional[str] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedImagesResponse(BaseModel):
|
||||||
|
model_id: str
|
||||||
|
results: List[List[List[float]]]
|
||||||
|
metadata: List[ImageMetadata]
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# --------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _torch_dtype_str(dtype: torch.dtype) -> str:
|
||||||
|
if dtype == torch.bfloat16:
|
||||||
|
return "bfloat16"
|
||||||
|
if dtype == torch.float16:
|
||||||
|
return "float16"
|
||||||
|
return str(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def _hard_requirements_check():
|
||||||
|
# CUDA hard requirement
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise RuntimeError("CUDA is not available; a CUDA-capable GPU is required.")
|
||||||
|
# FlashAttention-2 hard requirement
|
||||||
|
if not is_flash_attn_2_available():
|
||||||
|
raise RuntimeError(
|
||||||
|
"flash_attn_2 is not available; this deployment requires FlashAttention-2."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _pick_dtype() -> torch.dtype:
|
||||||
|
return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
||||||
|
|
||||||
|
|
||||||
|
def _unload_model_locked():
|
||||||
|
global _model, _processor, _loaded_model_id, _dtype_str
|
||||||
|
# Assumes caller holds _model_lock
|
||||||
|
before_free, before_total = (0, 0)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
free, total = torch.cuda.mem_get_info(0)
|
||||||
|
before_free, before_total = free, total
|
||||||
|
|
||||||
|
_model = None
|
||||||
|
_processor = None
|
||||||
|
_loaded_model_id = None
|
||||||
|
_dtype_str = None
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
try:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.ipc_collect()
|
||||||
|
except Exception:
|
||||||
|
# Best-effort cleanup; continue
|
||||||
|
pass
|
||||||
|
|
||||||
|
after_free, after_total = (0, 0)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
free, total = torch.cuda.mem_get_info(0)
|
||||||
|
after_free, after_total = free, total
|
||||||
|
|
||||||
|
return {
|
||||||
|
"before": {"free": before_free, "total": before_total},
|
||||||
|
"after": {"free": after_free, "total": after_total},
|
||||||
|
"freed": max(0, (after_free - before_free)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _load_model_locked(model_id: str):
|
||||||
|
# Assumes caller holds _model_lock
|
||||||
|
global _model, _processor, _loaded_model_id, _dtype_str, _device_str
|
||||||
|
|
||||||
|
_hard_requirements_check()
|
||||||
|
|
||||||
|
dtype = _pick_dtype()
|
||||||
|
device_map = "cuda:0"
|
||||||
|
attn_impl = "flash_attention_2" # we ensured availability above
|
||||||
|
|
||||||
|
model = ColQwen2_5.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
device_map=device_map,
|
||||||
|
attn_implementation=attn_impl,
|
||||||
|
).eval()
|
||||||
|
|
||||||
|
processor = ColQwen2_5_Processor.from_pretrained(model_id)
|
||||||
|
|
||||||
|
_model = model
|
||||||
|
_processor = processor
|
||||||
|
_loaded_model_id = model_id
|
||||||
|
_dtype_str = _torch_dtype_str(dtype)
|
||||||
|
_device_str = device_map
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_model_loaded():
|
||||||
|
with _model_lock:
|
||||||
|
if (
|
||||||
|
_model is not None
|
||||||
|
and _processor is not None
|
||||||
|
and _loaded_model_id == _active_model_id
|
||||||
|
):
|
||||||
|
model, processor = _model, _processor
|
||||||
|
assert model is not None and processor is not None
|
||||||
|
return model, processor
|
||||||
|
# Different or missing model: (re)load
|
||||||
|
_unload_model_locked()
|
||||||
|
_load_model_locked(_active_model_id)
|
||||||
|
model, processor = _model, _processor
|
||||||
|
assert model is not None and processor is not None
|
||||||
|
return model, processor
|
||||||
|
|
||||||
|
|
||||||
|
def _current_vram_info():
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return {"free": None, "total": None, "used": None}
|
||||||
|
free, total = torch.cuda.mem_get_info(0)
|
||||||
|
used = total - free
|
||||||
|
return {"free": free, "total": total, "used": used}
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_base64_image(b64_data: str):
|
||||||
|
# Size guard
|
||||||
|
approx_bytes = int(len(b64_data) * 0.75) # rough, base64 overhead ~33%
|
||||||
|
if approx_bytes > MAX_IMAGE_BASE64_BYTES:
|
||||||
|
raise ValueError(
|
||||||
|
f"Image exceeds max base64 size of {MAX_IMAGE_BASE64_BYTES} bytes"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
raw = base64.b64decode(b64_data, validate=True)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Invalid base64: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
img = Image.open(io.BytesIO(raw))
|
||||||
|
# Convert to RGB for model compatibility
|
||||||
|
if img.mode != "RGB":
|
||||||
|
img = img.convert("RGB")
|
||||||
|
img.load() # ensure data is read
|
||||||
|
return img
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Unable to decode image: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_embeddings(outputs) -> torch.Tensor:
|
||||||
|
# ColQwen2.5 returns either:
|
||||||
|
# - a tensor shaped (batch, tokens, dim), or
|
||||||
|
# - an object with .last_hidden_state
|
||||||
|
if isinstance(outputs, torch.Tensor):
|
||||||
|
embeddings = outputs
|
||||||
|
elif hasattr(outputs, "last_hidden_state"):
|
||||||
|
embeddings = outputs.last_hidden_state
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unexpected model output type: {type(outputs)}")
|
||||||
|
|
||||||
|
if embeddings.dim() == 2: # (tokens, dim) -> single item
|
||||||
|
embeddings = embeddings.unsqueeze(0)
|
||||||
|
elif embeddings.dim() != 3:
|
||||||
|
raise RuntimeError(f"Unexpected embedding shape: {tuple(embeddings.shape)}")
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------------------
|
||||||
|
# Endpoints
|
||||||
|
# --------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
def health():
|
||||||
|
"""
|
||||||
|
Health check with hard requirements:
|
||||||
|
- CUDA available
|
||||||
|
- FlashAttention-2 available
|
||||||
|
- Lazy-loads the active model once
|
||||||
|
- Includes dtype, device, and VRAM info
|
||||||
|
"""
|
||||||
|
cuda_ok = bool(torch.cuda.is_available())
|
||||||
|
flash_ok = bool(is_flash_attn_2_available())
|
||||||
|
|
||||||
|
info = {
|
||||||
|
"status": "ok",
|
||||||
|
"model_id": _active_model_id,
|
||||||
|
"model_url": HF_MODEL_URL,
|
||||||
|
"cuda_available": cuda_ok,
|
||||||
|
"flash_attn_2_available": flash_ok,
|
||||||
|
"dtype": _dtype_str,
|
||||||
|
"device": _device_str,
|
||||||
|
"vram_bytes": _current_vram_info(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Hard failures
|
||||||
|
if not cuda_ok:
|
||||||
|
info["status"] = "error"
|
||||||
|
info["error"] = "CUDA is not available inside the container."
|
||||||
|
raise HTTPException(status_code=500, detail=info)
|
||||||
|
|
||||||
|
if not flash_ok:
|
||||||
|
info["status"] = "error"
|
||||||
|
info["error"] = (
|
||||||
|
"flash_attn_2 is not available; this deployment requires FlashAttention-2."
|
||||||
|
)
|
||||||
|
raise HTTPException(status_code=500, detail=info)
|
||||||
|
|
||||||
|
try:
|
||||||
|
_ensure_model_loaded()
|
||||||
|
except Exception as exc:
|
||||||
|
info["status"] = "error"
|
||||||
|
info["error"] = str(exc)
|
||||||
|
raise HTTPException(status_code=500, detail=info) from exc
|
||||||
|
|
||||||
|
# Ensure final dtype/device populated
|
||||||
|
info["dtype"] = _dtype_str
|
||||||
|
info["device"] = _device_str
|
||||||
|
info["vram_bytes"] = _current_vram_info()
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/select-model")
|
||||||
|
def select_model(req: SelectModelRequest):
|
||||||
|
"""
|
||||||
|
Switch the active model between the allowed set.
|
||||||
|
Fully unloads the current model (free VRAM) then loads the new one.
|
||||||
|
Blocks concurrent requests briefly via a lock.
|
||||||
|
"""
|
||||||
|
global _active_model_id
|
||||||
|
target = req.model_id.strip()
|
||||||
|
|
||||||
|
if target not in ALLOWED_MODEL_IDS:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={
|
||||||
|
"error": "Unsupported model_id",
|
||||||
|
"allowed": sorted(list(ALLOWED_MODEL_IDS)),
|
||||||
|
"received": target,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
with _model_lock:
|
||||||
|
if (
|
||||||
|
target == _active_model_id
|
||||||
|
and _model is not None
|
||||||
|
and _loaded_model_id == target
|
||||||
|
):
|
||||||
|
# No-op
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"model_id": _active_model_id,
|
||||||
|
"message": "Model unchanged; already active.",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Switch
|
||||||
|
_active_model_id = target
|
||||||
|
_unload_model_locked()
|
||||||
|
try:
|
||||||
|
_load_model_locked(_active_model_id)
|
||||||
|
except Exception as exc:
|
||||||
|
# Attempt to revert to a safe state: no model loaded
|
||||||
|
_unload_model_locked()
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={"error": f"Failed to load model '{target}': {exc}"},
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
return {"status": "ok", "model_id": _active_model_id}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/embed-texts", response_model=EmbedResponse)
|
||||||
|
def embed_texts(request: EmbedTextsRequest):
|
||||||
|
"""
|
||||||
|
Compute multi-vector embeddings for a list of texts.
|
||||||
|
Result shape: results[batch][tokens][dim] (multi-vector per text).
|
||||||
|
|
||||||
|
Limits:
|
||||||
|
- Max texts per request: TEXT_MAX_ITEMS (default 32)
|
||||||
|
"""
|
||||||
|
texts = request.texts
|
||||||
|
if not texts:
|
||||||
|
raise HTTPException(status_code=400, detail="texts must not be empty")
|
||||||
|
if len(texts) > MAX_TEXTS_PER_REQUEST:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail=f"Too many texts; max is {MAX_TEXTS_PER_REQUEST}"
|
||||||
|
)
|
||||||
|
|
||||||
|
with _model_lock:
|
||||||
|
model, processor = _ensure_model_loaded()
|
||||||
|
|
||||||
|
try:
|
||||||
|
with torch.inference_mode():
|
||||||
|
batch = processor.process_queries(texts).to(_device_str)
|
||||||
|
outputs = model(**batch)
|
||||||
|
embeddings = _extract_embeddings(outputs)
|
||||||
|
except Exception as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to compute text embeddings: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
embeddings = embeddings.detach().cpu().float()
|
||||||
|
results = embeddings.tolist()
|
||||||
|
|
||||||
|
return EmbedResponse(model_id=_active_model_id, results=results)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/embed-images", response_model=EmbedImagesResponse)
|
||||||
|
def embed_images(request: EmbedImagesRequest):
|
||||||
|
"""
|
||||||
|
Compute multi-vector embeddings for a list of base64-encoded images.
|
||||||
|
Returns results aligned with the input order: results[i] is [] if the i-th image failed to decode.
|
||||||
|
|
||||||
|
Limits:
|
||||||
|
- Max images per request: IMAGE_MAX_ITEMS (default 8)
|
||||||
|
- Max base64 bytes per image: IMAGE_MAX_BASE64_BYTES (default ~25MB)
|
||||||
|
- Max image pixels (safety): IMAGE_MAX_PIXELS (default ~30MP)
|
||||||
|
"""
|
||||||
|
b64_list = request.images_b64
|
||||||
|
if not b64_list:
|
||||||
|
raise HTTPException(status_code=400, detail="images_b64 must not be empty")
|
||||||
|
if len(b64_list) > MAX_IMAGES_PER_REQUEST:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail=f"Too many images; max is {MAX_IMAGES_PER_REQUEST}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decode individually and track metadata
|
||||||
|
decoded_images: List[Optional[Image.Image]] = [None] * len(b64_list)
|
||||||
|
metadata: List[ImageMetadata] = []
|
||||||
|
ok_indices: List[int] = []
|
||||||
|
|
||||||
|
for idx, b64_img in enumerate(b64_list):
|
||||||
|
try:
|
||||||
|
img = _decode_base64_image(b64_img)
|
||||||
|
decoded_images[idx] = img
|
||||||
|
w, h = img.size
|
||||||
|
metadata.append(
|
||||||
|
ImageMetadata(index=idx, status="ok", width=w, height=h, mode=img.mode)
|
||||||
|
)
|
||||||
|
ok_indices.append(idx)
|
||||||
|
except Exception as exc:
|
||||||
|
metadata.append(ImageMetadata(index=idx, status="error", error=str(exc)))
|
||||||
|
|
||||||
|
if not ok_indices:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="All provided images failed to decode or were rejected by limits",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare only successful images for batching, but preserve order in output
|
||||||
|
images_ok: List[Image.Image] = []
|
||||||
|
for i in ok_indices:
|
||||||
|
img_i = decoded_images[i]
|
||||||
|
assert img_i is not None
|
||||||
|
images_ok.append(img_i)
|
||||||
|
|
||||||
|
with _model_lock:
|
||||||
|
model, processor = _ensure_model_loaded()
|
||||||
|
|
||||||
|
try:
|
||||||
|
with torch.inference_mode():
|
||||||
|
batch_images = processor.process_images(images_ok).to(_device_str)
|
||||||
|
outputs = model(**batch_images)
|
||||||
|
embeddings = _extract_embeddings(outputs)
|
||||||
|
except Exception as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500, detail=f"Failed to compute image embeddings: {exc}"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
embeddings = embeddings.detach().cpu().float().tolist()
|
||||||
|
|
||||||
|
# Distribute embeddings back into results aligned with original indices
|
||||||
|
# For failed entries, place an empty list [].
|
||||||
|
results: List[List[List[float]]] = [[] for _ in range(len(b64_list))]
|
||||||
|
for pos, idx in enumerate(ok_indices):
|
||||||
|
results[idx] = embeddings[pos]
|
||||||
|
|
||||||
|
return EmbedImagesResponse(
|
||||||
|
model_id=_active_model_id, results=results, metadata=metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/free-vram")
|
||||||
|
def free_vram():
|
||||||
|
"""
|
||||||
|
Frees GPU VRAM by unloading the model/processor and emptying CUDA caches.
|
||||||
|
The active model selection is preserved, but the next request will re-load the model.
|
||||||
|
"""
|
||||||
|
with _model_lock:
|
||||||
|
before = _current_vram_info()
|
||||||
|
stats = _unload_model_locked()
|
||||||
|
after = _current_vram_info()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"active_model_id": _active_model_id,
|
||||||
|
"vram_bytes_before": before,
|
||||||
|
"vram_bytes_after": after,
|
||||||
|
"free_stats": stats,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------------------
|
||||||
|
# Entrypoint
|
||||||
|
# --------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=API_PORT)
|
||||||
Reference in New Issue
Block a user