FastPLMs is an open-source initiative dedicated to accelerating pretrained protein language models (pLMs). By replacing native, often suboptimal attention implementations with Flash Attention or Flex Attention, we provide high-performance alternatives that are fully compatible with the HuggingFace transformers ecosystem.
- Introduction
- Supported Models
- Attention Backends
- Embedding & Pooling
- Concrete Examples
- Testing & Benchmarking
- Installation & Docker
Protein Language Models are transformer-based architectures trained on massive datasets of protein sequences (such as UniProt). These models learn the "grammar" of proteins, capturing evolutionary information, structural constraints, and functional motifs. They are used for:
- Representation Learning: Generating high-dimensional embeddings for downstream tasks (e.g., stability, function prediction).
- Protein Generation: Designing novel sequences with specific properties.
- Structure Prediction: Mapping sequences to their 3D folds (e.g., Boltz2).
FastPLMs provides optimized versions of these models. Our focus is on:
- Speed: Drastically faster inference through optimized attention kernels.
- Memory Efficiency: Lower VRAM usage, enabling larger batch sizes or longer sequences.
- Seamless Integration: Use
AutoModel.from_pretrained(..., trust_remote_code=True)to load our optimized weights directly from HuggingFace.
We maintain a comprehensive HuggingFace Collection of optimized models. Below is a summary of the supported families and their origins.
| Model Family | Organization | Official Implementation | FastPLMs Optimization | Checkpoints |
|---|---|---|---|---|
| E1 | Profluent Bio | Profluent-Bio/E1 | Flex Attention, Block-Causal | 150M, 300M, 600M |
| ESM2 | Meta AI | facebookresearch/esm | Flash (SDPA) / Flex Attention | 8M, 35M, 150M, 650M, 3B |
| ESM++ | EvolutionaryScale | EvolutionaryScale/esm | Optimized SDPA / Flex | Small (300M), Large (600M) |
| DPLM | ByteDance | N/A | Diffusion Optimized Attention | 150M, 650M, 3B |
| DPLM2 | ByteDance | N/A | Multimodal Diffusion | 150M, 650M, 3B |
| Boltz2 | MIT / Various | jwohlwend/boltz | Optimized Structure Prediction | Standard |
All FastPLMs models share a common set of attention backends, controlled via config.attn_backend. The default is "sdpa", which is safe on all hardware and numerically equivalent to standard attention.
| Backend | Key | Speed | Numerical Equivalence | Availability |
|---|---|---|---|---|
| PyTorch SDPA | "sdpa" |
Fast | Exact | Any PyTorch ≥ 2.0 |
| Flash Attention | "kernels_flash" |
Fastest | Approximate | Requires pip install kernels (pre-built) |
| Flex Attention | "flex" |
Very fast | ~Exact | Requires PyTorch ≥ 2.5 |
| Auto | "auto" |
— | — | Always (selects best available) |
PyTorch's scaled_dot_product_attention dispatches to a fused CUDA kernel (cuDNN or efficient attention) that is faster and more memory-efficient than naive attention, while being mathematically identical to it. This is the recommended default for reproducibility and general use. It is also the only backend where output_attentions=True is handled natively; with other backends, attentions are computed via a separate naive matrix multiplication when requested.
Flash Attention 2 and 3 are typically the fastest options on Ampere (A100) and Hopper (H100) GPUs, often 2–4× faster than SDPA at long sequence lengths. Flash Attention achieves this by tiling the computation and applying an online softmax, which means the results are not bitwise identical to SDPA or naive attention. Differences are on the order of floating-point rounding and are often inconsequential for standard inference — but they are not guaranteed to be so. They can compound across layers, interact with low-precision dtypes (fp16/bf16), or affect sensitive downstream tasks. Flash Attention is standard practice in large model training and the trade-off is well understood, but it should not be treated as a drop-in numerical equivalent of SDPA. If exact reproducibility or numerical sensitivity is a concern, use "sdpa" instead.
No compilation required. FastPLMs uses the HuggingFace kernels package to load pre-built Flash Attention 2/3 binaries at runtime — no C++ compiler, no CUDA toolkit version pinning, no waiting:
pip install kernels
Building flash-attn from source is notoriously painful. The Ninja build system parallelizes aggressively across all available CPU cores, and each NVCC/CICC compiler process it spawns can consume 5–8 GB of RAM on its own. On a 64-core machine this can push peak RAM usage to ~300 GB, and even on a throttled single-threaded build (MAX_JOBS=1 NVCC_THREADS=1) the compile still takes many hours while grinding through paging. Pre-built community wheels cover 384+ version/GPU/CUDA/platform combinations and still routinely fall short of matching a user's exact environment. This is the point where most people give up and go without Flash Attention entirely. The kernels package sidesteps all of this by fetching a pre-compiled binary matched to your GPU architecture (SM80 for Ampere, SM90 for Hopper). If no compatible binary exists for your hardware, it gracefully falls back to flex or sdpa rather than erroring.
PyTorch's flex_attention (PyTorch ≥ 2.5) generates a fused Triton kernel customized to the mask pattern at hand. It is numerically very close to SDPA — typically within floating-point rounding of naive computation. The primary advantage is that it can apply a block mask that skips padding tokens entirely, providing a meaningful speedup on batches with variable-length sequences (no compute wasted on padding). E1 uses a block-causal variant of this mask.
The first forward pass triggers JIT compilation via Triton, which can take 30–120 seconds. All subsequent calls are fast. Combining with torch.compile yields the best sustained throughput.
Automatically selects the best available backend in order of preference: kernels_flash → flex → sdpa. Useful when you want maximum speed without configuring the environment manually, and you accept that the resolved backend may differ across machines.
At load time (all models):
from transformers import AutoConfig, AutoModel
config = AutoConfig.from_pretrained("Synthyra/ESM2-150M", trust_remote_code=True)
config.attn_backend = "flex" # "sdpa", "kernels_flash", "flex", or "auto"
model = AutoModel.from_pretrained("Synthyra/ESM2-150M", config=config, trust_remote_code=True)After load time (DPLM and DPLM2 only):
DPLM and DPLM2 expose an attn_backend property on the model that propagates the change to all attention layers immediately:
model = AutoModel.from_pretrained("Synthyra/DPLM-150M", trust_remote_code=True)
model.attn_backend = "flex" # updates every attention layer in-placeFor ESM2, E1, and ESM++, the backend must be set on the config before calling from_pretrained.
All backends support output_attentions=True. For the optimized backends (SDPA, Flash Attention, Flex), attention weights are computed via a separate naive matrix multiplication and appended to the output — so enabling this negates the memory savings of those backends. Use it only for inspection or contact prediction, not during high-throughput inference.
The EmbeddingMixin (shared across all models) provides a standardized way to extract representations from proteins.
The Pooler class aggregates sequence-level residue representations into a single fixed-size vector. Supported strategies include:
mean: Mask-aware average of all residues.cls: The first token's representation (Standard for classification).max: Element-wise maximum across the sequence.var/std: Variance or Standard Deviation of representations.norm: L2 normalization.median: Element-wise median.parti: Experimental PageRank-based attention pooling.
Ideal for embedding millions of sequences where you need to stream data or avoid OOM on RAM.
import torch
from transformers import AutoModel
model = AutoModel.from_pretrained("Synthyra/ESM2-150M", trust_remote_code=True).cuda()
sequences = ["MALWMRLLPLLALLALWGPDPAAA", "MKTIIALSYIFCLVFA", ...]
# Embed and store in SQLite
model.embed_dataset(
sequences=sequences,
batch_size=64,
pooling_types=['mean', 'cls'], # Concatenates both
sql=True,
sql_db_path='large_protein_db.db',
embed_dtype=torch.float32
)Pass a FASTA file path directly — no manual parsing required. Multi-line sequences are handled automatically. You can combine fasta_path with an explicit sequences list and the two sources are merged before embedding.
# Embed all sequences in a FASTA file and save to SQLite
model.embed_dataset(
fasta_path='my_proteins.fasta',
batch_size=64,
pooling_types=['mean'],
sql=True,
sql_db_path='my_proteins.db',
)
# Mix a FASTA file with an explicit list
model.embed_dataset(
sequences=["MKTIIALSYIFCLVFA"],
fasta_path='additional_proteins.fasta',
batch_size=32,
save=True,
save_path='combined_embeddings.pth',
)Perfect for medium-sized datasets that fit in memory.
# Embed and return as a dictionary
embeddings = model.embed_dataset(
sequences=sequences,
batch_size=128,
pooling_types=['mean'],
save=True,
save_path='my_embeddings.pth'
)
# Access embedding
seq_vector = embeddings["MALWMRLLPLLALLALWGPDPAAA"] # torch.TensorConcatenate multiple mathematical representations for richer downstream features.
# Use a variety of pooling types
embeddings = model.embed_dataset(
sequences=sequences,
pooling_types=['mean', 'max', 'std', 'var'], # All 4 concatenated
batch_size=32,
full_embeddings=False
)
# Resulting vector size: 4 * hidden_size
print(embeddings[sequences[0]].shape)FastPLMs includes a robust CLI-based testing suite under testing/.
- Compliance Checks: Verify that optimized models match reference outputs.
py -m testing.run_compliance --families esm2
- Throughput Benchmarks: Measure tokens/sec and peak memory.
py -m testing.run_throughput --device cuda --lengths 512,1024
- Run Everything: Execute the full suite across all families.
py -m testing.run_all --full-models
Results are saved to testing/results/<timestamp>/ as metrics.json, metrics.csv, and high-resolution plots.
git clone https://github.com/Synthyra/FastPLMs.git
cd FastPLMs
pip install -r requirements.txt# Build the image
docker build -t fastplms-test -f Dockerfile .
# Run benchmarks inside container
docker run --rm --gpus all -it -v ${PWD}:/workspace fastplms-test \
python -m testing.run_throughput --device cudaFound a bug or have a feature request? Please open a GitHub Issue. We are actively looking for contributions to optimize more pLM architectures!