Skip to content

Synthyra/FastPLMs

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

341 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

FastPLMs

FastPLMs Hero Image

FastPLMs is an open-source initiative dedicated to accelerating pretrained protein language models (pLMs). By replacing native, often suboptimal attention implementations with Flash Attention or Flex Attention, we provide high-performance alternatives that are fully compatible with the HuggingFace transformers ecosystem.


Table of Contents

  1. Introduction
  2. Supported Models
  3. Attention Backends
  4. Embedding & Pooling
  5. Concrete Examples
  6. Testing & Benchmarking
  7. Installation & Docker

Introduction

What are Protein Language Models (pLMs)?

Protein Language Models are transformer-based architectures trained on massive datasets of protein sequences (such as UniProt). These models learn the "grammar" of proteins, capturing evolutionary information, structural constraints, and functional motifs. They are used for:

  • Representation Learning: Generating high-dimensional embeddings for downstream tasks (e.g., stability, function prediction).
  • Protein Generation: Designing novel sequences with specific properties.
  • Structure Prediction: Mapping sequences to their 3D folds (e.g., Boltz2).

What is this repository?

FastPLMs provides optimized versions of these models. Our focus is on:

  • Speed: Drastically faster inference through optimized attention kernels.
  • Memory Efficiency: Lower VRAM usage, enabling larger batch sizes or longer sequences.
  • Seamless Integration: Use AutoModel.from_pretrained(..., trust_remote_code=True) to load our optimized weights directly from HuggingFace.

Supported Models

We maintain a comprehensive HuggingFace Collection of optimized models. Below is a summary of the supported families and their origins.

Model Registry Summary

Model Family Organization Official Implementation FastPLMs Optimization Checkpoints
E1 Profluent Bio Profluent-Bio/E1 Flex Attention, Block-Causal 150M, 300M, 600M
ESM2 Meta AI facebookresearch/esm Flash (SDPA) / Flex Attention 8M, 35M, 150M, 650M, 3B
ESM++ EvolutionaryScale EvolutionaryScale/esm Optimized SDPA / Flex Small (300M), Large (600M)
DPLM ByteDance N/A Diffusion Optimized Attention 150M, 650M, 3B
DPLM2 ByteDance N/A Multimodal Diffusion 150M, 650M, 3B
Boltz2 MIT / Various jwohlwend/boltz Optimized Structure Prediction Standard

Full Model List

Model Key Family Parameters Organization FastPLMs Repo ID Official Reference
e1_150m E1 150M Profluent Bio Synthyra/Profluent-E1-150M Profluent-Bio/E1-150m
e1_300m E1 300M Profluent Bio Synthyra/Profluent-E1-300M Profluent-Bio/E1-300m
e1_600m E1 600M Profluent Bio Synthyra/Profluent-E1-600M Profluent-Bio/E1-600m
esm2_8m ESM2 8M Meta AI Synthyra/ESM2-8M facebook/esm2_t6_8M_UR50D
esm2_35m ESM2 35M Meta AI Synthyra/ESM2-35M facebook/esm2_t12_35M_UR50D
esm2_150m ESM2 150M Meta AI Synthyra/ESM2-150M facebook/esm2_t30_150M_UR50D
esm2_650m ESM2 650M Meta AI Synthyra/ESM2-650M facebook/esm2_t33_650M_UR50D
esm2_3b ESM2 3B Meta AI Synthyra/ESM2-3B facebook/esm2_t36_3B_UR50D
esmplusplus_small ESM++ 300M EvolutionaryScale Synthyra/ESMplusplus_small EvolutionaryScale/esmc-300m
esmplusplus_large ESM++ 600M EvolutionaryScale Synthyra/ESMplusplus_large EvolutionaryScale/esmc-600m
dplm_150m DPLM 150M ByteDance Synthyra/DPLM-150M airkingbd/dplm_150m
dplm_650m DPLM 650M ByteDance Synthyra/DPLM-650M airkingbd/dplm_650m
dplm_3b DPLM 3B ByteDance Synthyra/DPLM-3B airkingbd/dplm_3b
dplm2_150m DPLM2 150M ByteDance Synthyra/DPLM2-150M airkingbd/dplm2_150m
dplm2_650m DPLM2 650M ByteDance Synthyra/DPLM2-650M airkingbd/dplm2_650m
dplm2_3b DPLM2 3B ByteDance Synthyra/DPLM2-3B airkingbd/dplm2_3b
boltz2 Boltz2 - MIT / Various Synthyra/Boltz2 jwohlwend/boltz

