Skip to content

Fix ESMC 300M/600M local loading: torch.load the .pth (repos ship no safetensors)#332

Open
sakzgupzzz wants to merge 1 commit into
Biohub:mainfrom
sakzgupzzz:fix-esmc-300m-600m-pth-loading
Open

Fix ESMC 300M/600M local loading: torch.load the .pth (repos ship no safetensors)#332
sakzgupzzz wants to merge 1 commit into
Biohub:mainfrom
sakzgupzzz:fix-esmc-300m-600m-pth-loading

Conversation

@sakzgupzzz
Copy link
Copy Markdown

ESMC.from_pretrained("esmc_300m") and ("esmc_600m") currently fail with:

ValueError: Directory ".../snapshots/..." does not contain a valid checkpoint.
Expected either a sharded checkpoint with an index file, or a single model file.

Cause

ESMC_300M_202412 / ESMC_600M_202412 in esm/pretrained.py load weights via
load_torch_model(model, data_root("esmc-300"|"esmc-600")). huggingface_hub.load_torch_model
defaults to safe=True, so given a directory it looks for a safetensors checkpoint.
But the biohub/esmc-300m-2024-12 and biohub/esmc-600m-2024-12 repos ship a single
legacy checkpoint at data/weights/esmc_*_2024_12_v0.pth and no safetensors — so the
lookup fails.

The 6B repo (biohub/esmc-6b-2024-12) ships sharded safetensors + an index, so its
loader works; this PR leaves it unchanged.

Fix

Load the .pth explicitly with torch.load + load_state_dict(assign=True), exactly
as the ESM3_* loaders in the same file already do.

Testing

Verified esmc_600m loads and runs end-to-end on a GPU (downloads the .pth, builds
the model with d_model=1152, n_heads=18, n_layers=36, embeds 65 protein sequences via
logits(..., return_embeddings=True)). The esmc_300m change is the identical pattern
against the verified data/weights/esmc_300m_2024_12_v0.pth filename.

The 300M and 600M HF repos (biohub/esmc-300m-2024-12, biohub/esmc-600m-2024-12)
ship a single legacy checkpoint at data/weights/esmc_*_2024_12_v0.pth and no
safetensors. ESMC_300M_202412 / ESMC_600M_202412 loaded them via
load_torch_model(model, dir), which defaults to safe=True and looks for a
safetensors checkpoint in the directory, so loading fails with:

    ValueError: Directory '.../snapshots/...' does not contain a valid
    checkpoint. Expected either a sharded checkpoint with an index file,
    or a single model file.

Load the .pth explicitly with torch.load + load_state_dict(assign=True), exactly
as the ESM3 loaders in this file already do. The 6B loader is unchanged because
its repo ships sharded safetensors, for which load_torch_model works.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant