Skip to content

Minimal script to tokenize/reconstruct? #9

@rdilip

Description

@rdilip

Hi!

I wasn't able to replicate AminoAseed's performance on reconstruction or on the measured benchmarks. Is there a minimal example showing tokenization/detokenization with reasonable RMSDs?

As a starting point, I pulled out the WrappedOurPretrainedTokenizer subclass and pasted this below-- the main dependency is the VQVAEModel. I load the state dict correctly, but I'm getting reconstructions around 14A, which seems almost certainly wrong. I'm also unable to fit anything with reasonable AUROC using the tokens from the loaded tokenizer (which is unsurprising since the reconstructions are essentially random). Any suggestions would be appreciated! (see the added WrappedOurPretrainedTokenizer.encode function)

import torch
import os
import time
import numpy as np
import joblib
import pickle
import urllib.request
import functools
from pathlib import Path
import Bio
import Bio.PDB
import torch
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer
from esm.utils.constants import esm3 as C


from vqvae_model import VQVAEModel

class WrappedOurPretrainedTokenizer():

    def __init__(self, device: torch.device | str = "cpu", model_cfg=None, pretrained_ckpt_path=None, ckpt_name=None):
        self.device = device

        # Make config match AminoAseed
        model_cfg.quantizer.use_linear_project = True
        model_cfg.quantizer.freeze_codebook = True        model_cfg.encoder.d_out = 1024
        model_cfg.decoder.encoder_d_out = 1024

        self.model = VQVAEModel(model_cfg=model_cfg)
        model_states = torch.load(pretrained_ckpt_path, weights_only=False, map_location=self.device)["module"]
        new_model_states = {}
        for k,v in model_states.items():
            assert k.startswith("model.")
            new_model_states[k[6:]] = v
        self.model.load_state_dict(new_model_states)
        for p in self.model.parameters():
            p.requires_grad = False
        self.model = self.model.to(self.device)
        
        self.seq_tokenizer = EsmSequenceTokenizer()

        self.ckpt_name = ckpt_name

        # reference: https://github.com/evolutionaryscale/esm/blob/39a3a6cb1e722347947dc375e3f8e2ba80ed8b59/esm/utils/constants/esm3.py#L18C12-L18C35
        self.pad_token_id = self.model.quantizer.codebook.weight.shape[0] + 3

    def get_num_tokens(self):
        return self.model.quantizer.codebook.weight.shape[0] + 5
    
    def get_codebook_embedding(self,):
        return self.model.quantizer.codebook.weight

    @torch.no_grad()
    def encode(self, coords_BLAD: torch.Tensor, sequence: str, decode: bool = False):
        """
        coords_BLAD : [1, L, 37, 3]   (same shape as from to_structure_encoder_inputs)
        sequence    : str of length L
        Returns:
            quantized_z : [1, L, D]
            decoded     : dict or None
        """
        coords = coords_BLAD.to(self.device)          # [1, L, 37, 3]
        B, L, A, D = coords.shape
        assert B == 1

        # originally: attention_mask = coords[:, :, 0, 0] == torch.inf; then ~ inside forward
        padding_mask = coords[:, :, 0, 0] == torch.inf  # True where padded
        attention_mask = padding_mask                   # pass this; VQVAEModel will invert it
        residue_index = torch.arange(L, device=self.device).unsqueeze(0)  # [1, L]
        sequence = sequence.replace(C.MASK_STR_SHORT, "<mask>")
        seq_ids = self.seq_tokenizer.encode(sequence, add_special_tokens=False)
        seq_ids = torch.as_tensor(seq_ids, dtype=torch.int64, device=self.device)  # [L]
        assert seq_ids.shape[0] == L


        seq_residue_tokens = seq_ids  # [L]

        input_list = (coords, attention_mask, residue_index, seq_residue_tokens, None)
        quantized_z, quantized_indices, z = self.model(input_list, use_as_tokenizer=True)


        decoded = None
        if decode:
            decoded = self.model.decoder.decode(
                quantized_z,
                quantized_indices,
                ~attention_mask.bool(),  # forward flips mask; decoder expects same convention as in training
                None,
            )

        return quantized_z, quantized_indices, decoded
    

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions