diff --git a/README.md b/README.md index dfbd0b8d..40932d2f 100644 --- a/README.md +++ b/README.md @@ -223,7 +223,7 @@ spi = StructurePredictionInput( ) result = ESMFold2InputBuilder().fold( - model, spi, num_loops=3, num_sampling_steps=50, num_diffusion_samples=1, seed=0 + model, spi, num_loops=20, num_sampling_steps=100, num_diffusion_samples=1, seed=0 ) print(f"pLDDT mean: {float(result.plddt.mean()):.3f}, pTM: {float(result.ptm):.3f}, ipTM: {float(result.iptm):.3f}") @@ -266,7 +266,10 @@ ca2_input = StructurePredictionInput( sequences=[ProteinInput(id="A", sequence=ca2_sequence)] ) -config = FoldingConfig(num_loops=3, num_sampling_steps=32) +config = FoldingConfig( + num_loops=20, + num_sampling_steps=100 +) result = client.fold_all_atom(ca2_input, config=config) with open("result.cif", "w") as f: diff --git a/cookbook/tutorials/binder_design.ipynb b/cookbook/tutorials/binder_design.ipynb index a9219583..7ec1b867 100644 --- a/cookbook/tutorials/binder_design.ipynb +++ b/cookbook/tutorials/binder_design.ipynb @@ -7,10 +7,26 @@ "source": [ "## [Tutorial](https://github.com/biohub/esm/tree/main/cookbook/tutorials): How to run minibinder + scFv design fully end-to-end.\n", "\n", - "In this tutorial we will use [Modal](https://modal.com/) to parallelize binder design and synthesize a selection,\n", - "using the protocol described in the ESMC and ESMFold2 paper titled [\"Language Modeling Materializes a World Model of Protein Biology\"](https://biohub.ai/papers/esm_protein.pdf).\n", + "In this notebook we will use [Modal](https://modal.com/) to parallelize binder design and synthesize a selection, using the protocol described in the ESMC and ESMFold2 paper titled [\"Language Modeling Materializes a World Model of Protein Biology\"](https://biohub.ai/papers/esm_protein.pdf).\n", "\n", - "Biohub used this approach to design minibinders and scFvs against five therapeutically relevant targets — PDGFRB, EGFR, PD-L1, CD45, and CTLA4 — spanning receptor tyrosine kinases, immune checkpoints, and cell-surface phosphatases. Binders exhibit nanomolar affinity, target specificity, and functional activity in laboratory assays." + "Biohub used this approach to design minibinders and scFvs against five therapeutically relevant targets — PDGFRB, EGFR, PD-L1, CD45, and CTLA4 — spanning receptor tyrosine kinases, immune checkpoints, and cell-surface phosphatases. Binders exhibit nanomolar affinity, target specificity, and functional activity in laboratory assays.\n", + "\n", + "\n", + "**You'll need:**\n", + "- A target protein sequence (or pick one of the built-in presets)\n", + "- A [Modal](https://modal.com/) account and token (free tier works to get started)\n", + "- The notebook contains cells launching independent trajectories and a small sweep (N=256). With Modal H100 pricing, this will cost ≈$2 and ≈$150, respectively\n", + "\n", + "**You'll get:** a file of top-ranked designed binder sequences, plus 3D structures of the predicted complexes that you can view in the notebook.\n", + "\n", + "**Workflow:**\n", + "1. **Setup** (one-time): install dependencies, get a Modal token, deploy the design app\n", + "2. **Try one job**: pick a target and binder type, run a single design end-to-end as a sanity check\n", + "3. **Run a sweep**: launch many parallel jobs to produce real candidates\n", + "4. **Pick the designs to order**: filter and rank, save the shortlist\n", + "\n", + "\n", + "> **Why does this notebook need Modal?** Each design job needs a GPU and a few minutes of compute, and a real campaign means launching hundreds of these in parallel. [Modal](https://modal.com/) is a cloud service that lets this notebook spin up remote GPUs on demand, run each job on one, and return results to you. You don't manage any servers or containers yourself. When you \"deploy\" the design app to Modal (in section 1), you're uploading a Python file that defines the job; Modal then runs that code for you whenever the notebook calls `app.design.spawn(...)`." ] }, { @@ -18,7 +34,7 @@ "id": "4421cdaa", "metadata": {}, "source": [ - "### One-time setup" + "### 1. Setup" ] }, { @@ -45,6 +61,40 @@ "# ! modal token new # Create" ] }, + { + "cell_type": "markdown", + "id": "c88594a9", + "metadata": {}, + "source": [ + "### Deploy the design app to Modal\n", + "\n", + "The file `binder_design.py` lives in the same folder as this notebook. It defines the GPU job, the model, and the design loop.\n", + "\n", + "You only need to deploy once. Re-run this cell only if the underlying `.py` file changes." + ] + }, + { + "cell_type": "markdown", + "id": "8e88a943", + "metadata": {}, + "source": [ + "### If you're running on Colab\n", + "\n", + "Colab only pulls the notebook itself when you open it, not the surrounding files from the repo. Run the cell below to grab `binder_design.py` into your Colab workspace. If you're running locally, skip it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "af900bae", + "metadata": {}, + "outputs": [], + "source": [ + "# Colab only: download binder_design.py into the working directory\n", + "! wget -q https://raw.githubusercontent.com/Biohub/esm/main/cookbook/tutorials/binder_design.py\n", + "! ls binder_design.py " + ] + }, { "cell_type": "code", "execution_count": null, @@ -87,7 +137,11 @@ "id": "1fe9a141", "metadata": {}, "source": [ - "### App setup" + "### App setup\n", + "\n", + "`modal.Cls.from_name(...)` grabs a handle to the design app you just deployed, without rerunning anything on Modal yet. Instantiating it gives you `app`, which is what you'll call `app.design.spawn(...)` on to launch design jobs.\n", + "\n", + "`use_scaling_critics=False` is the default. Setting it to `True` adds extra critic models from the paper that improve selection at the cost of more compute per job." ] }, { @@ -97,8 +151,7 @@ "metadata": {}, "outputs": [], "source": [ - "# ESMFold2Design = modal.Cls.from_name(\"esmfold2-design-jun1-11am\", \"ESMFold2DesignModal\")\n", - "ESMFold2Design = modal.Cls.from_name(\"esmfold2-design-jun1-12pm\", \"ESMFold2DesignModal\")\n", + "ESMFold2Design = modal.Cls.from_name(\"esmfold2-design\", \"ESMFold2DesignModal\")\n", "# Set 'use_scaling_critics' to evaluate with the additional critics.\n", "# Off by default. But cells below were populated with them enabled.\n", "app = ESMFold2Design(use_scaling_critics=False)" @@ -109,7 +162,21 @@ "id": "159b63df", "metadata": {}, "source": [ - "### Run one job - interactive" + "## 2. Try one design job\n", + "\n", + "Run a single job end-to-end before launching a sweep. This is a sanity check that everything is wired up and that the target/scaffold combo you've chosen produces a sensible complex.\n", + "\n", + "Pick **one** of the two options below and run only that cell. (They both define a variable called `future`, so running both back-to-back overwrites the first one.)\n", + "\n", + "**Option 1** uses a built-in target and binder scaffold. Available targets: `ctla4`, `egfr`, `pdgfrb`, `pd-l1`, `cd45`. Available binder types: `minibinder`, `trastuzumab_framework_vhvl` (an antibody scaffold). Easiest if your target is one of the built-ins.\n", + "\n", + "**Option 2** takes your own target sequence and binder scaffold. In the binder scaffold, `#` means \"design this position\" and any amino acid letter means \"keep this position fixed.\" For example:\n", + "- `\"#\" * 60` designs a fully free 60-residue minibinder\n", + "- A trastuzumab-style antibody scaffold (shown below) fixes the framework regions and lets the model design the CDR loops \n", + "\n", + "If you're designing an antibody, pass `is_antibody=True` so the selection step later uses the antibody-appropriate scoring.\n", + "\n", + "Jobs run on Modal in the background. The `dashboard_url` link the cell prints is a clickable link to live progress." ] }, { @@ -281,7 +348,12 @@ "id": "fc105292", "metadata": {}, "source": [ - "### Run a sweep - async" + "## 3. Run a sweep for real designs\n", + "For real candidates worth ordering, sweep across many seeds (and optionally multiple targets, binder types, or lengths) and select the best.\n", + "\n", + "Edit `line_sweeps` below to define your campaign. Each key is a sweep axis; the notebook runs one job per combination of values. The default below sweeps 128 seeds across two binder types against PD-L1. \n", + "\n", + "**Before you click Run on the Launch cell, check the printed shape of the dataframe and confirm it's the number of jobs you intended.** " ] }, { @@ -423,6 +495,26 @@ ")" ] }, + { + "cell_type": "markdown", + "id": "13f0542e", + "metadata": {}, + "source": [ + "### Coming back later (important)\n", + "\n", + "Your jobs are now running on Modal's GPUs. **You do not need to keep this notebook open.** Modal continues running the jobs on its own and holds the results for up to 7 days.\n", + "\n", + "**If you're running on Colab:** the runtime is wiped when you disconnect. To resume a sweep on Colab, mount Google Drive **before** running the Launch cell and point `save_dir` at a Drive path, e.g.:\n", + "\n", + "\\`\\`\\`python\n", + "from google.colab import drive\n", + "drive.mount('/content/drive')\n", + "save_dir = Path('/content/drive/MyDrive/binder_sweep')\n", + "\\`\\`\\`\n", + "\n", + "When you reopen the notebook later, you'll need to re-install the dependencies, re-authenticate Modal (`modal token new`), re-mount Drive, then resume from the Monitor cell." + ] + }, { "cell_type": "code", "execution_count": 34, @@ -488,6 +580,14 @@ "df.status.value_counts()" ] }, + { + "cell_type": "markdown", + "id": "32fbc8d8", + "metadata": {}, + "source": [ + "The Collect cell waits for all jobs that succeeded and unpacks their results. Jobs that failed or are still running are skipped, so you can re-run Monitor + Collect periodically as more jobs finish." + ] + }, { "cell_type": "code", "execution_count": 48, @@ -532,6 +632,23 @@ "df_success[\"result_df\"] = [pd.DataFrame(r[2]) for r in tqdm(df_success.result)] # pyright: ignore" ] }, + { + "cell_type": "markdown", + "id": "d53ad844", + "metadata": {}, + "source": [ + "## 4. Pick the designs to order\n", + "\n", + "This is your final shortlist. The cell below:\n", + "\n", + "1. Combines results from all successful jobs into one dataframe.\n", + "2. Filters minibinders to isoelectric point under 6 (helps with solubility and expression). Antibodies pass through unfiltered.\n", + "3. Scores each unique designed sequence by averaging `iptm` (interface predicted TM-score, higher is better) and an `iptm_proxy` term across its trajectories.\n", + "4. Returns the top 84 designs per (target, binder type), saved to `selection.parquet` inside your `save_dir`.\n", + "\n", + "84 is a plate-friendly number for ordering and screening. The cutoff and the isoelectric-point filter are currently hardcoded inside the cell, so to change them, edit the values directly in the function.\n" + ] + }, { "cell_type": "code", "execution_count": 49, diff --git a/cookbook/tutorials/esmfold2.ipynb b/cookbook/tutorials/esmfold2.ipynb index 7b14ce59..a7b4d5a3 100644 --- a/cookbook/tutorials/esmfold2.ipynb +++ b/cookbook/tutorials/esmfold2.ipynb @@ -263,12 +263,17 @@ "\n", "Configure the folding parameters to control prediction quality and computational cost:\n", "\n", - "* **num_loops**: Number of times the structure is iteratively refined. More loops generally improve accuracy but increase runtime (typical range: 3-5).\n", + "* **num_loops** (default `20`): Number of times the structure is iteratively refined. More loops generally improve accuracy but increase runtime.\n", "\n", - "* **num_sampling_steps**: Number of diffusion sampling steps. Higher values produce more accurate structures but take longer (typical range: 200-400).\n", + "* **num_sampling_steps** (default `100`): Number of diffusion sampling steps. Higher values produce more accurate structures but take longer.\n", "\n", + "* **lm_dropout** (default `0.3`): Dropout probability on LM pair embeddings. When > 0, dropout is applied.\n", "\n", - "We use a moderate number of loops (3) and sampling steps (32) to balance accuracy and runtime for this example. For more accurate results, especially when modeling flexible ligands or uncertain binding modes, you can increase num_loops and num_sampling_steps at the cost of longer runtimes." + "* **lm_mask_pct** (default `0.1` for `esmfold2-fast`, `0.0` for `esmfold2`): Fraction of sequence residues randomly masked before the PLM backbone. Leave unset (`None`) to use the model-dependent default.\n", + "\n", + "* **msa_max_depth** (default `1024`): Number of MSA rows randomly subsampled on each loop. Set to `None` to disable MSA subsampling.\n", + "\n", + "We use a moderate number of loops (10) to balance accuracy and runtime for this example. For more accurate results, especially when modeling flexible ligands or uncertain binding modes, you can increase num_loops at the cost of longer runtimes." ] }, { @@ -277,7 +282,7 @@ "metadata": {}, "outputs": [], "source": [ - "config = FoldingConfig(num_loops=3, num_sampling_steps=32, include_pae=True)" + "config = FoldingConfig(num_loops=10, num_sampling_steps=100, include_pae=True)" ] }, { diff --git a/esm/models/esmfold2/processor.py b/esm/models/esmfold2/processor.py index 0a67b5d9..6432336d 100644 --- a/esm/models/esmfold2/processor.py +++ b/esm/models/esmfold2/processor.py @@ -42,6 +42,47 @@ def _seed_context(seed: int | None): torch.cuda.set_rng_state_all(cuda_state) +@contextmanager +def _lm_dropout_context(model: Any, lm_dropout: float | None): + """Temporarily set LM-embedding dropout for the wrapped forward, restoring on exit. + + Applies dropout to ``lm_z`` at inference (``training=True``, fresh mask per + loop) so repeated folds give a diverse ensemble. Release models read + ``config.lm_encoder.lm_dropout``, the experimental model a top-level + ``config.lm_dropout`` — disambiguated by ``config.type``. ``None``/``0`` is a no-op. + """ + if not lm_dropout: + yield + return + + cfg = model.config + lm_encoder_cfg = getattr(cfg, "lm_encoder", None) + if lm_encoder_cfg is not None and getattr(cfg, "type", None) != "experimental": + saved = (lm_encoder_cfg.lm_dropout, lm_encoder_cfg.per_loop_lm_dropout) + lm_encoder_cfg.lm_dropout = lm_dropout + lm_encoder_cfg.per_loop_lm_dropout = True + try: + yield + finally: + lm_encoder_cfg.lm_dropout, lm_encoder_cfg.per_loop_lm_dropout = saved + elif hasattr(cfg, "lm_dropout"): + saved = ( + cfg.lm_dropout, + getattr(cfg, "force_lm_dropout_during_inference", False), + ) + cfg.lm_dropout = lm_dropout + cfg.force_lm_dropout_during_inference = True + try: + yield + finally: + cfg.lm_dropout, cfg.force_lm_dropout_during_inference = saved + else: + raise ValueError( + "lm_dropout was requested but this model's config exposes neither " + "`lm_encoder.lm_dropout` nor a top-level `lm_dropout`." + ) + + def clean_esmfold2_input(input: StructurePredictionInput) -> StructurePredictionInput: """Group identical protein sequences into the same ProteinInput with multiple ids. @@ -295,7 +336,9 @@ def fold( noise_scale: float | None = None, step_scale: float | None = None, max_inference_sigma: float | None = None, + lm_mask_pct: float | None = None, early_exit: bool = False, + lm_dropout: float | None = 0.3, complex_id: str = "pred", ) -> MolecularComplexResult | list[MolecularComplexResult]: """Fold a structure end-to-end: encode → model → decode. @@ -312,6 +355,13 @@ def fold( Seeds both input prep (SMILES conformer generation) and diffusion sampling. noise_scale, step_scale, max_inference_sigma, early_exit Optional sampler overrides forwarded to the model when not None. + lm_mask_pct : float, optional + Fraction of sequence residues randomly masked before the PLM backbone. + Overrides the checkpoint config when not None. + lm_dropout : float, optional + LM-embedding dropout for this fold (fresh mask per loop → diverse + ensemble on repeated folds). Defaults to ``0.3`` (paper folding-eval + value); ``0``/``None`` disables. complex_id : str Identifier assigned to the predicted MolecularComplex(es). @@ -331,17 +381,20 @@ def fold( sampler_kwargs["step_scale"] = step_scale if max_inference_sigma is not None: sampler_kwargs["max_inference_sigma"] = max_inference_sigma + if lm_mask_pct is not None: + sampler_kwargs["lm_mask_pct"] = lm_mask_pct with torch.no_grad(): with _seed_context(seed) if seed is not None else nullcontext(): - output = model( - **features, - num_loops=num_loops, - num_sampling_steps=num_sampling_steps, - num_diffusion_samples=num_diffusion_samples, - early_exit=early_exit, - **sampler_kwargs, - ) + with _lm_dropout_context(model, lm_dropout): + output = model( + **features, + num_loops=num_loops, + num_sampling_steps=num_sampling_steps, + num_diffusion_samples=num_diffusion_samples, + early_exit=early_exit, + **sampler_kwargs, + ) return self.decode( output, diff --git a/esm/sdk/api.py b/esm/sdk/api.py index 54221ff1..d0a68e47 100644 --- a/esm/sdk/api.py +++ b/esm/sdk/api.py @@ -367,6 +367,12 @@ class FoldingConfig: include_pair_chains_iptm: bool = False num_sampling_steps: int = 100 num_loops: int = 20 + lm_dropout: float = 0.3 + lm_mask_pct: float | None = ( + None # If not provided, defaults to 0.1 for ESMFOLD2_FAST and 0.0 for ESMFOLD2 + ) + msa_max_depth: int | None = 1024 + msa_column_mask_rate: float = 0.1 include_embeddings: bool = False diff --git a/esm/sdk/forge.py b/esm/sdk/forge.py index 1c807c77..c1273dcb 100644 --- a/esm/sdk/forge.py +++ b/esm/sdk/forge.py @@ -30,7 +30,12 @@ from esm.sdk.base_forge_client import _BaseForgeInferenceClient from esm.sdk.retry import retry_decorator from esm.utils.constants.api import MIMETYPE_ES_PICKLE -from esm.utils.constants.models import ESMFOLD2_FAST, ESMFOLD2_MAX_MSA_SEQS +from esm.utils.constants.models import ( + DEFAULT_ESMFOLD2_FAST_LM_MASK_PCT, + ESMFOLD2, + ESMFOLD2_FAST, + ESMFOLD2_MAX_MSA_SEQS, +) from esm.utils.misc import deserialize_tensors, maybe_list, maybe_tensor from esm.utils.msa import MSA from esm.utils.structure.input_builder import ( @@ -52,6 +57,18 @@ # fmt: on +def _resolve_lm_mask_pct(lm_mask_pct: float | None, model_name: str | None) -> float: + f""" + An explicit value is always honored. When unset (``None``), + {ESMFOLD2_FAST} (or no model name) defaults to 0.1, + {ESMFOLD2} defaults to 0.0. + """ + if lm_mask_pct is not None: + return lm_mask_pct + is_fast_or_default = model_name is None or model_name == ESMFOLD2_FAST + return DEFAULT_ESMFOLD2_FAST_LM_MASK_PCT if is_fast_or_default else 0.0 + + def _list_to_function_annotations(l) -> list[FunctionAnnotation] | None: if l is None or len(l) <= 0: return None @@ -150,6 +167,10 @@ def _process_fold_request( request["include_pair_chains_iptm"] = config.include_pair_chains_iptm request["num_sampling_steps"] = config.num_sampling_steps request["num_loops"] = config.num_loops + request["lm_dropout"] = config.lm_dropout + request["lm_mask_pct"] = _resolve_lm_mask_pct(config.lm_mask_pct, model_name) + request["msa_max_depth"] = config.msa_max_depth + request["msa_column_mask_rate"] = config.msa_column_mask_rate request["include_embeddings"] = config.include_embeddings request["model"] = model_name @@ -371,6 +392,10 @@ def _process_fold_all_atom_request( request["include_pae"] = config.include_pae request["num_sampling_steps"] = config.num_sampling_steps request["num_loops"] = config.num_loops + request["lm_dropout"] = config.lm_dropout + request["lm_mask_pct"] = _resolve_lm_mask_pct(config.lm_mask_pct, model_name) + request["msa_max_depth"] = config.msa_max_depth + request["msa_column_mask_rate"] = config.msa_column_mask_rate request["include_embeddings"] = config.include_embeddings return request diff --git a/esm/utils/constants/models.py b/esm/utils/constants/models.py index 8a613225..cc64f0d7 100644 --- a/esm/utils/constants/models.py +++ b/esm/utils/constants/models.py @@ -11,7 +11,9 @@ ESMC_6B = "esmc_6b" ESMFOLD2_FAST = "esmfold2-fast-2026-05" ESMFOLD2 = "esmfold2-2026-05" + ESMFOLD2_MAX_MSA_SEQS = 16384 +DEFAULT_ESMFOLD2_FAST_LM_MASK_PCT = 0.1 def forge_only_return_single_layer_hidden_states(model_name: str):