Attention Backends

All FastPLMs models share a common set of attention backends, controlled via config.attn_backend. The default is "sdpa", which is safe on all hardware and numerically equivalent to standard attention.

Backend Comparison

Backend Key Speed Numerical Equivalence Availability
PyTorch SDPA "sdpa" Fast Exact Any PyTorch ≥ 2.0
Flash Attention "kernels_flash" Fastest Approximate Requires pip install kernels (pre-built)
Flex Attention "flex" Very fast ~Exact Requires PyTorch ≥ 2.5
Auto "auto" Always (selects best available)

SDPA (default)

PyTorch's scaled_dot_product_attention dispatches to a fused CUDA kernel (cuDNN or efficient attention) that is faster and more memory-efficient than naive attention, while being mathematically identical to it. This is the recommended default for reproducibility and general use. It is also the only backend where output_attentions=True is handled natively; with other backends, attentions are computed via a separate naive matrix multiplication when requested.

Flash Attention (kernels_flash)

Flash Attention 2 and 3 are typically the fastest options on Ampere (A100) and Hopper (H100) GPUs, often 2–4× faster than SDPA at long sequence lengths. Flash Attention achieves this by tiling the computation and applying an online softmax, which means the results are not bitwise identical to SDPA or naive attention. Differences are on the order of floating-point rounding and are often inconsequential for standard inference — but they are not guaranteed to be so. They can compound across layers, interact with low-precision dtypes (fp16/bf16), or affect sensitive downstream tasks. Flash Attention is standard practice in large model training and the trade-off is well understood, but it should not be treated as a drop-in numerical equivalent of SDPA. If exact reproducibility or numerical sensitivity is a concern, use "sdpa" instead.

No compilation required. FastPLMs uses the HuggingFace kernels package to load pre-built Flash Attention 2/3 binaries at runtime — no C++ compiler, no CUDA toolkit version pinning, no waiting:

pip install kernels

Building flash-attn from source is notoriously painful. The Ninja build system parallelizes aggressively across all available CPU cores, and each NVCC/CICC compiler process it spawns can consume 5–8 GB of RAM on its own. On a 64-core machine this can push peak RAM usage to ~300 GB, and even on a throttled single-threaded build (MAX_JOBS=1 NVCC_THREADS=1) the compile still takes many hours while grinding through paging. Pre-built community wheels cover 384+ version/GPU/CUDA/platform combinations and still routinely fall short of matching a user's exact environment. This is the point where most people give up and go without Flash Attention entirely. The kernels package sidesteps all of this by fetching a pre-compiled binary matched to your GPU architecture (SM80 for Ampere, SM90 for Hopper). If no compatible binary exists for your hardware, it gracefully falls back to flex or sdpa rather than erroring.

Flex Attention (flex)

PyTorch's flex_attention (PyTorch ≥ 2.5) generates a fused Triton kernel customized to the mask pattern at hand. It is numerically very close to SDPA — typically within floating-point rounding of naive computation. The primary advantage is that it can apply a block mask that skips padding tokens entirely, providing a meaningful speedup on batches with variable-length sequences (no compute wasted on padding). E1 uses a block-causal variant of this mask.

The first forward pass triggers JIT compilation via Triton, which can take 30–120 seconds. All subsequent calls are fast. Combining with torch.compile yields the best sustained throughput.

Auto (auto)

Automatically selects the best available backend in order of preference: kernels_flashflexsdpa. Useful when you want maximum speed without configuring the environment manually, and you accept that the resolved backend may differ across machines.

Setting the Backend

At load time (all models):

from transformers import AutoConfig, AutoModel

config = AutoConfig.from_pretrained("Synthyra/ESM2-150M", trust_remote_code=True)
config.attn_backend = "flex"  # "sdpa", "kernels_flash", "flex", or "auto"
model = AutoModel.from_pretrained("Synthyra/ESM2-150M", config=config, trust_remote_code=True)

After load time (DPLM and DPLM2 only):

DPLM and DPLM2 expose an attn_backend property on the model that propagates the change to all attention layers immediately:

model = AutoModel.from_pretrained("Synthyra/DPLM-150M", trust_remote_code=True)
model.attn_backend = "flex"  # updates every attention layer in-place

