-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathexp.py
More file actions
119 lines (97 loc) · 5.21 KB
/
exp.py
File metadata and controls
119 lines (97 loc) · 5.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
from e1_fastplms.modeling_e1 import E1BatchPreparer
def analyze_batch_kwargs(batch_kwargs: dict, preparer: E1BatchPreparer, sequences: list[str]) -> None:
print("==== Batch kwargs analysis ====")
input_ids = batch_kwargs["input_ids"]
within_seq_position_ids = batch_kwargs["within_seq_position_ids"]
global_position_ids = batch_kwargs["global_position_ids"]
sequence_ids = batch_kwargs["sequence_ids"]
labels = batch_kwargs["labels"]
context = batch_kwargs["context"]
context_len = batch_kwargs["context_len"]
pad_token_id = preparer.pad_token_id
def _shortened_list(values: list[int], max_items: int = 8) -> str:
if len(values) <= max_items:
return str(values)
return str(values[:max_items] + [f"... (+{len(values) - max_items} more)"])
assert input_ids.shape == within_seq_position_ids.shape == global_position_ids.shape == sequence_ids.shape == labels.shape
batch_size, max_len = input_ids.shape
assert len(context) == batch_size == len(context_len) == len(sequences)
print(f"batch_size: {batch_size}")
print(f"max_length: {max_len}")
print(f"pad_token_id: {pad_token_id}")
print(f"kwargs keys: {list(batch_kwargs.keys())}")
for name, tensor in (
("input_ids", input_ids),
("within_seq_position_ids", within_seq_position_ids),
("global_position_ids", global_position_ids),
("sequence_ids", sequence_ids),
("labels", labels),
):
assert isinstance(tensor, torch.Tensor)
non_pad = (tensor != -1).sum().item()
if tensor.numel() > 0 and tensor.dtype.is_floating_point:
value_stats = f"min={tensor.min().item():.4f}, max={tensor.max().item():.4f}"
else:
value_stats = f"min={tensor.min().item()}, max={tensor.max().item()}"
print()
print(f"{name}:")
print(f" shape={tuple(tensor.shape)} dtype={tensor.dtype} device={tensor.device}")
first_index = tuple([0] * tensor.ndim)
print(f" first_element={tensor[first_index].item()}")
first_row = tensor[0, : min(8, tensor.shape[1])].tolist()
print(f" first_row_prefix={_shortened_list([int(x) for x in first_row], max_items=8)}")
print(f" non_padding_count={non_pad} / total={tensor.numel()} ({non_pad / tensor.numel() * 100:.2f}%)")
print(f" {value_stats}")
print()
print("context tokens (metadata):")
print(f" first_context: '{str(context[0])[:50]}'")
print(f" first_context_len: {context_len[0]}")
print(f" first_sequence: '{sequences[0]}'")
for i, (raw_sequence, decoded_context, ctx_len, raw_ids) in enumerate(
zip(sequences, context, context_len, sequence_ids)
):
valid_len = int((raw_ids != -1).sum().item())
ctx_len = int(ctx_len)
print(f" sample[{i}] raw sequence input: {raw_sequence}")
print(f" valid_length={valid_len}, context_len={ctx_len}, context='{decoded_context}'")
row_input_ids = input_ids[i, :valid_len]
row_sequence_ids = raw_ids[:valid_len]
row_within = within_seq_position_ids[i, :valid_len]
row_global = global_position_ids[i, :valid_len]
row_labels = labels[i, :valid_len]
print(f" decoded_input_ids: {preparer.tokenizer.decode(row_input_ids.tolist(), skip_special_tokens=False)}")
print(f" input_id_pads: {int((row_input_ids == pad_token_id).sum().item())}")
print(f" sequence_id_tail: {row_sequence_ids[-5:].tolist()}")
assert torch.equal(row_sequence_ids[torch.where(row_sequence_ids != -1)[0][0] : torch.where(row_sequence_ids != -1)[0][-1] + 1], row_sequence_ids[row_sequence_ids != -1])
unique_sequence_ids = torch.unique(row_sequence_ids[row_sequence_ids != -1]).tolist()
print(f" unique sequence_ids: {unique_sequence_ids}")
seq_boundaries = torch.where(row_sequence_ids[1:] != row_sequence_ids[:-1])[0] + 1
seq_breaks = seq_boundaries.tolist() + [valid_len]
seq_lens = []
start = 0
for end in seq_breaks:
seq_lens.append(end - start)
start = end
print(f" per-subsequence token counts (from concatenated encoding): {seq_lens}")
context_mask = torch.arange(valid_len) < ctx_len
context_masked = int((row_labels[context_mask] == pad_token_id).sum().item())
target_mask = torch.arange(valid_len) >= ctx_len
target_tokens = int((row_labels[target_mask] != pad_token_id).sum().item())
print(f" context tokens masked in labels: {context_masked} / {ctx_len}")
print(f" non-context target tokens kept: {target_tokens}")
# Position-id behavior check
print(f" within_seq_position_ids unique: {torch.unique(row_within).tolist()}")
print(f" global_position_ids max: {int(row_global.max().item())}, min: {int(row_global.min().item())}")
print()
def main() -> None:
# Example batch with single-seq and multi-seq inputs.
sequences = [
"ACDEFGHIKLMNPQRSTVWY",
"MKTFFLILV,LKQMN",
]
preparer = E1BatchPreparer()
batch_kwargs = preparer.get_batch_kwargs(sequences, device=torch.device("cpu"))
analyze_batch_kwargs(batch_kwargs, preparer, sequences)
if __name__ == "__main__":
main()