Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 54 additions & 2 deletions tests/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,14 +704,66 @@ def test_count_degrees_of_freedom_multi_system_sum(
_assert_dof_per_system(mixed_double_sim_state, constraint_list, expected_dof)


def test_count_degrees_of_freedom_clamped_to_zero(
def test_count_degrees_of_freedom_partial_system_constraint(
mixed_double_sim_state: ts.SimState,
) -> None:
"""count_degrees_of_freedom only changes systems targeted by a constraint."""
total_dof_per_system = 3 * mixed_double_sim_state.n_atoms_per_system
constraint_list: list[Constraint] = [FixCom([0])]
expected_dof = total_dof_per_system.clone()
expected_dof[0] -= 3
_assert_dof_per_system(mixed_double_sim_state, constraint_list, expected_dof.tolist())


def test_count_degrees_of_freedom_matches_sim_state_method(
mixed_double_sim_state: ts.SimState,
) -> None:
"""DOF helper and SimState method agree for strictly positive DOF."""
n_atoms_in_first_system = int(mixed_double_sim_state.n_atoms_per_system[0].item())
constraint_list: list[Constraint] = [FixAtoms(atom_idx=[0, n_atoms_in_first_system])]
mixed_double_sim_state.constraints = constraint_list
dof_from_method = mixed_double_sim_state.get_number_of_degrees_of_freedom()
dof_from_helper = count_degrees_of_freedom(mixed_double_sim_state, constraint_list)
assert torch.equal(dof_from_method, dof_from_helper)


def test_count_degrees_of_freedom_none_constraints_returns_unconstrained(
mixed_double_sim_state: ts.SimState,
) -> None:
"""Omitting constraints returns unconstrained DOF (3 * n_atoms_per_system)."""
expected_dof = 3 * mixed_double_sim_state.n_atoms_per_system
dof = count_degrees_of_freedom(mixed_double_sim_state)
assert torch.equal(dof, expected_dof)


def test_count_degrees_of_freedom_helper_clamps_but_state_method_raises(
cu_sim_state: ts.SimState,
) -> None:
"""count_degrees_of_freedom clamps per-system values at zero."""
"""Helper clamps zero DOF while SimState method rejects non-positive DOF."""
all_atom_indices = torch.arange(cu_sim_state.n_atoms, device=cu_sim_state.device)
constraint_list: list[Constraint] = [FixAtoms(atom_idx=all_atom_indices), FixCom([0])]
_assert_dof_per_system(cu_sim_state, constraint_list, [0])

cu_sim_state.constraints = constraint_list
with pytest.raises(ValueError, match="Degrees of freedom cannot be zero or negative"):
cu_sim_state.get_number_of_degrees_of_freedom()


@pytest.mark.parametrize(
"invalid_constraints",
[
[FixAtoms(atom_idx=[999])],
[FixCom([999])],
],
)
def test_count_degrees_of_freedom_rejects_out_of_bounds_constraint(
cu_sim_state: ts.SimState,
invalid_constraints: list[Constraint],
) -> None:
"""count_degrees_of_freedom rejects constraints with out-of-bounds indices."""
with pytest.raises(ValueError, match=r"has indices up to.*only has.*atoms"):
count_degrees_of_freedom(cu_sim_state, invalid_constraints)


@pytest.mark.parametrize(
("cell_filter", "fire_flavor"),
Expand Down
29 changes: 18 additions & 11 deletions torch_sim/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,28 +613,35 @@ def __repr__(self) -> str:
def count_degrees_of_freedom(
state: SimState, constraints: list[Constraint] | None = None
) -> torch.Tensor:
"""Count the total degrees of freedom in a system with constraints.
"""Count per-system degrees of freedom with compatibility checks.

This function calculates the total number of degrees of freedom by starting
with the unconstrained count (n_atoms * 3) and subtracting the degrees of
freedom removed by each constraint.
This helper computes one DOF value per system. When ``constraints`` are
supplied, it validates that they are compatible with ``state`` before
counting.

Args:
state: Simulation state
constraints: List of active constraints (optional)
constraints: Constraints to evaluate. If ``None``, returns unconstrained
DOF (3 * n_atoms_per_system). Use ``state.get_number_of_degrees_of_freedom()``
to count with state-attached constraints.

Returns:
Degrees of freedom per system as a tensor of shape (n_systems,)
"""
# Start with unconstrained DOF per system
total_dof = 3 * state.n_atoms_per_system
if constraints is not None:
validate_constraints(constraints, state)
return torch.clamp(_dof_per_system(state, constraints), min=0)


# Subtract DOF removed by constraints
def _dof_per_system(
state: SimState, constraints: list[Constraint] | None = None
) -> torch.Tensor:
"""Compute unconstrained-minus-removed DOF per system."""
dof_per_system = 3 * state.n_atoms_per_system
if constraints is not None:
for constraint in constraints:
total_dof -= constraint.get_removed_dof(state)

return torch.clamp(total_dof, min=0)
dof_per_system -= constraint.get_removed_dof(state)
return dof_per_system


def check_no_index_out_of_bounds(
Expand Down
23 changes: 12 additions & 11 deletions torch_sim/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
from phonopy.structure.atoms import PhonopyAtoms
from pymatgen.core import Structure

from torch_sim.constraints import Constraint, merge_constraints, validate_constraints
from torch_sim.constraints import (
Constraint,
_dof_per_system,
merge_constraints,
validate_constraints,
)


def coerce_prng(rng: PRNGLike, device: torch.device) -> torch.Generator:
Expand Down Expand Up @@ -415,17 +420,13 @@ def get_number_of_degrees_of_freedom(self) -> torch.Tensor:
torch.Tensor: Number of degrees of freedom per system, with shape
(n_systems,). Each system starts with 3 * n_atoms_per_system degrees
of freedom, minus any degrees removed by constraints.
"""
# Start with unconstrained DOF: 3 degrees per atom
dof_per_system = 3 * self.n_atoms_per_system

# Subtract DOF removed by constraints
if self.constraints is not None:
for constraint in self.constraints:
removed_dof = constraint.get_removed_dof(self)
dof_per_system -= removed_dof

# Ensure non-negative DOF
Raises:
ValueError: If any system has zero or negative degrees of freedom.
This strict behavior is used by simulation routines that require
physically valid DOF.
"""
dof_per_system = _dof_per_system(self, self.constraints)
if (dof_per_system <= 0).any():
raise ValueError("Degrees of freedom cannot be zero or negative")
return dof_per_system
Expand Down
Loading