diff --git a/python/timing_tomoro-colqwen3-embed-4b.py b/python/timing_tomoro-colqwen3-embed-4b.py new file mode 100644 index 0000000..e81717b --- /dev/null +++ b/python/timing_tomoro-colqwen3-embed-4b.py @@ -0,0 +1,95 @@ +import torch +from transformers import AutoModel, AutoProcessor +from PIL import Image, UnidentifiedImageError +import requests +from io import BytesIO +import time + +# Configuration +MODEL_ID = "TomoroAI/tomoro-colqwen3-embed-4b" +DTYPE = torch.bfloat16 +# DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +# DEVICE = "cuda" +DEVICE = "cpu" + +# Load Model & Processor +processor = AutoProcessor.from_pretrained( + MODEL_ID, + trust_remote_code=True, + max_num_visual_tokens=1280, +) +model = AutoModel.from_pretrained( + MODEL_ID, + dtype=DTYPE, + attn_implementation="flash_attention_2", + trust_remote_code=True, + device_map=DEVICE, +).eval() + +# Sample Data +queries = [ + "Retrieve the city of Singapore", + "Retrieve the city of Beijing", + "Retrieve the city of London", +] +docs = [ + "https://upload.wikimedia.org/wikipedia/commons/2/27/Singapore_skyline_2022.jpg", + "https://upload.wikimedia.org/wikipedia/commons/6/61/Beijing_skyline_at_night.JPG", + "https://upload.wikimedia.org/wikipedia/commons/4/49/London_skyline.jpg", +] + +def load_image(url: str) -> Image.Image: + # Some CDNs (e.g., Wikimedia) expect a browser-like UA to avoid 403s. + for headers in ({}, {"User-Agent": "Mozilla/5.0 (compatible; ColQwen3-demo/1.0)"}): + resp = requests.get(url, headers=headers, timeout=10) + if resp.status_code == 403: + continue + resp.raise_for_status() + try: + return Image.open(BytesIO(resp.content)).convert("RGB") + except UnidentifiedImageError as e: + raise RuntimeError(f"Failed to decode image from {url}") from e + raise RuntimeError(f"Could not fetch image (HTTP 403) from {url}; try downloading locally and loading from file path.") + +# Helper Functions +def encode_queries(texts, batch_size=8): + outputs = [] + for start in range(0, len(texts), batch_size): + batch = processor.process_texts(texts=texts[start : start + batch_size]) + batch = {k: v.to(DEVICE) for k, v in batch.items()} + with torch.inference_mode(): + out = model(**batch) + vecs = out.embeddings.to(torch.bfloat16).cpu() + outputs.extend(vecs) + return outputs + +def encode_docs(urls): + outputs = [] + for idx, url in enumerate(urls): + img = load_image(url) + features = processor.process_images(images=[img]) + features = {k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v for k, v in features.items()} + # Warm up on the first image, measure only 2nd and 3rd embeddings generation + if idx in (1, 2): + start_ns = time.perf_counter_ns() + with torch.inference_mode(): + out = model(**features) + vecs = out.embeddings.to(torch.bfloat16).cpu() + end_ns = time.perf_counter_ns() + duration_ns = end_ns - start_ns + print(f"Duration encode_docs image {idx + 1}: {duration_ns:,} ns") + else: + with torch.inference_mode(): + out = model(**features) + vecs = out.embeddings.to(torch.bfloat16).cpu() + outputs.extend(vecs) + return outputs + +# Execution +query_embeddings = encode_queries(queries) + +doc_embeddings = encode_docs(docs) + +# MaxSim Scoring +scores = processor.score_multi_vector(query_embeddings, doc_embeddings) +print(scores)