From 8ffd5dd122618579557f8e437993c92f92ba3a47 Mon Sep 17 00:00:00 2001 From: llm Date: Sat, 13 Dec 2025 22:50:19 +0100 Subject: [PATCH] PyTorch experiments --- .../experiments_tomoro-colqwen3-embed-4b.py | 120 ++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 python/experiments_tomoro-colqwen3-embed-4b.py diff --git a/python/experiments_tomoro-colqwen3-embed-4b.py b/python/experiments_tomoro-colqwen3-embed-4b.py new file mode 100644 index 0000000..4626542 --- /dev/null +++ b/python/experiments_tomoro-colqwen3-embed-4b.py @@ -0,0 +1,120 @@ +import torch +from transformers import AutoModel, AutoProcessor +from PIL import Image, UnidentifiedImageError +import requests +from io import BytesIO +import time + +# Configuration +MODEL_ID = "TomoroAI/tomoro-colqwen3-embed-4b" +DTYPE = torch.bfloat16 +# DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +DEVICE = "cpu" +# DEVICE = "cuda" + +start_ts = time.perf_counter_ns() + +# Load Model & Processor +processor = AutoProcessor.from_pretrained( + MODEL_ID, + trust_remote_code=True, + max_num_visual_tokens=1280, +) +model = AutoModel.from_pretrained( + MODEL_ID, + dtype=DTYPE, + # attn_implementation="flash_attention_2", + attn_implementation="sdpa", + trust_remote_code=True, + device_map=DEVICE, +).eval() + +duration_ns = time.perf_counter_ns() - start_ts +print(f"Duration Load Model & Processor: {duration_ns:,} ns") +total_params = sum(p.numel() for p in model.parameters()) +print(f"total_params: {total_params:,}") + +# Sample Data +queries = [ + "Retrieve the city of Singapore", +# "Retrieve the city of Beijing", +# "Retrieve the city of London", +] +docs = [ + "https://upload.wikimedia.org/wikipedia/commons/2/27/Singapore_skyline_2022.jpg", +# "https://upload.wikimedia.org/wikipedia/commons/6/61/Beijing_skyline_at_night.JPG", +# "https://upload.wikimedia.org/wikipedia/commons/4/49/London_skyline.jpg", +] + +def load_image(url: str) -> Image.Image: + # Some CDNs (e.g., Wikimedia) expect a browser-like UA to avoid 403s. + for headers in ({}, {"User-Agent": "Mozilla/5.0 (compatible; ColQwen3-demo/1.0)"}): + resp = requests.get(url, headers=headers, timeout=10) + if resp.status_code == 403: + continue + resp.raise_for_status() + try: + return Image.open(BytesIO(resp.content)).convert("RGB") + except UnidentifiedImageError as e: + raise RuntimeError(f"Failed to decode image from {url}") from e + raise RuntimeError(f"Could not fetch image (HTTP 403) from {url}; try downloading locally and loading from file path.") + +# Helper Functions +def encode_queries(texts, batch_size=8): + outputs = [] + for start in range(0, len(texts), batch_size): + batch = processor.process_texts(texts=texts[start : start + batch_size]) + batch = {k: v.to(DEVICE) for k, v in batch.items()} + with torch.inference_mode(): + out = model(**batch) + vecs = out.embeddings.to(torch.bfloat16).cpu() + outputs.extend(vecs) + return outputs + +def encode_docs(urls, batch_size=4): + pil_images = [load_image(url) for url in urls] + outputs = [] + for start in range(0, len(pil_images), batch_size): + batch_imgs = pil_images[start : start + batch_size] + features = processor.process_images(images=batch_imgs) + features = {k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v for k, v in features.items()} + with torch.inference_mode(): + out = model(**features) + print(f"type(out.embeddings) = {type(out.embeddings)}") + print(f"out.embeddings.shape = {out.embeddings.shape}") + print(f"out.embeddings.ndim = {out.embeddings.ndim}") + print(f"out.embeddings.device = {out.embeddings.device}") + print(f"out.embeddings.numel() = {out.embeddings.numel()}") + print("out.embeddings.element_size() = " + f"{out.embeddings.element_size()}") + print("out.embeddings.numel() * out.embeddings.element_size() = " + f"{out.embeddings.numel() * out.embeddings.element_size()}") + vecs = out.embeddings.to(torch.bfloat16).cpu() + outputs.extend(vecs) + return outputs + +# Execution + +start_ts = time.perf_counter_ns() + +query_embeddings = encode_queries(queries) + +duration_ns = time.perf_counter_ns() - start_ts +print(f"Duration encode_queries: {duration_ns:,} ns") +start_ts = time.perf_counter_ns() + +doc_embeddings = encode_docs(docs) + +duration_ns = time.perf_counter_ns() - start_ts +print(f"Duration encode_docs: {duration_ns:,} ns") + +# MaxSim Scoring + +start_ts = time.perf_counter_ns() + +scores = processor.score_multi_vector(query_embeddings, doc_embeddings) + +duration_ns = time.perf_counter_ns() - start_ts +print(f"Duration score_multi_vector: {duration_ns:,} ns") + +print(scores)