Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/_static/draw_pkg_treemap.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@
fig = pmv.py_pkg_treemap(pkg_name.replace("-", "_"))
fig.layout.title.update(text=f"{pkg_name} Package Structure", font_size=20, x=0.5, y=0.98)
fig.show()
# pmv.io.save_and_compress_svg(fig, f"{module_dir}/{pkg_name}-pkg-treemap.svg")
pmv.io.save_and_compress_svg(fig, f"{module_dir}/{pkg_name}-pkg-treemap.svg")
fig.write_html(f"{module_dir}/{pkg_name}-pkg-treemap.html", include_plotlyjs="cdn")
7 changes: 7 additions & 0 deletions docs/_static/torch-sim-pkg-treemap.html

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions docs/_static/torch-sim-pkg-treemap.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 5 additions & 1 deletion tests/models/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,20 +79,24 @@ def make_validate_model_outputs_test(
model_fixture_name: str,
device: torch.device = DEVICE,
dtype: torch.dtype = DTYPE,
*,
check_detached: bool = True,
):
"""Factory function to create model output validation tests.

Args:
model_fixture_name: Name of the model fixture to validate
device: Device to run validation on
dtype: Data type to use for validation
check_detached: Whether to assert output tensors are detached from the
autograd graph (skipped for models with ``retain_graph=True``).
"""
from torch_sim.models.interface import validate_model_outputs

def test_model_output_validation(request: pytest.FixtureRequest) -> None:
"""Test that a model implementation follows the ModelInterface contract."""
model: ModelInterface = request.getfixturevalue(model_fixture_name)
validate_model_outputs(model, device, dtype)
validate_model_outputs(model, device, dtype, check_detached=check_detached)

test_model_output_validation.__name__ = f"test_{model_fixture_name}_output_validation"
return test_model_output_validation
22 changes: 11 additions & 11 deletions tests/models/test_fairchem.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
@pytest.fixture
def eqv2_uma_model_pbc() -> FairChemModel:
"""UMA model for periodic boundary condition systems."""
return FairChemModel(model="uma-s-1", task_name="omat", device=DEVICE)
return FairChemModel(model="uma-s-1p1", task_name="omat", device=DEVICE)


@pytest.mark.skipif(
Expand All @@ -40,7 +40,7 @@ def eqv2_uma_model_pbc() -> FairChemModel:
def test_task_initialization(task_name: str) -> None:
"""Test that different UMA task names work correctly."""
model = FairChemModel(
model="uma-s-1", task_name=task_name, device=torch.device("cpu")
model="uma-s-1p1", task_name=task_name, device=torch.device("cpu")
)
assert model.task_name
assert str(model.task_name.value) == task_name
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_homogeneous_batching(task_name: str, systems_func: Callable) -> None:
for mol in systems:
mol.info |= {"charge": 0, "spin": 1}

model = FairChemModel(model="uma-s-1", task_name=task_name, device=DEVICE)
model = FairChemModel(model="uma-s-1p1", task_name=task_name, device=DEVICE)
state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE)
results = model(state)

Expand Down Expand Up @@ -109,7 +109,7 @@ def test_heterogeneous_tasks() -> None:
systems[0].info |= {"charge": 0, "spin": 1}

model = FairChemModel(
model="uma-s-1",
model="uma-s-1p1",
task_name=task_name,
device=DEVICE,
)
Expand Down Expand Up @@ -150,7 +150,7 @@ def test_batch_size_variations(systems_func: Callable, expected_count: int) -> N
"""Test batching with different numbers and sizes of systems."""
systems = systems_func()

model = FairChemModel(model="uma-s-1", task_name="omat", device=DEVICE)
model = FairChemModel(model="uma-s-1p1", task_name="omat", device=DEVICE)
state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE)
results = model(state)

Expand All @@ -170,7 +170,7 @@ def test_stress_computation(*, compute_stress: bool) -> None:
systems = [bulk("Si", "diamond", a=5.43), bulk("Al", "fcc", a=4.05)]

model = FairChemModel(
model="uma-s-1",
model="uma-s-1p1",
task_name="omat",
device=DEVICE,
compute_stress=compute_stress,
Expand All @@ -191,7 +191,7 @@ def test_stress_computation(*, compute_stress: bool) -> None:
)
def test_device_consistency() -> None:
"""Test device consistency between model and data."""
model = FairChemModel(model="uma-s-1", task_name="omat", device=DEVICE)
model = FairChemModel(model="uma-s-1p1", task_name="omat", device=DEVICE)
system = bulk("Si", "diamond", a=5.43)
state = ts.io.atoms_to_state([system], device=DEVICE, dtype=DTYPE)

Expand All @@ -205,7 +205,7 @@ def test_device_consistency() -> None:
)
def test_empty_batch_error() -> None:
"""Test that empty batches raise appropriate errors."""
model = FairChemModel(model="uma-s-1", task_name="omat", device=torch.device("cpu"))
model = FairChemModel(model="uma-s-1p1", task_name="omat", device=torch.device("cpu"))
with pytest.raises((ValueError, RuntimeError, IndexError)):
model(ts.io.atoms_to_state([], device=torch.device("cpu"), dtype=torch.float32))

