-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbench_combined.py
More file actions
59 lines (50 loc) · 2.09 KB
/
bench_combined.py
File metadata and controls
59 lines (50 loc) · 2.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
from sentence_transformers import CrossEncoder
from bench_utils import benchmark, test_model
BUCKETS = list(range(16, 512, 16))
class DynamicCrossEncoder(CrossEncoder):
def smart_batching_collate_text_only(self, batch):
texts = [[text.strip() for text in field] for field in zip(*batch)]
tokenized = self.tokenizer(
*texts,
padding=True,
truncation="longest_first",
return_tensors="pt",
max_length=self.max_length
)
tokenized = {k: v.to(self.model.device) for k, v in tokenized.items()}
# Pad each field to the closest bucket length (multiples of 16)
cur_length = tokenized["input_ids"].size(1)
bucket_length = next((b for b in BUCKETS if b >= cur_length), cur_length)
if bucket_length > cur_length:
diff = bucket_length - cur_length
for key, val in tokenized.items():
pad_value = self.tokenizer.pad_token_id if key == "input_ids" else 0
tokenized[key] = torch.nn.functional.pad(val, (0, diff), value=pad_value)
return tokenized
model = CrossEncoder(
"jinaai/jina-reranker-v2-base-multilingual",
trust_remote_code=True,
device="cuda",
max_length=512
)
model_compile = DynamicCrossEncoder(
"jinaai/jina-reranker-v2-base-multilingual",
trust_remote_code=True,
device="cuda",
config_args={"use_flash_attn": False}
)
model_compile.model.forward = torch.compile(
model_compile.model.forward,
backend="inductor",
mode="max-autotune",
dynamic=True
)
benchmark(model, print_scores=True, on_sorted_inputs=True, seed=100)
benchmark(model_compile, print_scores=True, on_sorted_inputs=True, seed=100)
test_model(model)
test_model(model_compile)
# Base (with flash attn) + Sorted Inputs - Mean time: 0.2658 ± 0.0119 seconds
# Base (with flash attn) + Unsorted Inputs - Mean time: 0.2961 ± 0.0089 seconds
# torch.compile (without flash attn) + Sorted Inputs - Mean time: 0.2089 ± 0.0196 seconds
# torch.compile (without flash attn) + Unsorted Inputs - Mean time: 0.2595 ± 0.0077 seconds