Hi!
I wasn't able to replicate AminoAseed's performance on reconstruction or on the measured benchmarks. Is there a minimal example showing tokenization/detokenization with reasonable RMSDs?
As a starting point, I pulled out the WrappedOurPretrainedTokenizer subclass and pasted this below-- the main dependency is the VQVAEModel. I load the state dict correctly, but I'm getting reconstructions around 14A, which seems almost certainly wrong. I'm also unable to fit anything with reasonable AUROC using the tokens from the loaded tokenizer (which is unsurprising since the reconstructions are essentially random). Any suggestions would be appreciated! (see the added WrappedOurPretrainedTokenizer.encode function)
import torch
import os
import time
import numpy as np
import joblib
import pickle
import urllib.request
import functools
from pathlib import Path
import Bio
import Bio.PDB
import torch
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
from esm.utils.constants import esm3 as C
from vqvae_model import VQVAEModel
class WrappedOurPretrainedTokenizer():
def __init__(self, device: torch.device | str = "cpu", model_cfg=None, pretrained_ckpt_path=None, ckpt_name=None):
self.device = device
# Make config match AminoAseed
model_cfg.quantizer.use_linear_project = True
model_cfg.quantizer.freeze_codebook = True model_cfg.encoder.d_out = 1024
model_cfg.decoder.encoder_d_out = 1024
self.model = VQVAEModel(model_cfg=model_cfg)
model_states = torch.load(pretrained_ckpt_path, weights_only=False, map_location=self.device)["module"]
new_model_states = {}
for k,v in model_states.items():
assert k.startswith("model.")
new_model_states[k[6:]] = v
self.model.load_state_dict(new_model_states)
for p in self.model.parameters():
p.requires_grad = False
self.model = self.model.to(self.device)
self.seq_tokenizer = EsmSequenceTokenizer()
self.ckpt_name = ckpt_name
# reference: https://github.com/evolutionaryscale/esm/blob/39a3a6cb1e722347947dc375e3f8e2ba80ed8b59/esm/utils/constants/esm3.py#L18C12-L18C35
self.pad_token_id = self.model.quantizer.codebook.weight.shape[0] + 3
def get_num_tokens(self):
return self.model.quantizer.codebook.weight.shape[0] + 5
def get_codebook_embedding(self,):
return self.model.quantizer.codebook.weight
@torch.no_grad()
def encode(self, coords_BLAD: torch.Tensor, sequence: str, decode: bool = False):
"""
coords_BLAD : [1, L, 37, 3] (same shape as from to_structure_encoder_inputs)
sequence : str of length L
Returns:
quantized_z : [1, L, D]
decoded : dict or None
"""
coords = coords_BLAD.to(self.device) # [1, L, 37, 3]
B, L, A, D = coords.shape
assert B == 1
# originally: attention_mask = coords[:, :, 0, 0] == torch.inf; then ~ inside forward
padding_mask = coords[:, :, 0, 0] == torch.inf # True where padded
attention_mask = padding_mask # pass this; VQVAEModel will invert it
residue_index = torch.arange(L, device=self.device).unsqueeze(0) # [1, L]
sequence = sequence.replace(C.MASK_STR_SHORT, "<mask>")
seq_ids = self.seq_tokenizer.encode(sequence, add_special_tokens=False)
seq_ids = torch.as_tensor(seq_ids, dtype=torch.int64, device=self.device) # [L]
assert seq_ids.shape[0] == L
seq_residue_tokens = seq_ids # [L]
input_list = (coords, attention_mask, residue_index, seq_residue_tokens, None)
quantized_z, quantized_indices, z = self.model(input_list, use_as_tokenizer=True)
decoded = None
if decode:
decoded = self.model.decoder.decode(
quantized_z,
quantized_indices,
~attention_mask.bool(), # forward flips mask; decoder expects same convention as in training
None,
)
return quantized_z, quantized_indices, decoded
Hi!
I wasn't able to replicate AminoAseed's performance on reconstruction or on the measured benchmarks. Is there a minimal example showing tokenization/detokenization with reasonable RMSDs?
As a starting point, I pulled out the WrappedOurPretrainedTokenizer subclass and pasted this below-- the main dependency is the VQVAEModel. I load the state dict correctly, but I'm getting reconstructions around 14A, which seems almost certainly wrong. I'm also unable to fit anything with reasonable AUROC using the tokens from the loaded tokenizer (which is unsurprising since the reconstructions are essentially random). Any suggestions would be appreciated! (see the added
WrappedOurPretrainedTokenizer.encodefunction)