Expand All @@ -215,7 +215,7 @@ def test_empty_batch_error() -> None:
)
def test_load_from_checkpoint_path() -> None:
"""Test loading model from a saved checkpoint file path."""
checkpoint_path = pretrained_checkpoint_path_from_name("uma-s-1")
checkpoint_path = pretrained_checkpoint_path_from_name("uma-s-1p1")
loaded_model = FairChemModel(
model=str(checkpoint_path), task_name="omat", device=DEVICE
)
Expand Down Expand Up @@ -274,7 +274,7 @@ def test_fairchem_charge_spin(charge: float, spin: float) -> None:

# Create model with UMA omol task (supports charge/spin for molecules)
model = FairChemModel(
model="uma-s-1",
model="uma-s-1p1",
task_name="omol",
device=DEVICE,
)
Expand Down Expand Up @@ -305,7 +305,7 @@ def test_fairchem_single_step_relax(rattled_si_sim_state: ts.SimState) -> None:
that it doesn't have issues with the computational graph (e.g., missing
.detach() calls).
"""
model = FairChemModel(model="uma-s-1", task_name="omat", device=DEVICE)
model = FairChemModel(model="uma-s-1p1", task_name="omat", device=DEVICE)
state = rattled_si_sim_state.to(device=DEVICE, dtype=DTYPE)

# Initialize FIRE optimizer
Expand Down
51 changes: 50 additions & 1 deletion torch_sim/models/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,45 @@ def forward(self, state: SimState, **kwargs) -> dict[str, torch.Tensor]:
"""


def _check_output_detached(
output: dict[str, torch.Tensor], model: ModelInterface
) -> None:
"""Check that output tensors match the model's graph retention setting.
When ``retain_graph`` is absent or ``False``, all tensors must be detached.
When ``retain_graph`` is ``True``, all tensors must have ``requires_grad``.
Args:
output: Model output dictionary mapping keys to tensors.
model: The model that produced the output.
Raises:
ValueError: If tensors are not detached when ``retain_graph`` is
``False``, or lack gradients when ``retain_graph`` is ``True``.
"""
retain_graph = getattr(model, "retain_graph", False)
for key, tensor in output.items():
if not isinstance(tensor, torch.Tensor):
continue
if retain_graph and not tensor.requires_grad:
raise ValueError(
f"Output tensor '{key}' does not have gradients but model.retain_graph "
"is True. Ensure the tensor is part of the computation graph."
)
if not retain_graph and tensor.requires_grad:
raise ValueError(
f"Output tensor '{key}' is not detached from the computation graph. "
"Call .detach() on the tensor before returning it, or set "
"model.retain_graph = True if graph retention is intentional."
)


def validate_model_outputs( # noqa: C901, PLR0915
model: ModelInterface, device: torch.device, dtype: torch.dtype
model: ModelInterface,
device: torch.device,
dtype: torch.dtype,
*,
check_detached: bool = False,
) -> None:
"""Validate the outputs of a model implementation against the interface requirements.
Expand All @@ -181,6 +218,10 @@ def validate_model_outputs( # noqa: C901, PLR0915
model (ModelInterface): Model implementation to validate.
device (torch.device): Device to run the validation tests on.
dtype (torch.dtype): Data type to use for validation tensors.
check_detached (bool): If ``True``, assert that all output tensors are
detached from the autograd graph, unless the model has a
``retain_graph`` attribute set to ``True``. Defaults to ``False`` so
that external callers are not immediately broken.
Raises:
AssertionError: If the model doesn't conform to the required interface,
Expand Down Expand Up @@ -229,8 +270,16 @@ def validate_model_outputs( # noqa: C901, PLR0915
og_system_idx = system_idx.clone()
og_atomic_nums = sim_state.atomic_numbers.clone()

if check_detached and hasattr(model, "retain_graph"):
model.__dict__["retain_graph"] = True
_check_output_detached(model.forward(sim_state), model)
model.__dict__["retain_graph"] = False

model_output = model.forward(sim_state)

if check_detached:
_check_output_detached(model_output, model)

# assert model did not mutate the input
if not torch.allclose(og_positions, sim_state.positions):
raise ValueError(f"{og_positions=} != {sim_state.positions=}")
Expand Down
7 changes: 7 additions & 0 deletions torch_sim/models/orb.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@
class OrbModel(OrbTorchSimModel):
"""ORB model wrapper for torch-sim."""

def forward(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
"""Run forward pass, detaching outputs unless retain_graph is True."""
output = super().forward(*args, **kwargs)
return { # detach tensors as energy is not detached by default
k: v.detach() if hasattr(v, "detach") else v for k, v in output.items()
}

except ImportError as exc:
warnings.warn(f"Orb import failed: {traceback.format_exc()}", stacklevel=2)

Expand Down
2 changes: 1 addition & 1 deletion torch_sim/models/pair_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
)
)

return results
return {k: v.detach() for k, v in results.items()}


class PairForcesModel(ModelInterface):
Expand Down
Loading