diff --git a/docs/_static/draw_pkg_treemap.py b/docs/_static/draw_pkg_treemap.py
index d10b88b8e..5239eb26f 100644
--- a/docs/_static/draw_pkg_treemap.py
+++ b/docs/_static/draw_pkg_treemap.py
@@ -19,5 +19,5 @@
fig = pmv.py_pkg_treemap(pkg_name.replace("-", "_"))
fig.layout.title.update(text=f"{pkg_name} Package Structure", font_size=20, x=0.5, y=0.98)
fig.show()
-# pmv.io.save_and_compress_svg(fig, f"{module_dir}/{pkg_name}-pkg-treemap.svg")
+pmv.io.save_and_compress_svg(fig, f"{module_dir}/{pkg_name}-pkg-treemap.svg")
fig.write_html(f"{module_dir}/{pkg_name}-pkg-treemap.html", include_plotlyjs="cdn")
diff --git a/docs/_static/torch-sim-pkg-treemap.html b/docs/_static/torch-sim-pkg-treemap.html
new file mode 100644
index 000000000..40916faa7
--- /dev/null
+++ b/docs/_static/torch-sim-pkg-treemap.html
@@ -0,0 +1,7 @@
+
+
+
+
+
+
diff --git a/docs/_static/torch-sim-pkg-treemap.svg b/docs/_static/torch-sim-pkg-treemap.svg
new file mode 100644
index 000000000..81070aeed
--- /dev/null
+++ b/docs/_static/torch-sim-pkg-treemap.svg
@@ -0,0 +1 @@
+
diff --git a/tests/models/conftest.py b/tests/models/conftest.py
index 7679224cc..4c78118d4 100644
--- a/tests/models/conftest.py
+++ b/tests/models/conftest.py
@@ -79,6 +79,8 @@ def make_validate_model_outputs_test(
model_fixture_name: str,
device: torch.device = DEVICE,
dtype: torch.dtype = DTYPE,
+ *,
+ check_detached: bool = True,
):
"""Factory function to create model output validation tests.
@@ -86,13 +88,15 @@ def make_validate_model_outputs_test(
model_fixture_name: Name of the model fixture to validate
device: Device to run validation on
dtype: Data type to use for validation
+ check_detached: Whether to assert output tensors are detached from the
+ autograd graph (skipped for models with ``retain_graph=True``).
"""
from torch_sim.models.interface import validate_model_outputs
def test_model_output_validation(request: pytest.FixtureRequest) -> None:
"""Test that a model implementation follows the ModelInterface contract."""
model: ModelInterface = request.getfixturevalue(model_fixture_name)
- validate_model_outputs(model, device, dtype)
+ validate_model_outputs(model, device, dtype, check_detached=check_detached)
test_model_output_validation.__name__ = f"test_{model_fixture_name}_output_validation"
return test_model_output_validation
diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py
index bcbee2e16..44284c260 100644
--- a/tests/models/test_fairchem.py
+++ b/tests/models/test_fairchem.py
@@ -30,7 +30,7 @@
@pytest.fixture
def eqv2_uma_model_pbc() -> FairChemModel:
"""UMA model for periodic boundary condition systems."""
- return FairChemModel(model="uma-s-1", task_name="omat", device=DEVICE)
+ return FairChemModel(model="uma-s-1p1", task_name="omat", device=DEVICE)
@pytest.mark.skipif(
@@ -40,7 +40,7 @@ def eqv2_uma_model_pbc() -> FairChemModel:
def test_task_initialization(task_name: str) -> None:
"""Test that different UMA task names work correctly."""
model = FairChemModel(
- model="uma-s-1", task_name=task_name, device=torch.device("cpu")
+ model="uma-s-1p1", task_name=task_name, device=torch.device("cpu")
)
assert model.task_name
assert str(model.task_name.value) == task_name
@@ -77,7 +77,7 @@ def test_homogeneous_batching(task_name: str, systems_func: Callable) -> None:
for mol in systems:
mol.info |= {"charge": 0, "spin": 1}
- model = FairChemModel(model="uma-s-1", task_name=task_name, device=DEVICE)
+ model = FairChemModel(model="uma-s-1p1", task_name=task_name, device=DEVICE)
state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE)
results = model(state)
@@ -109,7 +109,7 @@ def test_heterogeneous_tasks() -> None:
systems[0].info |= {"charge": 0, "spin": 1}
model = FairChemModel(
- model="uma-s-1",
+ model="uma-s-1p1",
task_name=task_name,
device=DEVICE,
)
@@ -150,7 +150,7 @@ def test_batch_size_variations(systems_func: Callable, expected_count: int) -> N
"""Test batching with different numbers and sizes of systems."""
systems = systems_func()
- model = FairChemModel(model="uma-s-1", task_name="omat", device=DEVICE)
+ model = FairChemModel(model="uma-s-1p1", task_name="omat", device=DEVICE)
state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE)
results = model(state)
@@ -170,7 +170,7 @@ def test_stress_computation(*, compute_stress: bool) -> None:
systems = [bulk("Si", "diamond", a=5.43), bulk("Al", "fcc", a=4.05)]
model = FairChemModel(
- model="uma-s-1",
+ model="uma-s-1p1",
task_name="omat",
device=DEVICE,
compute_stress=compute_stress,
@@ -191,7 +191,7 @@ def test_stress_computation(*, compute_stress: bool) -> None:
)
def test_device_consistency() -> None:
"""Test device consistency between model and data."""
- model = FairChemModel(model="uma-s-1", task_name="omat", device=DEVICE)
+ model = FairChemModel(model="uma-s-1p1", task_name="omat", device=DEVICE)
system = bulk("Si", "diamond", a=5.43)
state = ts.io.atoms_to_state([system], device=DEVICE, dtype=DTYPE)
@@ -205,7 +205,7 @@ def test_device_consistency() -> None:
)
def test_empty_batch_error() -> None:
"""Test that empty batches raise appropriate errors."""
- model = FairChemModel(model="uma-s-1", task_name="omat", device=torch.device("cpu"))
+ model = FairChemModel(model="uma-s-1p1", task_name="omat", device=torch.device("cpu"))
with pytest.raises((ValueError, RuntimeError, IndexError)):
model(ts.io.atoms_to_state([], device=torch.device("cpu"), dtype=torch.float32))
@@ -215,7 +215,7 @@ def test_empty_batch_error() -> None:
)
def test_load_from_checkpoint_path() -> None:
"""Test loading model from a saved checkpoint file path."""
- checkpoint_path = pretrained_checkpoint_path_from_name("uma-s-1")
+ checkpoint_path = pretrained_checkpoint_path_from_name("uma-s-1p1")
loaded_model = FairChemModel(
model=str(checkpoint_path), task_name="omat", device=DEVICE
)
@@ -274,7 +274,7 @@ def test_fairchem_charge_spin(charge: float, spin: float) -> None:
# Create model with UMA omol task (supports charge/spin for molecules)
model = FairChemModel(
- model="uma-s-1",
+ model="uma-s-1p1",
task_name="omol",
device=DEVICE,
)
@@ -305,7 +305,7 @@ def test_fairchem_single_step_relax(rattled_si_sim_state: ts.SimState) -> None:
that it doesn't have issues with the computational graph (e.g., missing
.detach() calls).
"""
- model = FairChemModel(model="uma-s-1", task_name="omat", device=DEVICE)
+ model = FairChemModel(model="uma-s-1p1", task_name="omat", device=DEVICE)
state = rattled_si_sim_state.to(device=DEVICE, dtype=DTYPE)
# Initialize FIRE optimizer
diff --git a/torch_sim/models/interface.py b/torch_sim/models/interface.py
index adc0de8b7..63681c8ef 100644
--- a/torch_sim/models/interface.py
+++ b/torch_sim/models/interface.py
@@ -168,8 +168,45 @@ def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]:
"""
+def _check_output_detached(
+ output: dict[str, torch.Tensor], model: ModelInterface
+) -> None:
+ """Check that output tensors match the model's graph retention setting.
+
+ When ``retain_graph`` is absent or ``False``, all tensors must be detached.
+ When ``retain_graph`` is ``True``, all tensors must have ``requires_grad``.
+
+ Args:
+ output: Model output dictionary mapping keys to tensors.
+ model: The model that produced the output.
+
+ Raises:
+ ValueError: If tensors are not detached when ``retain_graph`` is
+ ``False``, or lack gradients when ``retain_graph`` is ``True``.
+ """
+ retain_graph = getattr(model, "retain_graph", False)
+ for key, tensor in output.items():
+ if not isinstance(tensor, torch.Tensor):
+ continue
+ if retain_graph and not tensor.requires_grad:
+ raise ValueError(
+ f"Output tensor '{key}' does not have gradients but model.retain_graph "
+ "is True. Ensure the tensor is part of the computation graph."
+ )
+ if not retain_graph and tensor.requires_grad:
+ raise ValueError(
+ f"Output tensor '{key}' is not detached from the computation graph. "
+ "Call .detach() on the tensor before returning it, or set "
+ "model.retain_graph = True if graph retention is intentional."
+ )
+
+
def validate_model_outputs( # noqa: C901, PLR0915
- model: ModelInterface, device: torch.device, dtype: torch.dtype
+ model: ModelInterface,
+ device: torch.device,
+ dtype: torch.dtype,
+ *,
+ check_detached: bool = False,
) -> None:
"""Validate the outputs of a model implementation against the interface requirements.
@@ -181,6 +218,10 @@ def validate_model_outputs( # noqa: C901, PLR0915
model (ModelInterface): Model implementation to validate.
device (torch.device): Device to run the validation tests on.
dtype (torch.dtype): Data type to use for validation tensors.
+ check_detached (bool): If ``True``, assert that all output tensors are
+ detached from the autograd graph, unless the model has a
+ ``retain_graph`` attribute set to ``True``. Defaults to ``False`` so
+ that external callers are not immediately broken.
Raises:
AssertionError: If the model doesn't conform to the required interface,
@@ -229,8 +270,16 @@ def validate_model_outputs( # noqa: C901, PLR0915
og_system_idx = system_idx.clone()
og_atomic_nums = sim_state.atomic_numbers.clone()
+ if check_detached and hasattr(model, "retain_graph"):
+ model.__dict__["retain_graph"] = True
+ _check_output_detached(model.forward(sim_state), model)
+ model.__dict__["retain_graph"] = False
+
model_output = model.forward(sim_state)
+ if check_detached:
+ _check_output_detached(model_output, model)
+
# assert model did not mutate the input
if not torch.allclose(og_positions, sim_state.positions):
raise ValueError(f"{og_positions=} != {sim_state.positions=}")
diff --git a/torch_sim/models/orb.py b/torch_sim/models/orb.py
index d0a0237a1..47417d089 100644
--- a/torch_sim/models/orb.py
+++ b/torch_sim/models/orb.py
@@ -19,6 +19,13 @@
class OrbModel(OrbTorchSimModel):
"""ORB model wrapper for torch-sim."""
+ def forward(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
+ """Run forward pass, detaching outputs unless retain_graph is True."""
+ output = super().forward(*args, **kwargs)
+ return { # detach tensors as energy is not detached by default
+ k: v.detach() if hasattr(v, "detach") else v for k, v in output.items()
+ }
+
except ImportError as exc:
warnings.warn(f"Orb import failed: {traceback.format_exc()}", stacklevel=2)
diff --git a/torch_sim/models/pair_potential.py b/torch_sim/models/pair_potential.py
index 6d13dbecb..a537e6e97 100644
--- a/torch_sim/models/pair_potential.py
+++ b/torch_sim/models/pair_potential.py
@@ -608,7 +608,7 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
)
)
- return results
+ return {k: v.detach() for k, v in results.items()}
class PairForcesModel(ModelInterface):