diff --git a/doctr/models/recognition/vitstr/pytorch.py b/doctr/models/recognition/vitstr/pytorch.py index 86ee9f6f1a..ba5bbf2568 100644 --- a/doctr/models/recognition/vitstr/pytorch.py +++ b/doctr/models/recognition/vitstr/pytorch.py @@ -90,11 +90,16 @@ def forward( return_model_output: bool = False, return_preds: bool = False, ) -> dict[str, Any]: - features = self.feat_extractor(x)["features"] # (batch_size, patches_seqlen, d_model) + features = self.feat_extractor(x)[ + "features" + ] # (batch_size, patches_seqlen, d_model) if target is not None: _gt, _seq_len = self.build_target(target) - gt, seq_len = torch.from_numpy(_gt).to(dtype=torch.long), torch.tensor(_seq_len) + gt, seq_len = ( + torch.from_numpy(_gt).to(dtype=torch.long), + torch.tensor(_seq_len), + ) gt, seq_len = gt.to(x.device), seq_len.to(x.device) if self.training and target is None: @@ -104,7 +109,9 @@ def forward( features = features[:, : self.max_length] # (batch_size, max_length, d_model) B, N, E = features.size() features = features.reshape(B * N, E) - logits = self.head(features).view(B, N, len(self.vocab) + 1) # (batch_size, max_length, vocab + 1) + logits = self.head(features).view( + B, N, len(self.vocab) + 1 + ) # (batch_size, max_length, vocab + 1) decoded_features = _bf16_to_float32(logits[:, 1:]) # remove cls_token out: dict[str, Any] = {} @@ -149,15 +156,23 @@ def compute_loss( # Input length : number of steps input_len = model_output.shape[1] # Add one for additional token (sos disappear in shift!) - seq_len = seq_len + 1 + seq_len_ = seq_len + 1 # Do not mutate input tensor, use a temporary variable + # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]! # The "masked" first gt char is . - cce = F.cross_entropy(model_output.permute(0, 2, 1), gt[:, 1:], reduction="none") + cce = F.cross_entropy( + model_output.permute(0, 2, 1), gt[:, 1:], reduction="none" + ) # Compute mask - mask_2d = torch.arange(input_len, device=model_output.device)[None, :] >= seq_len[:, None] - cce[mask_2d] = 0 + # The mask is True for positions to be masked (steps >= length) + # Scope performance: avoid repeated shape allocations by computing contiguous mask directly + row_range = torch.arange(input_len, device=model_output.device) + mask_2d = row_range.unsqueeze(0) >= seq_len_.unsqueeze(1) + + # In-place zeroing masked entries reduces memory use vs. advanced assignment + cce.masked_fill_(mask_2d, 0) - ce_loss = cce.sum(1) / seq_len.to(dtype=model_output.dtype) + ce_loss = cce.sum(dim=1) / seq_len_.to(dtype=model_output.dtype) return ce_loss.mean() @@ -183,7 +198,8 @@ def __call__( ] # compute probabilties for each word up to the EOS token probs = [ - preds_prob[i, : len(word)].clip(0, 1).mean().item() if word else 0.0 for i, word in enumerate(word_values) + preds_prob[i, : len(word)].clip(0, 1).mean().item() if word else 0.0 + for i, word in enumerate(word_values) ] return list(zip(word_values, probs)) @@ -222,7 +238,9 @@ def _vitstr( if pretrained: # The number of classes is not the same as the number of classes in the pretrained model => # remove the last layer weights - _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None + _ignore_keys = ( + ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None + ) model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys) return model