#!/usr/bin/env python import os from typing import List import torch from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers.utils.import_utils import is_flash_attn_2_available from colpali_engine.models import ColQwen2_5, ColQwen2_5_Processor HF_MODEL_ID = os.environ.get("HF_MODEL_ID", "nomic-ai/colnomic-embed-multimodal-7b") HF_MODEL_URL = os.environ.get("HF_MODEL_URL") API_PORT = int(os.environ.get("PYTORCH_CONTAINER_PORT", os.environ.get("PORT", "8000"))) app = FastAPI(title="Colnomic Embed Multimodal 7B API") _model = None _processor = None _device = None def _ensure_model_loaded(): """ Lazy-load the ColNomic model and processor on first request. Hard requirements for this deployment: - CUDA must be available. - FlashAttention-2 must be available (flash-attn successfully installed). If either is missing, an exception is raised and /health returns 500. """ global _model, _processor, _device if _model is not None and _processor is not None: return _model, _processor, _device if not torch.cuda.is_available(): raise RuntimeError("CUDA is not available; a CUDA-capable GPU is required.") if not is_flash_attn_2_available(): raise RuntimeError("flash_attn_2 is not available; please install compatible libraries.") # Choose dtype: BF16 if supported, otherwise FP16 dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # Use a single GPU (cuda:0) for now. device_map = "cuda:0" # Force FlashAttention-2 (we already checked availability above). attn_impl = "flash_attention_2" model = ColQwen2_5.from_pretrained( HF_MODEL_ID, torch_dtype=dtype, device_map=device_map, attn_implementation=attn_impl, ).eval() processor = ColQwen2_5_Processor.from_pretrained(HF_MODEL_ID) _model = model _processor = processor _device = device_map return _model, _processor, _device class EmbedRequest(BaseModel): texts: List[str] class EmbedResponse(BaseModel): model_id: str # results[batch][tokens][dim] results: List[List[List[float]]] @app.get("/health") def health(): """ Health check: - Reports CUDA and FlashAttention-2 availability. - Tries to load the model once (lazy). - Returns 200 only if CUDA, FlashAttention-2 and model loading are OK. """ cuda_ok = bool(torch.cuda.is_available()) flash_ok = bool(is_flash_attn_2_available()) info = { "status": "ok", "model_id": HF_MODEL_ID, "model_url": HF_MODEL_URL, "cuda_available": cuda_ok, "flash_attn_2_available": flash_ok, } # CUDA or FlashAttention missing -> hard failure if not cuda_ok: info["status"] = "error" info["error"] = "CUDA is not available inside the container." raise HTTPException(status_code=500, detail=info) if not flash_ok: info["status"] = "error" info["error"] = "flash_attn_2 is not available; this deployment requires FlashAttention-2." raise HTTPException(status_code=500, detail=info) try: _ensure_model_loaded() except Exception as exc: # noqa: BLE001 info["status"] = "error" info["error"] = str(exc) raise HTTPException(status_code=500, detail=info) from exc return info @app.post("/embed", response_model=EmbedResponse) def embed(request: EmbedRequest): """ Compute multi-vector embeddings for a list of texts. Result shape: results[batch][tokens][dim] (multi-vector per text). """ if not request.texts: raise HTTPException(status_code=400, detail="texts must not be empty") model, processor, device = _ensure_model_loaded() # noqa: F841 - device kept for future use # For queries, use process_queries (as in ColQwen2.5 docs) with torch.inference_mode(): batch = processor.process_queries(request.texts).to(model.device) outputs = model(**batch) # ColQwen2.5 returns either: # - a tensor shaped (batch, tokens, dim), or # - an object with .last_hidden_state if isinstance(outputs, torch.Tensor): embeddings = outputs elif hasattr(outputs, "last_hidden_state"): embeddings = outputs.last_hidden_state else: raise HTTPException( status_code=500, detail=f"Unexpected model output type from ColQwen/ColPali: {type(outputs)}", ) if embeddings.dim() == 2: # (tokens, dim) -> single text embeddings = embeddings.unsqueeze(0) elif embeddings.dim() != 3: raise HTTPException( status_code=500, detail=f"Unexpected embedding shape: {tuple(embeddings.shape)}", ) embeddings = embeddings.detach().cpu().float() results = embeddings.tolist() return EmbedResponse(model_id=HF_MODEL_ID, results=results) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=API_PORT)