Timing analysis for the separate processing steps, also GPU -> CPU transfer
This commit is contained in:
@@ -9,41 +9,66 @@ import time
|
|||||||
MODEL_ID = "TomoroAI/tomoro-colqwen3-embed-4b"
|
MODEL_ID = "TomoroAI/tomoro-colqwen3-embed-4b"
|
||||||
DTYPE = torch.bfloat16
|
DTYPE = torch.bfloat16
|
||||||
# DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
# DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
DEVICE = "cpu"
|
# DEVICE = "cpu"
|
||||||
# DEVICE = "cuda"
|
DEVICE = "cuda"
|
||||||
|
print(f"DEVICE: {DEVICE}")
|
||||||
start_ts = time.perf_counter_ns()
|
|
||||||
|
|
||||||
# Load Model & Processor
|
# Load Model & Processor
|
||||||
|
start_ts = time.perf_counter_ns()
|
||||||
processor = AutoProcessor.from_pretrained(
|
processor = AutoProcessor.from_pretrained(
|
||||||
MODEL_ID,
|
MODEL_ID,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
max_num_visual_tokens=1280,
|
max_num_visual_tokens=1280,
|
||||||
)
|
)
|
||||||
|
duration_ns = time.perf_counter_ns() - start_ts
|
||||||
|
print(f"Duration Load Processor: {duration_ns:,} ns")
|
||||||
|
|
||||||
|
start_ts = time.perf_counter_ns()
|
||||||
model = AutoModel.from_pretrained(
|
model = AutoModel.from_pretrained(
|
||||||
MODEL_ID,
|
MODEL_ID,
|
||||||
dtype=DTYPE,
|
dtype=DTYPE,
|
||||||
# attn_implementation="flash_attention_2",
|
attn_implementation="flash_attention_2",
|
||||||
attn_implementation="sdpa",
|
# attn_implementation="sdpa",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
device_map=DEVICE,
|
device_map=DEVICE,
|
||||||
).eval()
|
).eval()
|
||||||
|
|
||||||
duration_ns = time.perf_counter_ns() - start_ts
|
duration_ns = time.perf_counter_ns() - start_ts
|
||||||
print(f"Duration Load Model & Processor: {duration_ns:,} ns")
|
print(f"Duration Load Model: {duration_ns:,} ns")
|
||||||
total_params = sum(p.numel() for p in model.parameters())
|
total_params = sum(p.numel() for p in model.parameters())
|
||||||
print(f"total_params: {total_params:,}")
|
print(f"Model total_params: {total_params:,}")
|
||||||
|
|
||||||
# Sample Data
|
# Sample Data
|
||||||
queries = [
|
queries = [
|
||||||
"Retrieve the city of Singapore",
|
"Retrieve a city of Singapore picture",
|
||||||
# "Retrieve the city of Beijing",
|
"Retrieve a city of Beijing picture",
|
||||||
# "Retrieve the city of London",
|
"Retrieve a city of London picture",
|
||||||
|
"Retrieve a city of Frankfurt am Main picture",
|
||||||
|
"Retrieve a city of Berlin picture",
|
||||||
|
|
||||||
|
# "Retrieve a city of Madrid picture",
|
||||||
|
# "Retrieve a city of Budapest picture",
|
||||||
|
# "Retrieve a city of Dresden picture",
|
||||||
|
# "Retrieve a city of New York picture",
|
||||||
|
# "Retrieve a city of Sydney picture",
|
||||||
|
# "Retrieve a city of Toronto picture",
|
||||||
|
# "Retrieve a city of Asunción picture",
|
||||||
]
|
]
|
||||||
docs = [
|
docs = [
|
||||||
"https://upload.wikimedia.org/wikipedia/commons/2/27/Singapore_skyline_2022.jpg",
|
"https://upload.wikimedia.org/wikipedia/commons/2/27/Singapore_skyline_2022.jpg",
|
||||||
# "https://upload.wikimedia.org/wikipedia/commons/6/61/Beijing_skyline_at_night.JPG",
|
"https://upload.wikimedia.org/wikipedia/commons/6/61/Beijing_skyline_at_night.JPG",
|
||||||
# "https://upload.wikimedia.org/wikipedia/commons/4/49/London_skyline.jpg",
|
"https://upload.wikimedia.org/wikipedia/commons/4/49/London_skyline.jpg",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/d/d7/Skyline_Frankfurt_am_Main_2015.jpg",
|
||||||
|
"https://upload.wikimedia.org/wikipedia/commons/8/83/Cityscape_Berlin.jpg",
|
||||||
|
|
||||||
|
# Decoding errors:
|
||||||
|
# "https://commons.wikimedia.org/wiki/File:Sydney_skyline_at_dusk_-_Dec_2008.jpg",
|
||||||
|
# "https://commons.wikimedia.org/wiki/File:Toronto_-_ON_-_Toronto_Skyline8.jpg",
|
||||||
|
# "https://commons.wikimedia.org/wiki/File:Asunci%C3%B3n_Paraguay.jpg",
|
||||||
|
# "https://commons.wikimedia.org/wiki/File:Madrid_ciudad.jpg",
|
||||||
|
# "https://commons.wikimedia.org/wiki/File:Budapest,_Hungary_(explored)_(14995308504).jpg",
|
||||||
|
# "https://commons.wikimedia.org/wiki/File:DD-canaletto-blick.jpg",
|
||||||
|
# "https://commons.wikimedia.org/wiki/File:Long_Island_City_New_York_May_2015_panorama_3.jpg",
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def load_image(url: str) -> Image.Image:
|
def load_image(url: str) -> Image.Image:
|
||||||
@@ -76,10 +101,30 @@ def encode_docs(urls, batch_size=4):
|
|||||||
outputs = []
|
outputs = []
|
||||||
for start in range(0, len(pil_images), batch_size):
|
for start in range(0, len(pil_images), batch_size):
|
||||||
batch_imgs = pil_images[start : start + batch_size]
|
batch_imgs = pil_images[start : start + batch_size]
|
||||||
|
|
||||||
|
start_ts = time.perf_counter_ns()
|
||||||
features = processor.process_images(images=batch_imgs)
|
features = processor.process_images(images=batch_imgs)
|
||||||
features = {k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v for k, v in features.items()}
|
features = {
|
||||||
|
k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v
|
||||||
|
for k, v in features.items()
|
||||||
|
}
|
||||||
|
duration_ns = time.perf_counter_ns() - start_ts
|
||||||
|
print(f"Duration process_images: {duration_ns:,} ns")
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
|
|
||||||
|
start_ts = time.perf_counter_ns()
|
||||||
out = model(**features)
|
out = model(**features)
|
||||||
|
vecs = out.embeddings.to(torch.bfloat16).cpu()
|
||||||
|
duration_ns = time.perf_counter_ns() - start_ts
|
||||||
|
print(f"Duration vecs generation (no .cpu): {duration_ns:,} ns")
|
||||||
|
|
||||||
|
start_ts = time.perf_counter_ns()
|
||||||
|
vecs = vecs.cpu()
|
||||||
|
duration_ns = time.perf_counter_ns() - start_ts
|
||||||
|
print(f"Duration vecs.cpu()): {duration_ns:,} ns")
|
||||||
|
|
||||||
|
if False:
|
||||||
print(f"type(out.embeddings) = {type(out.embeddings)}")
|
print(f"type(out.embeddings) = {type(out.embeddings)}")
|
||||||
print(f"out.embeddings.shape = {out.embeddings.shape}")
|
print(f"out.embeddings.shape = {out.embeddings.shape}")
|
||||||
print(f"out.embeddings.ndim = {out.embeddings.ndim}")
|
print(f"out.embeddings.ndim = {out.embeddings.ndim}")
|
||||||
@@ -89,31 +134,26 @@ def encode_docs(urls, batch_size=4):
|
|||||||
f"{out.embeddings.element_size()}")
|
f"{out.embeddings.element_size()}")
|
||||||
print("out.embeddings.numel() * out.embeddings.element_size() = "
|
print("out.embeddings.numel() * out.embeddings.element_size() = "
|
||||||
f"{out.embeddings.numel() * out.embeddings.element_size()}")
|
f"{out.embeddings.numel() * out.embeddings.element_size()}")
|
||||||
vecs = out.embeddings.to(torch.bfloat16).cpu()
|
|
||||||
outputs.extend(vecs)
|
outputs.extend(vecs)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
# Execution
|
# Execution
|
||||||
|
|
||||||
start_ts = time.perf_counter_ns()
|
start_ts = time.perf_counter_ns()
|
||||||
|
|
||||||
query_embeddings = encode_queries(queries)
|
query_embeddings = encode_queries(queries)
|
||||||
|
|
||||||
duration_ns = time.perf_counter_ns() - start_ts
|
duration_ns = time.perf_counter_ns() - start_ts
|
||||||
print(f"Duration encode_queries: {duration_ns:,} ns")
|
print(f"Duration encode_queries: {duration_ns:,} ns")
|
||||||
|
|
||||||
start_ts = time.perf_counter_ns()
|
start_ts = time.perf_counter_ns()
|
||||||
|
|
||||||
doc_embeddings = encode_docs(docs)
|
doc_embeddings = encode_docs(docs)
|
||||||
|
|
||||||
duration_ns = time.perf_counter_ns() - start_ts
|
duration_ns = time.perf_counter_ns() - start_ts
|
||||||
print(f"Duration encode_docs: {duration_ns:,} ns")
|
print(f"Duration encode_docs: {duration_ns:,} ns")
|
||||||
|
|
||||||
# MaxSim Scoring
|
# MaxSim Scoring
|
||||||
|
|
||||||
start_ts = time.perf_counter_ns()
|
start_ts = time.perf_counter_ns()
|
||||||
|
|
||||||
scores = processor.score_multi_vector(query_embeddings, doc_embeddings)
|
scores = processor.score_multi_vector(query_embeddings, doc_embeddings)
|
||||||
|
|
||||||
duration_ns = time.perf_counter_ns() - start_ts
|
duration_ns = time.perf_counter_ns() - start_ts
|
||||||
print(f"Duration score_multi_vector: {duration_ns:,} ns")
|
print(f"Duration score_multi_vector: {duration_ns:,} ns")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user