diff --git a/docs/_static/draw_pkg_treemap.py b/docs/_static/draw_pkg_treemap.py index d10b88b8e..5239eb26f 100644 --- a/docs/_static/draw_pkg_treemap.py +++ b/docs/_static/draw_pkg_treemap.py @@ -19,5 +19,5 @@ fig = pmv.py_pkg_treemap(pkg_name.replace("-", "_")) fig.layout.title.update(text=f"{pkg_name} Package Structure", font_size=20, x=0.5, y=0.98) fig.show() -# pmv.io.save_and_compress_svg(fig, f"{module_dir}/{pkg_name}-pkg-treemap.svg") +pmv.io.save_and_compress_svg(fig, f"{module_dir}/{pkg_name}-pkg-treemap.svg") fig.write_html(f"{module_dir}/{pkg_name}-pkg-treemap.html", include_plotlyjs="cdn") diff --git a/docs/_static/torch-sim-pkg-treemap.html b/docs/_static/torch-sim-pkg-treemap.html new file mode 100644 index 000000000..40916faa7 --- /dev/null +++ b/docs/_static/torch-sim-pkg-treemap.html @@ -0,0 +1,7 @@ + + + +
+
+ + diff --git a/docs/_static/torch-sim-pkg-treemap.svg b/docs/_static/torch-sim-pkg-treemap.svg new file mode 100644 index 000000000..81070aeed --- /dev/null +++ b/docs/_static/torch-sim-pkg-treemap.svg @@ -0,0 +1 @@ +torch_sim (18,642 lines, 100.0%)modelsintegratorsoptimizerstrajectory9875%state9675%transforms9495%autobatching9495%elastic8935%constraints8555%runners7024%workflowsmath6173%neighborspropertiesio3862%testing3352%quantities2731%monte_carlo1961%symmetrize1771%units991%telemetry740.397%typing350.188%soft_sphere6644%pair_potential6163%lennard_jones3992%fairchem_legacy3632%morse3142%mace2882%interface2741%metatomic2271%particle_life2181%fairchem1891%graphpes_framework1471%mattersim1231%nequip_framework350.188%orb310.166%sevennet230.123%graphpes110.059%npt1,77310%nvt5163%md4242%nve860.461%lbfgs4422%bfgs3832%cell_filters3512%fire3422%state1201%gradient_descent820.44%a2c6864%torch_nl2241%vesin2191%alchemiops1511%correlations4272%torch-sim Package Structure diff --git a/tests/models/conftest.py b/tests/models/conftest.py index 7679224cc..4c78118d4 100644 --- a/tests/models/conftest.py +++ b/tests/models/conftest.py @@ -79,6 +79,8 @@ def make_validate_model_outputs_test( model_fixture_name: str, device: torch.device = DEVICE, dtype: torch.dtype = DTYPE, + *, + check_detached: bool = True, ): """Factory function to create model output validation tests. @@ -86,13 +88,15 @@ def make_validate_model_outputs_test( model_fixture_name: Name of the model fixture to validate device: Device to run validation on dtype: Data type to use for validation + check_detached: Whether to assert output tensors are detached from the + autograd graph (skipped for models with ``retain_graph=True``). """ from torch_sim.models.interface import validate_model_outputs def test_model_output_validation(request: pytest.FixtureRequest) -> None: """Test that a model implementation follows the ModelInterface contract.""" model: ModelInterface = request.getfixturevalue(model_fixture_name) - validate_model_outputs(model, device, dtype) + validate_model_outputs(model, device, dtype, check_detached=check_detached) test_model_output_validation.__name__ = f"test_{model_fixture_name}_output_validation" return test_model_output_validation diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py index bcbee2e16..44284c260 100644 --- a/tests/models/test_fairchem.py +++ b/tests/models/test_fairchem.py @@ -30,7 +30,7 @@ @pytest.fixture def eqv2_uma_model_pbc() -> FairChemModel: """UMA model for periodic boundary condition systems.""" - return FairChemModel(model="uma-s-1", task_name="omat", device=DEVICE) + return FairChemModel(model="uma-s-1p1", task_name="omat", device=DEVICE) @pytest.mark.skipif( @@ -40,7 +40,7 @@ def eqv2_uma_model_pbc() -> FairChemModel: def test_task_initialization(task_name: str) -> None: """Test that different UMA task names work correctly.""" model = FairChemModel( - model="uma-s-1", task_name=task_name, device=torch.device("cpu") + model="uma-s-1p1", task_name=task_name, device=torch.device("cpu") ) assert model.task_name assert str(model.task_name.value) == task_name @@ -77,7 +77,7 @@ def test_homogeneous_batching(task_name: str, systems_func: Callable) -> None: for mol in systems: mol.info |= {"charge": 0, "spin": 1} - model = FairChemModel(model="uma-s-1", task_name=task_name, device=DEVICE) + model = FairChemModel(model="uma-s-1p1", task_name=task_name, device=DEVICE) state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE) results = model(state) @@ -109,7 +109,7 @@ def test_heterogeneous_tasks() -> None: systems[0].info |= {"charge": 0, "spin": 1} model = FairChemModel( - model="uma-s-1", + model="uma-s-1p1", task_name=task_name, device=DEVICE, ) @@ -150,7 +150,7 @@ def test_batch_size_variations(systems_func: Callable, expected_count: int) -> N """Test batching with different numbers and sizes of systems.""" systems = systems_func() - model = FairChemModel(model="uma-s-1", task_name="omat", device=DEVICE) + model = FairChemModel(model="uma-s-1p1", task_name="omat", device=DEVICE) state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE) results = model(state) @@ -170,7 +170,7 @@ def test_stress_computation(*, compute_stress: bool) -> None: systems = [bulk("Si", "diamond", a=5.43), bulk("Al", "fcc", a=4.05)] model = FairChemModel( - model="uma-s-1", + model="uma-s-1p1", task_name="omat", device=DEVICE, compute_stress=compute_stress, @@ -191,7 +191,7 @@ def test_stress_computation(*, compute_stress: bool) -> None: ) def test_device_consistency() -> None: """Test device consistency between model and data.""" - model = FairChemModel(model="uma-s-1", task_name="omat", device=DEVICE) + model = FairChemModel(model="uma-s-1p1", task_name="omat", device=DEVICE) system = bulk("Si", "diamond", a=5.43) state = ts.io.atoms_to_state([system], device=DEVICE, dtype=DTYPE) @@ -205,7 +205,7 @@ def test_device_consistency() -> None: ) def test_empty_batch_error() -> None: """Test that empty batches raise appropriate errors.""" - model = FairChemModel(model="uma-s-1", task_name="omat", device=torch.device("cpu")) + model = FairChemModel(model="uma-s-1p1", task_name="omat", device=torch.device("cpu")) with pytest.raises((ValueError, RuntimeError, IndexError)): model(ts.io.atoms_to_state([], device=torch.device("cpu"), dtype=torch.float32)) @@ -215,7 +215,7 @@ def test_empty_batch_error() -> None: ) def test_load_from_checkpoint_path() -> None: """Test loading model from a saved checkpoint file path.""" - checkpoint_path = pretrained_checkpoint_path_from_name("uma-s-1") + checkpoint_path = pretrained_checkpoint_path_from_name("uma-s-1p1") loaded_model = FairChemModel( model=str(checkpoint_path), task_name="omat", device=DEVICE ) @@ -274,7 +274,7 @@ def test_fairchem_charge_spin(charge: float, spin: float) -> None: # Create model with UMA omol task (supports charge/spin for molecules) model = FairChemModel( - model="uma-s-1", + model="uma-s-1p1", task_name="omol", device=DEVICE, ) @@ -305,7 +305,7 @@ def test_fairchem_single_step_relax(rattled_si_sim_state: ts.SimState) -> None: that it doesn't have issues with the computational graph (e.g., missing .detach() calls). """ - model = FairChemModel(model="uma-s-1", task_name="omat", device=DEVICE) + model = FairChemModel(model="uma-s-1p1", task_name="omat", device=DEVICE) state = rattled_si_sim_state.to(device=DEVICE, dtype=DTYPE) # Initialize FIRE optimizer diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py index adc0de8b7..63681c8ef 100644 --- a/torch_sim/models/interface.py +++ b/torch_sim/models/interface.py @@ -168,8 +168,45 @@ def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]: """ +def _check_output_detached( + output: dict[str, torch.Tensor], model: ModelInterface +) -> None: + """Check that output tensors match the model's graph retention setting. + + When ``retain_graph`` is absent or ``False``, all tensors must be detached. + When ``retain_graph`` is ``True``, all tensors must have ``requires_grad``. + + Args: + output: Model output dictionary mapping keys to tensors. + model: The model that produced the output. + + Raises: + ValueError: If tensors are not detached when ``retain_graph`` is + ``False``, or lack gradients when ``retain_graph`` is ``True``. + """ + retain_graph = getattr(model, "retain_graph", False) + for key, tensor in output.items(): + if not isinstance(tensor, torch.Tensor): + continue + if retain_graph and not tensor.requires_grad: + raise ValueError( + f"Output tensor '{key}' does not have gradients but model.retain_graph " + "is True. Ensure the tensor is part of the computation graph." + ) + if not retain_graph and tensor.requires_grad: + raise ValueError( + f"Output tensor '{key}' is not detached from the computation graph. " + "Call .detach() on the tensor before returning it, or set " + "model.retain_graph = True if graph retention is intentional." + ) + + def validate_model_outputs( # noqa: C901, PLR0915 - model: ModelInterface, device: torch.device, dtype: torch.dtype + model: ModelInterface, + device: torch.device, + dtype: torch.dtype, + *, + check_detached: bool = False, ) -> None: """Validate the outputs of a model implementation against the interface requirements. @@ -181,6 +218,10 @@ def validate_model_outputs( # noqa: C901, PLR0915 model (ModelInterface): Model implementation to validate. device (torch.device): Device to run the validation tests on. dtype (torch.dtype): Data type to use for validation tensors. + check_detached (bool): If ``True``, assert that all output tensors are + detached from the autograd graph, unless the model has a + ``retain_graph`` attribute set to ``True``. Defaults to ``False`` so + that external callers are not immediately broken. Raises: AssertionError: If the model doesn't conform to the required interface, @@ -229,8 +270,16 @@ def validate_model_outputs( # noqa: C901, PLR0915 og_system_idx = system_idx.clone() og_atomic_nums = sim_state.atomic_numbers.clone() + if check_detached and hasattr(model, "retain_graph"): + model.__dict__["retain_graph"] = True + _check_output_detached(model.forward(sim_state), model) + model.__dict__["retain_graph"] = False + model_output = model.forward(sim_state) + if check_detached: + _check_output_detached(model_output, model) + # assert model did not mutate the input if not torch.allclose(og_positions, sim_state.positions): raise ValueError(f"{og_positions=} != {sim_state.positions=}") diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py index d0a0237a1..47417d089 100644 --- a/torch_sim/models/orb.py +++ b/torch_sim/models/orb.py @@ -19,6 +19,13 @@ class OrbModel(OrbTorchSimModel): """ORB model wrapper for torch-sim.""" + def forward(self, *args: Any, **kwargs: Any) -> dict[str, Any]: + """Run forward pass, detaching outputs unless retain_graph is True.""" + output = super().forward(*args, **kwargs) + return { # detach tensors as energy is not detached by default + k: v.detach() if hasattr(v, "detach") else v for k, v in output.items() + } + except ImportError as exc: warnings.warn(f"Orb import failed: {traceback.format_exc()}", stacklevel=2) diff --git a/torch_sim/models/pair_potential.py b/torch_sim/models/pair_potential.py index 6d13dbecb..a537e6e97 100644 --- a/torch_sim/models/pair_potential.py +++ b/torch_sim/models/pair_potential.py @@ -608,7 +608,7 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] ) ) - return results + return {k: v.detach() for k, v in results.items()} class PairForcesModel(ModelInterface):