diff --git a/tests/test_constraints.py b/tests/test_constraints.py index dd092a70..76932e96 100644 --- a/tests/test_constraints.py +++ b/tests/test_constraints.py @@ -704,14 +704,66 @@ def test_count_degrees_of_freedom_multi_system_sum( _assert_dof_per_system(mixed_double_sim_state, constraint_list, expected_dof) -def test_count_degrees_of_freedom_clamped_to_zero( +def test_count_degrees_of_freedom_partial_system_constraint( + mixed_double_sim_state: ts.SimState, +) -> None: + """count_degrees_of_freedom only changes systems targeted by a constraint.""" + total_dof_per_system = 3 * mixed_double_sim_state.n_atoms_per_system + constraint_list: list[Constraint] = [FixCom([0])] + expected_dof = total_dof_per_system.clone() + expected_dof[0] -= 3 + _assert_dof_per_system(mixed_double_sim_state, constraint_list, expected_dof.tolist()) + + +def test_count_degrees_of_freedom_matches_sim_state_method( + mixed_double_sim_state: ts.SimState, +) -> None: + """DOF helper and SimState method agree for strictly positive DOF.""" + n_atoms_in_first_system = int(mixed_double_sim_state.n_atoms_per_system[0].item()) + constraint_list: list[Constraint] = [FixAtoms(atom_idx=[0, n_atoms_in_first_system])] + mixed_double_sim_state.constraints = constraint_list + dof_from_method = mixed_double_sim_state.get_number_of_degrees_of_freedom() + dof_from_helper = count_degrees_of_freedom(mixed_double_sim_state, constraint_list) + assert torch.equal(dof_from_method, dof_from_helper) + + +def test_count_degrees_of_freedom_none_constraints_returns_unconstrained( + mixed_double_sim_state: ts.SimState, +) -> None: + """Omitting constraints returns unconstrained DOF (3 * n_atoms_per_system).""" + expected_dof = 3 * mixed_double_sim_state.n_atoms_per_system + dof = count_degrees_of_freedom(mixed_double_sim_state) + assert torch.equal(dof, expected_dof) + + +def test_count_degrees_of_freedom_helper_clamps_but_state_method_raises( cu_sim_state: ts.SimState, ) -> None: - """count_degrees_of_freedom clamps per-system values at zero.""" + """Helper clamps zero DOF while SimState method rejects non-positive DOF.""" all_atom_indices = torch.arange(cu_sim_state.n_atoms, device=cu_sim_state.device) constraint_list: list[Constraint] = [FixAtoms(atom_idx=all_atom_indices), FixCom([0])] _assert_dof_per_system(cu_sim_state, constraint_list, [0]) + cu_sim_state.constraints = constraint_list + with pytest.raises(ValueError, match="Degrees of freedom cannot be zero or negative"): + cu_sim_state.get_number_of_degrees_of_freedom() + + +@pytest.mark.parametrize( + "invalid_constraints", + [ + [FixAtoms(atom_idx=[999])], + [FixCom([999])], + ], +) +def test_count_degrees_of_freedom_rejects_out_of_bounds_constraint( + cu_sim_state: ts.SimState, + invalid_constraints: list[Constraint], +) -> None: + """count_degrees_of_freedom rejects constraints with out-of-bounds indices.""" + with pytest.raises(ValueError, match=r"has indices up to.*only has.*atoms"): + count_degrees_of_freedom(cu_sim_state, invalid_constraints) + @pytest.mark.parametrize( ("cell_filter", "fire_flavor"), diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index d1557a4b..1a55de2e 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -613,28 +613,35 @@ def __repr__(self) -> str: def count_degrees_of_freedom( state: SimState, constraints: list[Constraint] | None = None ) -> torch.Tensor: - """Count the total degrees of freedom in a system with constraints. + """Count per-system degrees of freedom with compatibility checks. - This function calculates the total number of degrees of freedom by starting - with the unconstrained count (n_atoms * 3) and subtracting the degrees of - freedom removed by each constraint. + This helper computes one DOF value per system. When ``constraints`` are + supplied, it validates that they are compatible with ``state`` before + counting. Args: state: Simulation state - constraints: List of active constraints (optional) + constraints: Constraints to evaluate. If ``None``, returns unconstrained + DOF (3 * n_atoms_per_system). Use ``state.get_number_of_degrees_of_freedom()`` + to count with state-attached constraints. Returns: Degrees of freedom per system as a tensor of shape (n_systems,) """ - # Start with unconstrained DOF per system - total_dof = 3 * state.n_atoms_per_system + if constraints is not None: + validate_constraints(constraints, state) + return torch.clamp(_dof_per_system(state, constraints), min=0) + - # Subtract DOF removed by constraints +def _dof_per_system( + state: SimState, constraints: list[Constraint] | None = None +) -> torch.Tensor: + """Compute unconstrained-minus-removed DOF per system.""" + dof_per_system = 3 * state.n_atoms_per_system if constraints is not None: for constraint in constraints: - total_dof -= constraint.get_removed_dof(state) - - return torch.clamp(total_dof, min=0) + dof_per_system -= constraint.get_removed_dof(state) + return dof_per_system def check_no_index_out_of_bounds( diff --git a/torch_sim/state.py b/torch_sim/state.py index 3fb91e39..76021b52 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -22,7 +22,12 @@ from phonopy.structure.atoms import PhonopyAtoms from pymatgen.core import Structure -from torch_sim.constraints import Constraint, merge_constraints, validate_constraints +from torch_sim.constraints import ( + Constraint, + _dof_per_system, + merge_constraints, + validate_constraints, +) def coerce_prng(rng: PRNGLike, device: torch.device) -> torch.Generator: @@ -415,17 +420,13 @@ def get_number_of_degrees_of_freedom(self) -> torch.Tensor: torch.Tensor: Number of degrees of freedom per system, with shape (n_systems,). Each system starts with 3 * n_atoms_per_system degrees of freedom, minus any degrees removed by constraints. - """ - # Start with unconstrained DOF: 3 degrees per atom - dof_per_system = 3 * self.n_atoms_per_system - - # Subtract DOF removed by constraints - if self.constraints is not None: - for constraint in self.constraints: - removed_dof = constraint.get_removed_dof(self) - dof_per_system -= removed_dof - # Ensure non-negative DOF + Raises: + ValueError: If any system has zero or negative degrees of freedom. + This strict behavior is used by simulation routines that require + physically valid DOF. + """ + dof_per_system = _dof_per_system(self, self.constraints) if (dof_per_system <= 0).any(): raise ValueError("Degrees of freedom cannot be zero or negative") return dof_per_system