diff --git a/doctr/models/classification/vit/pytorch.py b/doctr/models/classification/vit/pytorch.py index 24f4bbde58..e779ee5368 100644 --- a/doctr/models/classification/vit/pytorch.py +++ b/doctr/models/classification/vit/pytorch.py @@ -55,7 +55,8 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: # (batch_size, num_classes) cls token - return self.head(x[:, 0]) + x0 = x.select(1, 0) + return self.head(x0) class VisionTransformer(nn.Sequential): @@ -90,7 +91,9 @@ def __init__( ) -> None: _layers: list[nn.Module] = [ PatchEmbedding(input_shape, d_model, patch_size), - EncoderBlock(num_layers, num_heads, d_model, d_model * ffd_ratio, dropout, nn.GELU()), + EncoderBlock( + num_layers, num_heads, d_model, d_model * ffd_ratio, dropout, nn.GELU() + ), ] if include_top: _layers.append(ClassifierHead(d_model, num_classes)) @@ -114,7 +117,9 @@ def _vit( ignore_keys: list[str] | None = None, **kwargs: Any, ) -> VisionTransformer: - kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"])) + kwargs["num_classes"] = kwargs.get( + "num_classes", len(default_cfgs[arch]["classes"]) + ) kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"]) kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"]) @@ -130,7 +135,11 @@ def _vit( if pretrained: # The number of classes is not the same as the number of classes in the pretrained model => # remove the last layer weights - _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None + _ignore_keys = ( + ignore_keys + if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) + else None + ) model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys) return model