From 7e4ae94b93b4a2407b950d98ce79cc79a3f67fde Mon Sep 17 00:00:00 2001 From: Anonymous Date: Mon, 28 Jul 2025 23:13:19 -0400 Subject: [PATCH 1/2] Bug fixes in vqvae module: adding bias to LayerNorm of VanillaMultiHeadAttention:__init__, quantizer misuses and not forcing required arguments such as , --- src/vqvae/attention.py | 2 +- src/vqvae/quantizer_module.py | 2 +- src/vqvae/transformer_stack.py | 10 +++++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/vqvae/attention.py b/src/vqvae/attention.py index 4632267..f7f1a8c 100644 --- a/src/vqvae/attention.py +++ b/src/vqvae/attention.py @@ -23,7 +23,7 @@ def __init__( self.d_head = self.d_model // self.n_heads self.layernorm_qkv = nn.Sequential( - nn.LayerNorm(d_model), nn.Linear(d_model, d_model * 3, bias=bias) + nn.LayerNorm(d_model, bias=bias), nn.Linear(d_model, d_model * 3, bias=bias) ) self.out_proj = nn.Linear(d_model, d_model, bias=bias) diff --git a/src/vqvae/quantizer_module.py b/src/vqvae/quantizer_module.py index cb5fa8c..c25629c 100644 --- a/src/vqvae/quantizer_module.py +++ b/src/vqvae/quantizer_module.py @@ -56,7 +56,7 @@ def get_codebook(self,): return self.codebook.weight def indices2embedding(self, indices: torch.IntTensor) -> torch.Tensor: - z_q = self.codebook[indices] + z_q = self.codebook(indices) return z_q def forward(self, z: torch.Tensor) -> (torch.Tensor, torch.IntTensor, float): diff --git a/src/vqvae/transformer_stack.py b/src/vqvae/transformer_stack.py index 41b058f..d2e56b3 100644 --- a/src/vqvae/transformer_stack.py +++ b/src/vqvae/transformer_stack.py @@ -63,10 +63,10 @@ def __init__( def forward( self, x: torch.Tensor, - attention_mask: torch.Tensor | None = None, + affine: Affine3D, + affine_mask: torch.Tensor, + attention_mask: torch.Tensor, sequence_id: torch.Tensor | None = None, - affine: Affine3D | None = None, - affine_mask: torch.Tensor | None = None, chain_id: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -75,8 +75,8 @@ def forward( Args: x (torch.Tensor): The input tensor of shape (batch_size, sequence_length, d_model). sequence_id (torch.Tensor): The sequence ID tensor of shape (batch_size, sequence_length). - affine (Affine3D | None): The affine transformation tensor or None. - affine_mask (torch.Tensor | None): The affine mask tensor or None. + affine (Affine3D): The affine transformation tensor or None. + affine_mask (torch.Tensor): The affine mask tensor or None. chain_id (torch.Tensor): The protein chain tensor of shape (batch_size, sequence_length). Only used in geometric attention. From 1db1c40f10398c2c72c02d8b0eeb8deb2817e7f1 Mon Sep 17 00:00:00 2001 From: Jonathan Coletti <88168630+JonathanColetti@users.noreply.github.com> Date: Wed, 30 Jul 2025 00:18:18 -0400 Subject: [PATCH 2/2] Update comments to reflect required params --- src/vqvae/transformer_stack.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/vqvae/transformer_stack.py b/src/vqvae/transformer_stack.py index d2e56b3..8c957cb 100644 --- a/src/vqvae/transformer_stack.py +++ b/src/vqvae/transformer_stack.py @@ -75,8 +75,8 @@ def forward( Args: x (torch.Tensor): The input tensor of shape (batch_size, sequence_length, d_model). sequence_id (torch.Tensor): The sequence ID tensor of shape (batch_size, sequence_length). - affine (Affine3D): The affine transformation tensor or None. - affine_mask (torch.Tensor): The affine mask tensor or None. + affine (Affine3D): The affine transformation tensor. + affine_mask (torch.Tensor): The affine mask tensor. chain_id (torch.Tensor): The protein chain tensor of shape (batch_size, sequence_length). Only used in geometric attention. @@ -89,4 +89,4 @@ def forward( chain_id = torch.ones(size=batch_dims, dtype=torch.int64, device=x.device) for block in self.blocks: x = block(x, attention_mask, sequence_id, affine, affine_mask, chain_id) - return self.norm(x), x \ No newline at end of file + return self.norm(x), x