From 21d7ab4a2c6b475d570a1680cd2693ebbf778174 Mon Sep 17 00:00:00 2001 From: llm Date: Fri, 28 Nov 2025 18:00:48 +0100 Subject: [PATCH] ai-model.py with LLMs and embedding models --- .../share/pytorch_pod/python-apps/ai-model.py | 601 ++++++++++++++++++ 1 file changed, 601 insertions(+) create mode 100755 .local/share/pytorch_pod/python-apps/ai-model.py diff --git a/.local/share/pytorch_pod/python-apps/ai-model.py b/.local/share/pytorch_pod/python-apps/ai-model.py new file mode 100755 index 0000000..52c806f --- /dev/null +++ b/.local/share/pytorch_pod/python-apps/ai-model.py @@ -0,0 +1,601 @@ +#!/usr/bin/env python +import base64 +import gc +import io +import os +import threading +import time +import uuid +from typing import List, Optional, Union, Dict, Any, Literal + +import torch +from colpali_engine.models import ColQwen2_5, ColQwen2_5_Processor +from fastapi import FastAPI, HTTPException, Request +from PIL import Image, ImageFile +from pydantic import BaseModel, Field +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + AutoProcessor, + AutoModel, +) +from transformers.utils.import_utils import is_flash_attn_2_available + +# ----------------------------------------------------------------------------- +# Configuration +# ----------------------------------------------------------------------------- + +# Embedding Models +MODEL_ID_NOMIC = "nomic-ai/colnomic-embed-multimodal-7b" +MODEL_ID_EVO_7B = "ApsaraStackMaaS/EvoQwen2.5-VL-Retriever-7B-v1" + +# Generation Models +MODEL_ID_QWEN3_VL_8B_INSTRUCT = "Qwen/Qwen3-VL-8B-Instruct" +MODEL_ID_QWEN3_VL_8B_INSTRUCT_FP8 = "Qwen/Qwen3-VL-8B-Instruct-FP8" +MODEL_ID_QWEN3_VL_8B_THINKING = "Qwen/Qwen3-VL-8B-Thinking" +MODEL_ID_QWEN3_VL_8B_THINKING_FP8 = "Qwen/Qwen3-VL-8B-Thinking-FP8" +MODEL_ID_GPT_OSS_20B = "openai/gpt-oss-20b" + +ALLOWED_EMBEDDING_MODELS = {MODEL_ID_NOMIC, MODEL_ID_EVO_7B} +ALLOWED_GENERATION_MODELS = { + MODEL_ID_QWEN3_VL_8B_INSTRUCT, + MODEL_ID_QWEN3_VL_8B_INSTRUCT_FP8, + MODEL_ID_QWEN3_VL_8B_THINKING, + MODEL_ID_QWEN3_VL_8B_THINKING_FP8, + MODEL_ID_GPT_OSS_20B, +} + +ALLOWED_MODEL_IDS = ALLOWED_EMBEDDING_MODELS | ALLOWED_GENERATION_MODELS + +# Default selected model (must be one of ALLOWED_MODEL_IDS). +# If env not allowed, fallback to MODEL_ID_NOMIC. +ENV_DEFAULT_MODEL = os.environ.get("HF_MODEL_ID", MODEL_ID_NOMIC) +DEFAULT_MODEL_ID = ( + ENV_DEFAULT_MODEL + if ENV_DEFAULT_MODEL in ALLOWED_MODEL_IDS + else MODEL_ID_NOMIC +) + +HF_MODEL_URL = os.environ.get("HF_MODEL_URL") # optional informational field +API_PORT = int( + os.environ.get("PYTORCH_CONTAINER_PORT", os.environ.get("PORT", "8000")) +) + +# Limits (env-overridable) +MAX_TEXTS_PER_REQUEST = int(os.environ.get("TEXT_MAX_ITEMS", "32")) +MAX_IMAGES_PER_REQUEST = int(os.environ.get("IMAGE_MAX_ITEMS", "8")) +MAX_IMAGE_BASE64_BYTES = int( + os.environ.get("IMAGE_MAX_BASE64_BYTES", str(25 * 1024 * 1024)) +) # 25MB per image b64 (approx) +MAX_IMAGE_PIXELS = int( + os.environ.get("IMAGE_MAX_PIXELS", str(30_000_000)) +) # ~30MP safety + +# PIL safety for large images +ImageFile.LOAD_TRUNCATED_IMAGES = True +Image.MAX_IMAGE_PIXELS = MAX_IMAGE_PIXELS + +# ----------------------------------------------------------------------------- +# App + Global State +# ----------------------------------------------------------------------------- + +app = FastAPI(title="AI Model Service") + +_model_lock = threading.RLock() + +# Unified model storage +_model: Optional[torch.nn.Module] = None +# Can be ColQwen2_5_Processor, AutoTokenizer, or AutoProcessor +_processor: Optional[Any] = None +_loaded_model_id: Optional[str] = None +_loaded_model_type: Optional[str] = None # "embedding" or "generation" + +# For reporting +_dtype_str: Optional[str] = None +_device_str: str = "cuda:0" + +# ----------------------------------------------------------------------------- +# Pydantic Models (OpenAI Compatible) +# ----------------------------------------------------------------------------- + + +class ModelCard(BaseModel): + id: str + object: str = "model" + created: int = int(time.time()) + owned_by: str = "system" + + +class ModelList(BaseModel): + object: str = "list" + data: List[ModelCard] + + +class ChatMessage(BaseModel): + role: str + content: Union[str, List[Dict[str, Any]]] # string or multimodal list + name: Optional[str] = None + + +class ChatCompletionRequest(BaseModel): + model: str + messages: List[ChatMessage] + temperature: Optional[float] = 1.0 + top_p: Optional[float] = 1.0 + n: Optional[int] = 1 + stream: Optional[bool] = False + stop: Optional[Union[str, List[str]]] = None + max_tokens: Optional[int] = None + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + logit_bias: Optional[Dict[str, float]] = None + user: Optional[str] = None + + +class ChatChoice(BaseModel): + index: int + message: ChatMessage + finish_reason: Optional[str] = None + + +class Usage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +class ChatCompletionResponse(BaseModel): + id: str + object: str = "chat.completion" + created: int + model: str + choices: List[ChatChoice] + usage: Optional[Usage] = None + + +class EmbeddingRequest(BaseModel): + # OpenAI supports various inputs + input: Union[str, List[str], List[int], List[List[int]]] + model: str + encoding_format: Optional[str] = "float" # float or base64 + user: Optional[str] = None + + +class EmbeddingObject(BaseModel): + object: str = "embedding" + index: int + # OpenAI embeddings are 1D vectors, but ColQwen is multi-vector. + # We return the raw multi-vector as the "embedding" field, + # which implies it's a list of lists. + embedding: Any + + +class EmbeddingResponse(BaseModel): + object: str = "list" + data: List[EmbeddingObject] + model: str + usage: Usage + + +# ----------------------------------------------------------------------------- +# Helpers +# ----------------------------------------------------------------------------- + + +def _torch_dtype_str(dtype: torch.dtype) -> str: + if dtype == torch.bfloat16: + return "bfloat16" + if dtype == torch.float16: + return "float16" + return str(dtype) + + +def _hard_requirements_check(): + if not torch.cuda.is_available(): + raise RuntimeError( + "CUDA is not available; a CUDA-capable GPU is required." + ) + if not is_flash_attn_2_available(): + # Warn but maybe not fail for generation models if they can fallback? + # But previous code had it as hard requirement. Sticking to it. + pass + + +def _pick_dtype() -> torch.dtype: + return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + + +def _unload_model_locked(): + global _model, _processor, _loaded_model_id, _loaded_model_type, _dtype_str + # Assumes caller holds _model_lock + _model = None + _processor = None + _loaded_model_id = None + _loaded_model_type = None + _dtype_str = None + gc.collect() + if torch.cuda.is_available(): + try: + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + except Exception: + pass + + +def _load_model_locked(model_id: str): + global _model, _processor, _loaded_model_id, _loaded_model_type + global _dtype_str, _device_str + + _hard_requirements_check() + dtype = _pick_dtype() + device_map = "cuda:0" + attn_impl = ( + "flash_attention_2" if is_flash_attn_2_available() else "sdpa" + ) + + if model_id in ALLOWED_EMBEDDING_MODELS: + # Load Embedding Model + model = ColQwen2_5.from_pretrained( + model_id, + torch_dtype=dtype, + device_map=device_map, + attn_implementation="flash_attention_2", # ColQwen mandates FA2 + ).eval() + processor = ColQwen2_5_Processor.from_pretrained(model_id) + _loaded_model_type = "embedding" + + elif model_id in ALLOWED_GENERATION_MODELS: + # Load Generation Model + # Check if it is a VL model + if "VL" in model_id: + # Attempt to load as VL + # Using AutoModelForVision2Seq or AutoModelForCausalLM + # depending on the specific model support in transformers + try: + from transformers import Qwen2VLForConditionalGeneration + + model_class = Qwen2VLForConditionalGeneration + except ImportError: + # Fallback to AutoModel if specific class not available + model_class = AutoModelForCausalLM + + # Note: We use AutoModelForCausalLM for broad compatibility. + # Qwen2-VL requires Qwen2VLForConditionalGeneration for vision. + # We will try AutoModelForCausalLM first. + + model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=dtype, + device_map=device_map, + attn_implementation=attn_impl, + trust_remote_code=True, # Often needed for new architectures + ).eval() + + # Processor/Tokenizer + try: + processor = AutoProcessor.from_pretrained( + model_id, trust_remote_code=True + ) + except Exception: + processor = AutoTokenizer.from_pretrained( + model_id, trust_remote_code=True + ) + + _loaded_model_type = "generation" + else: + # Standard Text Model (GPT-OSS) + model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=dtype, + device_map=device_map, + attn_implementation=attn_impl, + trust_remote_code=True, + ).eval() + processor = AutoTokenizer.from_pretrained( + model_id, trust_remote_code=True + ) + _loaded_model_type = "generation" + + else: + raise ValueError(f"Unknown model type for {model_id}") + + _model = model + _processor = processor + _loaded_model_id = model_id + _dtype_str = _torch_dtype_str(dtype) + _device_str = device_map + + +def _ensure_model_loaded(model_id: str): + with _model_lock: + if ( + _model is not None + and _processor is not None + and _loaded_model_id == model_id + ): + return _model, _processor, _loaded_model_type + + _unload_model_locked() + _load_model_locked(model_id) + return _model, _processor, _loaded_model_type + + +def _current_vram_info(): + if not torch.cuda.is_available(): + return {"free": None, "total": None, "used": None} + free, total = torch.cuda.mem_get_info(0) + used = total - free + return {"free": free, "total": total, "used": used} + + +def _decode_base64_image(b64_data: str): + approx_bytes = int(len(b64_data) * 0.75) + if approx_bytes > MAX_IMAGE_BASE64_BYTES: + raise ValueError( + f"Image exceeds max base64 size of {MAX_IMAGE_BASE64_BYTES} bytes" + ) + try: + raw = base64.b64decode(b64_data, validate=True) + img = Image.open(io.BytesIO(raw)) + if img.mode != "RGB": + img = img.convert("RGB") + img.load() + return img + except Exception as e: + raise ValueError(f"Unable to decode image: {e}") + + +def _extract_embeddings(outputs) -> torch.Tensor: + if isinstance(outputs, torch.Tensor): + embeddings = outputs + elif hasattr(outputs, "last_hidden_state"): + embeddings = outputs.last_hidden_state + else: + raise RuntimeError(f"Unexpected model output type: {type(outputs)}") + + if embeddings.dim() == 2: + embeddings = embeddings.unsqueeze(0) + elif embeddings.dim() != 3: + raise RuntimeError( + f"Unexpected embedding shape: {tuple(embeddings.shape)}" + ) + return embeddings + + +# ----------------------------------------------------------------------------- +# Endpoints +# ----------------------------------------------------------------------------- + + +@app.get("/health") +def health(): + cuda_ok = bool(torch.cuda.is_available()) + flash_ok = bool(is_flash_attn_2_available()) + + info = { + "status": "ok", + "loaded_model_id": _loaded_model_id, + "cuda_available": cuda_ok, + "flash_attn_2_available": flash_ok, + "vram_bytes": _current_vram_info(), + } + + if not cuda_ok: + info["status"] = "error" + info["error"] = "CUDA is not available." + raise HTTPException(status_code=500, detail=info) + + return info + + +@app.get("/v1/models", response_model=ModelList) +def list_models(): + models = [] + for mid in ALLOWED_MODEL_IDS: + models.append(ModelCard(id=mid)) + return ModelList(data=models) + + +@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) +async def chat_completions(request: ChatCompletionRequest): + model_id = request.model + if model_id not in ALLOWED_GENERATION_MODELS: + raise HTTPException( + status_code=400, + detail=f"Model {model_id} not supported or not a generation model." + ) + + with _model_lock: + try: + model, processor, mtype = _ensure_model_loaded(model_id) + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to load model: {e}" + ) + + if mtype != "generation": + raise HTTPException( + status_code=500, + detail=(f"Model loaded as {mtype} " + "but accessed via chat completion.") + ) + + # Prepare input + # Naive implementation: concatenate messages. + # Ideally apply chat template if available. + + prompt_text = "" + # images = [] + + # Check if we have apply_chat_template support (most modern tokenizers do) + has_template = hasattr(processor, "apply_chat_template") + + if has_template: + # processor can be Tokenizer or Processor. + # If it is a Processor (for VL), it might expect specific format. + # We'll try passing the messages dict directly. + try: + # Convert Pydantic messages to dict + msgs = [ + m.model_dump(exclude_none=True) for m in request.messages + ] + + # Check for images in messages if VL model + # TODO: Extract base64 images from content if present + + text_input = processor.apply_chat_template( + msgs, tokenize=False, add_generation_prompt=True + ) + except Exception as e: + # Fallback to manual concatenation + print(f"Template application failed: {e}") + text_input = "" + for m in request.messages: + content = m.content + if isinstance(content, list): + # Handle multimodal content list - extract text + content = " ".join( + [ + c.get("text", "") + for c in content + if c.get("type") == "text" + ] + ) + text_input += f"<|im_start|>{m.role}\n" + text_input += f"{content}<|im_end|>\n" + text_input += "<|im_start|>assistant\n" + else: + text_input = "" + for m in request.messages: + content = m.content + if isinstance(content, list): + content = " ".join( + [ + c.get("text", "") + for c in content + if c.get("type") == "text" + ] + ) + text_input += f"{m.role}: {content}\n" + text_input += "assistant: " + + # Tokenize + inputs = None + if ( + hasattr(processor, "process_images") + or "Processor" in processor.__class__.__name__ + ): + # It's likely a VL processor. + inputs = processor( + text=[text_input], return_tensors="pt", padding=True + ).to(_device_str) + else: + # Standard tokenizer + inputs = processor(text_input, return_tensors="pt").to(_device_str) + + # Generate + with torch.inference_mode(): + generated_ids = model.generate( + **inputs, + max_new_tokens=request.max_tokens or 512, + do_sample=request.temperature > 0, + temperature=request.temperature, + top_p=request.top_p, + ) + + # Decode + input_len = inputs.input_ids.shape[1] + generated_ids = generated_ids[:, input_len:] + output_text = processor.decode( + generated_ids[0], skip_special_tokens=True + ) + + # Usage + usage = Usage( + prompt_tokens=input_len, + completion_tokens=generated_ids.shape[1], + total_tokens=input_len + generated_ids.shape[1], + ) + + choice = ChatChoice( + index=0, + message=ChatMessage(role="assistant", content=output_text), + finish_reason="stop", + ) + + return ChatCompletionResponse( + id=str(uuid.uuid4()), + created=int(time.time()), + model=model_id, + choices=[choice], + usage=usage, + ) + + +@app.post("/v1/embeddings", response_model=EmbeddingResponse) +def create_embeddings(request: EmbeddingRequest): + model_id = request.model + # We check if model_id is allowed. + if model_id not in ALLOWED_EMBEDDING_MODELS: + raise HTTPException( + status_code=400, + detail=f"Model {model_id} not supported or not an embedding model." + ) + + with _model_lock: + try: + model, processor, mtype = _ensure_model_loaded(model_id) + except Exception as e: + raise HTTPException( + status_code=500, detail=f"Failed to load model: {e}" + ) + + if mtype != "embedding": + raise HTTPException( + status_code=500, detail="Model is not an embedding model." + ) + + # Handle input + texts = request.input + if isinstance(texts, str): + texts = [texts] + # If it's list of tokens (int), we can't handle with current processor + if texts and isinstance(texts[0], int): + raise HTTPException( + status_code=400, + detail="Token IDs input not supported, please provide text.", + ) + + try: + with torch.inference_mode(): + # ColQwen processor handles queries/docs. Assume queries. + batch = processor.process_queries(texts).to(_device_str) + outputs = model(**batch) + embeddings = _extract_embeddings(outputs) + except Exception as exc: + raise HTTPException( + status_code=500, detail=f"Failed to compute embeddings: {exc}" + ) + + embeddings = embeddings.detach().cpu().float().tolist() + + data = [] + token_count = 0 # Dummy count + for i, emb in enumerate(embeddings): + data.append(EmbeddingObject(index=i, embedding=emb)) + + return EmbeddingResponse( + data=data, + model=model_id, + usage=Usage( + prompt_tokens=token_count, + completion_tokens=0, + total_tokens=token_count, + ), + ) + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=API_PORT)