Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 28 additions & 10 deletions doctr/models/recognition/vitstr/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,16 @@ def forward(
return_model_output: bool = False,
return_preds: bool = False,
) -> dict[str, Any]:
features = self.feat_extractor(x)["features"] # (batch_size, patches_seqlen, d_model)
features = self.feat_extractor(x)[
"features"
] # (batch_size, patches_seqlen, d_model)

if target is not None:
_gt, _seq_len = self.build_target(target)
gt, seq_len = torch.from_numpy(_gt).to(dtype=torch.long), torch.tensor(_seq_len)
gt, seq_len = (
torch.from_numpy(_gt).to(dtype=torch.long),
torch.tensor(_seq_len),
)
gt, seq_len = gt.to(x.device), seq_len.to(x.device)

if self.training and target is None:
Expand All @@ -104,7 +109,9 @@ def forward(
features = features[:, : self.max_length] # (batch_size, max_length, d_model)
B, N, E = features.size()
features = features.reshape(B * N, E)
logits = self.head(features).view(B, N, len(self.vocab) + 1) # (batch_size, max_length, vocab + 1)
logits = self.head(features).view(
B, N, len(self.vocab) + 1
) # (batch_size, max_length, vocab + 1)
decoded_features = _bf16_to_float32(logits[:, 1:]) # remove cls_token

out: dict[str, Any] = {}
Expand Down Expand Up @@ -149,15 +156,23 @@ def compute_loss(
# Input length : number of steps
input_len = model_output.shape[1]
# Add one for additional <eos> token (sos disappear in shift!)
seq_len = seq_len + 1
seq_len_ = seq_len + 1 # Do not mutate input tensor, use a temporary variable

# Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
# The "masked" first gt char is <sos>.
cce = F.cross_entropy(model_output.permute(0, 2, 1), gt[:, 1:], reduction="none")
cce = F.cross_entropy(
model_output.permute(0, 2, 1), gt[:, 1:], reduction="none"
)
# Compute mask
mask_2d = torch.arange(input_len, device=model_output.device)[None, :] >= seq_len[:, None]
cce[mask_2d] = 0
# The mask is True for positions to be masked (steps >= length)
# Scope performance: avoid repeated shape allocations by computing contiguous mask directly
row_range = torch.arange(input_len, device=model_output.device)
mask_2d = row_range.unsqueeze(0) >= seq_len_.unsqueeze(1)

# In-place zeroing masked entries reduces memory use vs. advanced assignment
cce.masked_fill_(mask_2d, 0)

ce_loss = cce.sum(1) / seq_len.to(dtype=model_output.dtype)
ce_loss = cce.sum(dim=1) / seq_len_.to(dtype=model_output.dtype)
return ce_loss.mean()


Expand All @@ -183,7 +198,8 @@ def __call__(
]
# compute probabilties for each word up to the EOS token
probs = [
preds_prob[i, : len(word)].clip(0, 1).mean().item() if word else 0.0 for i, word in enumerate(word_values)
preds_prob[i, : len(word)].clip(0, 1).mean().item() if word else 0.0
for i, word in enumerate(word_values)
]

return list(zip(word_values, probs))
Expand Down Expand Up @@ -222,7 +238,9 @@ def _vitstr(
if pretrained:
# The number of classes is not the same as the number of classes in the pretrained model =>
# remove the last layer weights
_ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
_ignore_keys = (
ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
)
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)

return model
Expand Down