#!/usr/bin/env python import base64 import gc import io import os import threading import time import uuid from typing import List, Optional, Union, Dict, Any, Literal import torch from colpali_engine.models import ColQwen2_5, ColQwen2_5_Processor from fastapi import FastAPI, HTTPException, Request from PIL import Image, ImageFile from pydantic import BaseModel, Field from transformers import ( AutoModelForCausalLM, AutoTokenizer, AutoProcessor, AutoModel, AutoModelForVision2Seq, ) from transformers.utils.import_utils import is_flash_attn_2_available # ----------------------------------------------------------------------------- # Configuration # ----------------------------------------------------------------------------- # Embedding Models MODEL_ID_NOMIC = "nomic-ai/colnomic-embed-multimodal-7b" MODEL_ID_EVO_7B = "ApsaraStackMaaS/EvoQwen2.5-VL-Retriever-7B-v1" # Generation Models MODEL_ID_QWEN3_VL_8B_INSTRUCT = "Qwen/Qwen3-VL-8B-Instruct" MODEL_ID_QWEN3_VL_8B_INSTRUCT_FP8 = "Qwen/Qwen3-VL-8B-Instruct-FP8" MODEL_ID_QWEN3_VL_8B_THINKING = "Qwen/Qwen3-VL-8B-Thinking" MODEL_ID_QWEN3_VL_8B_THINKING_FP8 = "Qwen/Qwen3-VL-8B-Thinking-FP8" MODEL_ID_GPT_OSS_20B = "openai/gpt-oss-20b" ALLOWED_EMBEDDING_MODELS = {MODEL_ID_NOMIC, MODEL_ID_EVO_7B} ALLOWED_GENERATION_MODELS = { MODEL_ID_QWEN3_VL_8B_INSTRUCT, MODEL_ID_QWEN3_VL_8B_INSTRUCT_FP8, MODEL_ID_QWEN3_VL_8B_THINKING, MODEL_ID_QWEN3_VL_8B_THINKING_FP8, MODEL_ID_GPT_OSS_20B, } ALLOWED_MODEL_IDS = ALLOWED_EMBEDDING_MODELS | ALLOWED_GENERATION_MODELS # Default selected model (must be one of ALLOWED_MODEL_IDS). # If env not allowed, fallback to MODEL_ID_NOMIC. ENV_DEFAULT_MODEL = os.environ.get("HF_MODEL_ID", MODEL_ID_NOMIC) DEFAULT_MODEL_ID = ( ENV_DEFAULT_MODEL if ENV_DEFAULT_MODEL in ALLOWED_MODEL_IDS else MODEL_ID_NOMIC ) HF_MODEL_URL = os.environ.get("HF_MODEL_URL") # optional informational field API_PORT = int( os.environ.get("PYTORCH_CONTAINER_PORT", os.environ.get("PORT", "8000")) ) # Limits (env-overridable) MAX_TEXTS_PER_REQUEST = int(os.environ.get("TEXT_MAX_ITEMS", "32")) MAX_IMAGES_PER_REQUEST = int(os.environ.get("IMAGE_MAX_ITEMS", "8")) MAX_IMAGE_BASE64_BYTES = int( os.environ.get("IMAGE_MAX_BASE64_BYTES", str(25 * 1024 * 1024)) ) # 25MB per image b64 (approx) MAX_IMAGE_PIXELS = int( os.environ.get("IMAGE_MAX_PIXELS", str(30_000_000)) ) # ~30MP safety # PIL safety for large images ImageFile.LOAD_TRUNCATED_IMAGES = True Image.MAX_IMAGE_PIXELS = MAX_IMAGE_PIXELS # ----------------------------------------------------------------------------- # App + Global State # ----------------------------------------------------------------------------- app = FastAPI(title="AI Model Service") _model_lock = threading.RLock() # Unified model storage _model: Optional[torch.nn.Module] = None # Can be ColQwen2_5_Processor, AutoTokenizer, or AutoProcessor _processor: Optional[Any] = None _loaded_model_id: Optional[str] = None _loaded_model_type: Optional[str] = None # "embedding" or "generation" # For reporting _dtype_str: Optional[str] = None _device_str: str = "cuda:0" # ----------------------------------------------------------------------------- # Pydantic Models (OpenAI Compatible) # ----------------------------------------------------------------------------- class ModelCard(BaseModel): id: str object: str = "model" created: int = int(time.time()) owned_by: str = "system" class ModelList(BaseModel): object: str = "list" data: List[ModelCard] class ChatMessage(BaseModel): role: str content: Union[str, List[Dict[str, Any]]] # string or multimodal list name: Optional[str] = None class ChatCompletionRequest(BaseModel): model: str messages: List[ChatMessage] temperature: Optional[float] = 1.0 top_p: Optional[float] = 1.0 n: Optional[int] = 1 stream: Optional[bool] = False stop: Optional[Union[str, List[str]]] = None max_tokens: Optional[int] = None presence_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0 logit_bias: Optional[Dict[str, float]] = None user: Optional[str] = None class ChatChoice(BaseModel): index: int message: ChatMessage finish_reason: Optional[str] = None class Usage(BaseModel): prompt_tokens: int completion_tokens: int total_tokens: int class ChatCompletionResponse(BaseModel): id: str object: str = "chat.completion" created: int model: str choices: List[ChatChoice] usage: Optional[Usage] = None class EmbeddingRequest(BaseModel): # OpenAI supports various inputs input: Union[str, List[str], List[int], List[List[int]]] model: str encoding_format: Optional[str] = "float" # float or base64 user: Optional[str] = None class EmbeddingObject(BaseModel): object: str = "embedding" index: int # OpenAI embeddings are 1D vectors, but ColQwen is multi-vector. # We return the raw multi-vector as the "embedding" field, # which implies it's a list of lists. embedding: Any class EmbeddingResponse(BaseModel): object: str = "list" data: List[EmbeddingObject] model: str usage: Usage class PreloadRequest(BaseModel): model: str # ----------------------------------------------------------------------------- # Helpers # ----------------------------------------------------------------------------- def _torch_dtype_str(dtype: torch.dtype) -> str: if dtype == torch.bfloat16: return "bfloat16" if dtype == torch.float16: return "float16" return str(dtype) def _hard_requirements_check(): if not torch.cuda.is_available(): raise RuntimeError( "CUDA is not available; a CUDA-capable GPU is required." ) if not is_flash_attn_2_available(): # Warn but maybe not fail for generation models if they can fallback? # But previous code had it as hard requirement. Sticking to it. pass def _pick_dtype() -> torch.dtype: return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 def _unload_model_locked(): global _model, _processor, _loaded_model_id, _loaded_model_type, _dtype_str # Assumes caller holds _model_lock _model = None _processor = None _loaded_model_id = None _loaded_model_type = None _dtype_str = None gc.collect() if torch.cuda.is_available(): try: torch.cuda.empty_cache() torch.cuda.ipc_collect() except Exception: pass def _load_model_locked(model_id: str): global _model, _processor, _loaded_model_id, _loaded_model_type global _dtype_str, _device_str _hard_requirements_check() dtype = _pick_dtype() device_map = "cuda:0" attn_impl = ( "flash_attention_2" if is_flash_attn_2_available() else "sdpa" ) if model_id in ALLOWED_EMBEDDING_MODELS: # Load Embedding Model model = ColQwen2_5.from_pretrained( model_id, torch_dtype=dtype, device_map=device_map, attn_implementation="flash_attention_2", # ColQwen mandates FA2 ).eval() processor = ColQwen2_5_Processor.from_pretrained(model_id) _loaded_model_type = "embedding" elif model_id in ALLOWED_GENERATION_MODELS: # Load Generation Model # Check if it is a VL model if "VL" in model_id: # Use AutoModelForVision2Seq for VL models # The configuration class Qwen3VLConfig requires Vision2Seq or AutoModel try: print(f"Loading {model_id} with AutoModelForVision2Seq...") model = AutoModelForVision2Seq.from_pretrained( model_id, torch_dtype=dtype, device_map=device_map, attn_implementation=attn_impl, trust_remote_code=True, low_cpu_mem_usage=True, ).eval() except Exception as e: print(f"Vision2Seq failed: {e}. Fallback to AutoModel...") # Fallback to generic AutoModel if Vision2Seq fails model = AutoModel.from_pretrained( model_id, torch_dtype=dtype, device_map=device_map, attn_implementation=attn_impl, trust_remote_code=True, low_cpu_mem_usage=True, ).eval() # Processor/Tokenizer try: processor = AutoProcessor.from_pretrained( model_id, trust_remote_code=True ) except Exception: processor = AutoTokenizer.from_pretrained( model_id, trust_remote_code=True ) _loaded_model_type = "generation" else: # Standard Text Model (GPT-OSS) print(f"Loading {model_id} with AutoModelForCausalLM...") model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=dtype, device_map=device_map, attn_implementation=attn_impl, trust_remote_code=True, low_cpu_mem_usage=True, ).eval() processor = AutoTokenizer.from_pretrained( model_id, trust_remote_code=True ) _loaded_model_type = "generation" else: raise ValueError(f"Unknown model type for {model_id}") _model = model _processor = processor _loaded_model_id = model_id _dtype_str = _torch_dtype_str(dtype) _device_str = device_map def _ensure_model_loaded(model_id: str): with _model_lock: if ( _model is not None and _processor is not None and _loaded_model_id == model_id ): return _model, _processor, _loaded_model_type _unload_model_locked() _load_model_locked(model_id) return _model, _processor, _loaded_model_type def _current_vram_info(): if not torch.cuda.is_available(): return {"free": None, "total": None, "used": None} free, total = torch.cuda.mem_get_info(0) used = total - free return {"free": free, "total": total, "used": used} def _decode_base64_image(b64_data: str): approx_bytes = int(len(b64_data) * 0.75) if approx_bytes > MAX_IMAGE_BASE64_BYTES: raise ValueError( f"Image exceeds max base64 size of {MAX_IMAGE_BASE64_BYTES} bytes" ) try: raw = base64.b64decode(b64_data, validate=True) img = Image.open(io.BytesIO(raw)) if img.mode != "RGB": img = img.convert("RGB") img.load() return img except Exception as e: raise ValueError(f"Unable to decode image: {e}") def _extract_embeddings(outputs) -> torch.Tensor: if isinstance(outputs, torch.Tensor): embeddings = outputs elif hasattr(outputs, "last_hidden_state"): embeddings = outputs.last_hidden_state else: raise RuntimeError(f"Unexpected model output type: {type(outputs)}") if embeddings.dim() == 2: embeddings = embeddings.unsqueeze(0) elif embeddings.dim() != 3: raise RuntimeError( f"Unexpected embedding shape: {tuple(embeddings.shape)}" ) return embeddings # ----------------------------------------------------------------------------- # Endpoints # ----------------------------------------------------------------------------- @app.post("/preload") def preload_model(request: PreloadRequest): model_id = request.model.strip() if model_id not in ALLOWED_MODEL_IDS: raise HTTPException( status_code=400, detail=f"Model {model_id} not in allowed models.", ) with _model_lock: try: _ensure_model_loaded(model_id) except Exception as e: raise HTTPException( status_code=500, detail=f"Failed to load model: {e}" ) return { "status": "ok", "loaded_model_id": _loaded_model_id, "vram_bytes": _current_vram_info(), } @app.post("/unload") def unload_model(): with _model_lock: stats = _unload_model_locked() return { "status": "ok", "vram_bytes": _current_vram_info(), "stats": stats, } @app.get("/health") def health(): cuda_ok = bool(torch.cuda.is_available()) flash_ok = bool(is_flash_attn_2_available()) info = { "status": "ok", "loaded_model_id": _loaded_model_id, "cuda_available": cuda_ok, "flash_attn_2_available": flash_ok, "vram_bytes": _current_vram_info(), } if not cuda_ok: info["status"] = "error" info["error"] = "CUDA is not available." raise HTTPException(status_code=500, detail=info) return info @app.get("/v1/models", response_model=ModelList) def list_models(): models = [] for mid in ALLOWED_MODEL_IDS: models.append(ModelCard(id=mid)) return ModelList(data=models) @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) async def chat_completions(request: ChatCompletionRequest): model_id = request.model if model_id not in ALLOWED_GENERATION_MODELS: raise HTTPException( status_code=400, detail=f"Model {model_id} not supported or not a generation model." ) with _model_lock: try: model, processor, mtype = _ensure_model_loaded(model_id) except Exception as e: raise HTTPException( status_code=500, detail=f"Failed to load model: {e}" ) if mtype != "generation": raise HTTPException( status_code=500, detail=(f"Model loaded as {mtype} " "but accessed via chat completion.") ) # Prepare input # Naive implementation: concatenate messages. # Ideally apply chat template if available. prompt_text = "" # images = [] # Check if we have apply_chat_template support (most modern tokenizers do) has_template = hasattr(processor, "apply_chat_template") if has_template: # processor can be Tokenizer or Processor. # If it is a Processor (for VL), it might expect specific format. # We'll try passing the messages dict directly. try: # Convert Pydantic messages to dict msgs = [ m.model_dump(exclude_none=True) for m in request.messages ] # Check for images in messages if VL model # TODO: Extract base64 images from content if present text_input = processor.apply_chat_template( msgs, tokenize=False, add_generation_prompt=True ) except Exception as e: # Fallback to manual concatenation print(f"Template application failed: {e}") text_input = "" for m in request.messages: content = m.content if isinstance(content, list): # Handle multimodal content list - extract text content = " ".join( [ c.get("text", "") for c in content if c.get("type") == "text" ] ) text_input += f"<|im_start|>{m.role}\n" text_input += f"{content}<|im_end|>\n" text_input += "<|im_start|>assistant\n" else: text_input = "" for m in request.messages: content = m.content if isinstance(content, list): content = " ".join( [ c.get("text", "") for c in content if c.get("type") == "text" ] ) text_input += f"{m.role}: {content}\n" text_input += "assistant: " # Tokenize inputs = None if ( hasattr(processor, "process_images") or "Processor" in processor.__class__.__name__ ): # It's likely a VL processor. inputs = processor( text=[text_input], return_tensors="pt", padding=True ).to(_device_str) else: # Standard tokenizer inputs = processor(text_input, return_tensors="pt").to(_device_str) # Generate with torch.inference_mode(): generated_ids = model.generate( **inputs, max_new_tokens=request.max_tokens or 512, do_sample=request.temperature > 0, temperature=request.temperature, top_p=request.top_p, ) # Decode input_len = inputs.input_ids.shape[1] generated_ids = generated_ids[:, input_len:] output_text = processor.decode( generated_ids[0], skip_special_tokens=True ) # Usage usage = Usage( prompt_tokens=input_len, completion_tokens=generated_ids.shape[1], total_tokens=input_len + generated_ids.shape[1], ) choice = ChatChoice( index=0, message=ChatMessage(role="assistant", content=output_text), finish_reason="stop", ) return ChatCompletionResponse( id=str(uuid.uuid4()), created=int(time.time()), model=model_id, choices=[choice], usage=usage, ) @app.post("/v1/embeddings", response_model=EmbeddingResponse) def create_embeddings(request: EmbeddingRequest): model_id = request.model # We check if model_id is allowed. if model_id not in ALLOWED_EMBEDDING_MODELS: raise HTTPException( status_code=400, detail=f"Model {model_id} not supported or not an embedding model." ) with _model_lock: try: model, processor, mtype = _ensure_model_loaded(model_id) except Exception as e: raise HTTPException( status_code=500, detail=f"Failed to load model: {e}" ) if mtype != "embedding": raise HTTPException( status_code=500, detail="Model is not an embedding model." ) # Handle input texts = request.input if isinstance(texts, str): texts = [texts] # If it's list of tokens (int), we can't handle with current processor if texts and isinstance(texts[0], int): raise HTTPException( status_code=400, detail="Token IDs input not supported, please provide text.", ) try: with torch.inference_mode(): # ColQwen processor handles queries/docs. Assume queries. batch = processor.process_queries(texts).to(_device_str) outputs = model(**batch) embeddings = _extract_embeddings(outputs) except Exception as exc: raise HTTPException( status_code=500, detail=f"Failed to compute embeddings: {exc}" ) embeddings = embeddings.detach().cpu().float().tolist() data = [] token_count = 0 # Dummy count for i, emb in enumerate(embeddings): data.append(EmbeddingObject(index=i, embedding=emb)) return EmbeddingResponse( data=data, model=model_id, usage=Usage( prompt_tokens=token_count, completion_tokens=0, total_tokens=token_count, ), ) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=API_PORT)