For ESM2, E1, and ESM++, the backend must be set on the config before calling from_pretrained.

Returning Attention Maps

All backends support output_attentions=True. For the optimized backends (SDPA, Flash Attention, Flex), attention weights are computed via a separate naive matrix multiplication and appended to the output — so enabling this negates the memory savings of those backends. Use it only for inspection or contact prediction, not during high-throughput inference.


Embedding & Pooling

The EmbeddingMixin (shared across all models) provides a standardized way to extract representations from proteins.

The Pooler

The Pooler class aggregates sequence-level residue representations into a single fixed-size vector. Supported strategies include:

  • mean: Mask-aware average of all residues.
  • cls: The first token's representation (Standard for classification).
  • max: Element-wise maximum across the sequence.
  • var / std: Variance or Standard Deviation of representations.
  • norm: L2 normalization.
  • median: Element-wise median.
  • parti: Experimental PageRank-based attention pooling.

Concrete Examples

1. Batch Embedding with SQLite (Scalable)

Ideal for embedding millions of sequences where you need to stream data or avoid OOM on RAM.

import torch
from transformers import AutoModel

model = AutoModel.from_pretrained("Synthyra/ESM2-150M", trust_remote_code=True).cuda()

sequences = ["MALWMRLLPLLALLALWGPDPAAA", "MKTIIALSYIFCLVFA", ...]

# Embed and store in SQLite
model.embed_dataset(
    sequences=sequences,
    batch_size=64,
    pooling_types=['mean', 'cls'], # Concatenates both
    sql=True,
    sql_db_path='large_protein_db.db',
    embed_dtype=torch.float32
)

2. Embedding from a FASTA File

Pass a FASTA file path directly — no manual parsing required. Multi-line sequences are handled automatically. You can combine fasta_path with an explicit sequences list and the two sources are merged before embedding.

# Embed all sequences in a FASTA file and save to SQLite
model.embed_dataset(
    fasta_path='my_proteins.fasta',
    batch_size=64,
    pooling_types=['mean'],
    sql=True,
    sql_db_path='my_proteins.db',
)

# Mix a FASTA file with an explicit list
model.embed_dataset(
    sequences=["MKTIIALSYIFCLVFA"],
    fasta_path='additional_proteins.fasta',
    batch_size=32,
    save=True,
    save_path='combined_embeddings.pth',
)

3. High-Throughput In-Memory Embedding

Perfect for medium-sized datasets that fit in memory.

# Embed and return as a dictionary
embeddings = model.embed_dataset(
    sequences=sequences,
    batch_size=128,
    pooling_types=['mean'],
    save=True,
    save_path='my_embeddings.pth'
)

# Access embedding
seq_vector = embeddings["MALWMRLLPLLALLALWGPDPAAA"] # torch.Tensor

4. Custom Pooling & Multi-Strategy

Concatenate multiple mathematical representations for richer downstream features.

# Use a variety of pooling types
embeddings = model.embed_dataset(
    sequences=sequences,
    pooling_types=['mean', 'max', 'std', 'var'], # All 4 concatenated
    batch_size=32,
    full_embeddings=False
)

# Resulting vector size: 4 * hidden_size
print(embeddings[sequences[0]].shape)

Testing & Benchmarking

FastPLMs includes a robust CLI-based testing suite under testing/.

Running the Suite

  • Compliance Checks: Verify that optimized models match reference outputs.
    py -m testing.run_compliance --families esm2
  • Throughput Benchmarks: Measure tokens/sec and peak memory.
    py -m testing.run_throughput --device cuda --lengths 512,1024
  • Run Everything: Execute the full suite across all families.
    py -m testing.run_all --full-models

Results are saved to testing/results/<timestamp>/ as metrics.json, metrics.csv, and high-resolution plots.


Installation & Docker

Local Installation

git clone https://github.com/Synthyra/FastPLMs.git
cd FastPLMs
pip install -r requirements.txt

Docker (Recommended for Testing)

# Build the image
docker build -t fastplms-test -f Dockerfile .

# Run benchmarks inside container
docker run --rm --gpus all -it -v ${PWD}:/workspace fastplms-test \
    python -m testing.run_throughput --device cuda

Suggestions & Contributions

Found a bug or have a feature request? Please open a GitHub Issue. We are actively looking for contributions to optimize more pLM architectures!

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages