diff --git a/.gitignore b/.gitignore index 8fe23290..b0bdd821 100644 --- a/.gitignore +++ b/.gitignore @@ -46,3 +46,6 @@ uv.lock # agents .agents/ + +# treemap html file +docs/_static/torch-sim-pkg-treemap.html diff --git a/docs/_static/torch-sim-pkg-treemap.svg b/docs/_static/torch-sim-pkg-treemap.svg index 81070aee..c7344fa9 100644 --- a/docs/_static/torch-sim-pkg-treemap.svg +++ b/docs/_static/torch-sim-pkg-treemap.svg @@ -1 +1 @@ -torch_sim (18,642 lines, 100.0%)modelsintegratorsoptimizerstrajectory9875%state9675%transforms9495%autobatching9495%elastic8935%constraints8555%runners7024%workflowsmath6173%neighborspropertiesio3862%testing3352%quantities2731%monte_carlo1961%symmetrize1771%units991%telemetry740.397%typing350.188%soft_sphere6644%pair_potential6163%lennard_jones3992%fairchem_legacy3632%morse3142%mace2882%interface2741%metatomic2271%particle_life2181%fairchem1891%graphpes_framework1471%mattersim1231%nequip_framework350.188%orb310.166%sevennet230.123%graphpes110.059%npt1,77310%nvt5163%md4242%nve860.461%lbfgs4422%bfgs3832%cell_filters3512%fire3422%state1201%gradient_descent820.44%a2c6864%torch_nl2241%vesin2191%alchemiops1511%correlations4272%torch-sim Package Structure +torch_sim (17,540 lines, 100.0%)modelsintegratorsoptimizerstrajectory9846%state9676%transforms9495%autobatching9495%elastic8935%constraints8555%runners7064%workflowsmath6174%neighborspropertiesio3862%testing3352%quantities2732%monte_carlo1961%symmetrize1771%units991%telemetry740.422%typing350.2%pair_potential4543%fairchem_legacy3632%soft_sphere3202%mace2882%interface2742%metatomic2271%fairchem1891%graphpes_framework1471%mattersim1231%morse1201%lennard_jones1141%particle_life1031%nequip_framework350.2%orb310.177%sevennet230.131%graphpes110.0627%npt1,77310%nvt5163%md4242%nve860.49%lbfgs4423%bfgs3832%cell_filters3512%fire3422%state1201%gradient_descent820.468%a2c6834%torch_nl2241%vesin2191%alchemiops1511%correlations4272%torch-sim Package Structure diff --git a/examples/tutorials/diff_sim.py b/examples/tutorials/diff_sim.py index 1092c275..cf96c6e6 100644 --- a/examples/tutorials/diff_sim.py +++ b/examples/tutorials/diff_sim.py @@ -9,17 +9,19 @@ # # %% +import typing import torch import matplotlib.pyplot as plt +from torch_sim.state import SimState from torch_sim.models.soft_sphere import ( + SoftSphereMultiModel, soft_sphere_pair, - DEFAULT_SIGMA, - DEFAULT_EPSILON, - DEFAULT_ALPHA, ) -from torch_sim import transforms -from collections.abc import Callable -from dataclasses import dataclass +from torch_sim.neighbors import torch_nl_n2 +from torch_sim.optimizers.gradient_descent import ( + gradient_descent_init, + gradient_descent_step, +) from torch._functorch import config config.donated_buffer = False @@ -99,244 +101,16 @@ def draw_system( plt.xlim([0, 1.5]) plt.ylim([-0.2, 0.8]) -# model = SoftSphereMultiModel(sigma_matrix=torch.tensor([1.0])) dr = torch.linspace(0, 3.0, 80) -plt.plot(dr, soft_sphere_pair(dr, sigma=1), "b-", linewidth=3) -plt.fill_between(dr, soft_sphere_pair(dr), alpha=0.4) +_z = torch.zeros_like(dr, dtype=torch.long) +plt.plot(dr, soft_sphere_pair(dr, _z, _z, sigma=1), "b-", linewidth=3) +plt.fill_between(dr, soft_sphere_pair(dr, _z, _z), alpha=0.4) plt.xlabel(r"$r$", fontsize=20) plt.ylabel(r"$U(r)$", fontsize=20) plt.show() -# %% [markdown] -""" -## Define the simple TorchSim model for the soft sphere potential. -""" - - -# %% -@dataclass -class BaseState: - """Simple simulation state""" - - positions: torch.Tensor - cell: torch.Tensor - pbc: torch.Tensor - species: torch.Tensor - - -class SoftSphereMultiModel(torch.nn.Module): - """Soft sphere potential""" - - def __init__( - self, - species: torch.Tensor | None = None, - sigma_matrix: torch.Tensor | None = None, - epsilon_matrix: torch.Tensor | None = None, - alpha_matrix: torch.Tensor | None = None, - device: torch.device | None = None, - dtype: torch.dtype = torch.float32, - *, # Force keyword-only arguments - pbc: torch.Tensor | bool = True, - cutoff: float | None = None, - ) -> None: - """Initialize a soft sphere model for multi-component systems.""" - super().__init__() - self.device = device or torch.device("cpu") - self.dtype = dtype - self.pbc = ( - pbc - if isinstance(pbc, torch.Tensor) - else torch.tensor([pbc] * 3, dtype=torch.bool) - ) - - # Store species list and determine number of unique species - self.species = species - n_species = len(torch.unique(species)) - - # Initialize parameter matrices with defaults if not provided - default_sigma = DEFAULT_SIGMA.to(device=self.device, dtype=self.dtype) - default_epsilon = DEFAULT_EPSILON.to(device=self.device, dtype=self.dtype) - default_alpha = DEFAULT_ALPHA.to(device=self.device, dtype=self.dtype) - - # Validate matrix shapes match number of species - if sigma_matrix is not None and sigma_matrix.shape != (n_species, n_species): - raise ValueError(f"sigma_matrix must have shape ({n_species}, {n_species})") - if epsilon_matrix is not None and epsilon_matrix.shape != ( - n_species, - n_species, - ): - raise ValueError(f"epsilon_matrix must have shape ({n_species}, {n_species})") - if alpha_matrix is not None and alpha_matrix.shape != (n_species, n_species): - raise ValueError(f"alpha_matrix must have shape ({n_species}, {n_species})") - - # Create parameter matrices, using defaults if not provided - self.sigma_matrix = ( - sigma_matrix - if sigma_matrix is not None - else default_sigma - * torch.ones((n_species, n_species), dtype=dtype, device=device) - ) - self.epsilon_matrix = ( - epsilon_matrix - if epsilon_matrix is not None - else default_epsilon - * torch.ones((n_species, n_species), dtype=dtype, device=device) - ) - self.alpha_matrix = ( - alpha_matrix - if alpha_matrix is not None - else default_alpha - * torch.ones((n_species, n_species), dtype=dtype, device=device) - ) - - # Ensure parameter matrices are symmetric (required for energy conservation) - for matrix_name in ("sigma_matrix", "epsilon_matrix", "alpha_matrix"): - matrix = getattr(self, matrix_name) - if not torch.allclose(matrix, matrix.T): - raise ValueError(f"{matrix_name} is not symmetric") - - # Set interaction cutoff distance - self.cutoff = torch.tensor( - cutoff or float(self.sigma_matrix.max()), dtype=dtype, device=device - ) - - def forward( - self, - custom_state: BaseState, - species: torch.Tensor | None = None, - ) -> dict[str, torch.Tensor]: - """Compute energies and forces for a single unbatched system with multiple - species.""" - # Convert inputs to proper device/dtype and handle species - positions = custom_state.positions.requires_grad_(True) - cell = custom_state.cell - species = custom_state.species - - if species is not None: - species = species.to(device=self.device, dtype=torch.long) - else: - species = self.species - - species_idx = species - - # Direct N^2 computation of all pairs (minimum image convention) - dr_vec, distances = transforms.get_pair_displacements( - positions=positions, - cell=cell, - pbc=self.pbc, - ) - # Remove self-interactions and apply cutoff - mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device) - distances = distances.masked_fill(mask, float("inf")) - mask = distances < self.cutoff - - # Get valid pairs and their displacements - i, j = torch.where(mask) - mapping = torch.stack([j, i]) - dr_vec = dr_vec[mask] - distances = distances[mask] - - # Look up species-specific parameters for each interacting pair - pair_species_1 = species_idx[mapping[0]] # Species of first atom in pair - pair_species_2 = species_idx[mapping[1]] # Species of second atom in pair - - # Get interaction parameters from parameter matrices - pair_sigmas = self.sigma_matrix[pair_species_1, pair_species_2] - pair_epsilons = self.epsilon_matrix[pair_species_1, pair_species_2] - pair_alphas = self.alpha_matrix[pair_species_1, pair_species_2] - - # Calculate pair energies using species-specific parameters - pair_energies = soft_sphere_pair( - distances, sigma=pair_sigmas, epsilon=pair_epsilons, alpha=pair_alphas - ) - - # Initialize results with total energy (divide by 2 to avoid double counting) - potential_energy = pair_energies.sum() / 2 - - grad_outputs: list[torch.Tensor] = [ - torch.ones_like( - potential_energy, - ) - ] - grad = torch.autograd.grad( - outputs=[ - potential_energy, - ], - inputs=[positions], - grad_outputs=grad_outputs, - create_graph=False, - retain_graph=True, - ) - - force_grad = grad[0] - if force_grad is not None: - forces = torch.neg(force_grad) - - return {"energy": potential_energy, "forces": forces} - - -# %% [markdown] -""" -## Gradient Descent - -We will use a simple gradient descent to optimize the positions of the particles. -""" - - -# %% -@dataclass -class GDState(BaseState): - """Simple simulation state""" - - forces: torch.Tensor - energy: torch.Tensor - - -def gradient_descent( - model: torch.nn.Module, *, lr: torch.Tensor | float = 0.01 -) -> tuple[Callable[[BaseState], GDState], Callable[[GDState], GDState]]: - """Initialize a gradient descent optimization.""" - - def gd_init( - state: BaseState, - ) -> GDState: - """Initialize the gradient descent optimization state.""" - - # Get initial forces and energy from model - model_output = model(state) - energy = model_output["energy"] - forces = model_output["forces"] - - return GDState( - positions=state.positions, - forces=forces, - energy=energy, - cell=state.cell, - pbc=state.pbc, - species=state.species, - ) - - def gd_step(state: GDState, lr: torch.Tensor | float = lr) -> GDState: - """Perform one gradient descent optimization step to update the - atomic positions. The cell is not optimized.""" - - # Update positions using forces and per-atom learning rates - state.positions = state.positions + lr * state.forces - - # Get updated forces and energy from model - model_output = model(state) - - # Update state with new forces and energy - state.forces = model_output["forces"] - state.energy = model_output["energy"] - - return state - - return gd_init, gd_step - - # %% [markdown] """ ## Setup the simulation environment. @@ -358,15 +132,14 @@ def box_size_at_packing_fraction( def species_sigma(diameter: torch.Tensor) -> torch.Tensor: - d_AA = diameter - d_BB = 1 + d_BB = torch.ones_like(diameter) d_AB = 0.5 * (diameter + 1) - return torch.tensor([[d_AA, d_AB], [d_AB, d_BB]]) + return torch.stack([diameter, d_AB, d_AB, d_BB]).reshape(2, 2) N = 128 N_2 = N // 2 -species = torch.tensor([0] * (N_2) + [1] * (N_2), dtype=torch.int32) +species = torch.tensor([0] * (N_2) + [1] * (N_2), dtype=torch.long) simulation_steps = 1000 packing_fraction = 0.98 markersize = 260 @@ -376,30 +149,30 @@ def species_sigma(diameter: torch.Tensor) -> torch.Tensor: def simulation( diameter: torch.Tensor, seed: int = 42 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # Create the simulation environment. box_size = box_size_at_packing_fraction(diameter, packing_fraction) - cell = torch.eye(3) * box_size - # Create the energy function. + cell = (torch.eye(3) * box_size).unsqueeze(0) sigma = species_sigma(diameter) - model = SoftSphereMultiModel(sigma_matrix=sigma, species=species) - model = torch.compile(model) - # Randomly initialize the system. - # Fix seed for reproducible random positions + model = SoftSphereMultiModel( + atomic_numbers=species, + sigma_matrix=sigma, + dtype=torch.float32, + neighbor_list_fn=torch_nl_n2, + retain_graph=True, + ) + # Use aot_eager backend as Inductor has issues with scatter operations (index_add/scatter_add) + model = typing.cast(SoftSphereMultiModel, torch.compile(model, backend="aot_eager")) torch.manual_seed(seed) R = torch.rand(N, 3) * box_size - - # Minimize to the nearest minimum. - init_fn, apply_fn = gradient_descent(model, lr=0.1) # ty: ignore[invalid-argument-type] - - custom_state = BaseState( + state = SimState( positions=R, + masses=torch.ones(N), cell=cell, - species=species, - pbc=torch.tensor([True] * 3, dtype=torch.bool), + pbc=True, + atomic_numbers=species, ) - state = init_fn(custom_state) + state = gradient_descent_init(state, model) for _ in range(simulation_steps): - state = apply_fn(state) + state = gradient_descent_step(state, model, pos_lr=0.1) return box_size, model(state)["energy"], state.positions @@ -482,41 +255,28 @@ def short_simulation( ) -> tuple[torch.Tensor, torch.Tensor]: diameter = diameter.requires_grad_(True) box_size = box_size_at_packing_fraction(diameter, packing_fraction) - cell = torch.eye(3) * box_size - # Create the energy function. + cell = (torch.eye(3) * box_size).unsqueeze(0) sigma = species_sigma(diameter) - model = SoftSphereMultiModel(sigma_matrix=sigma, species=species) - - # Minimize to the nearest minimum. - init_fn, apply_fn = gradient_descent(model, lr=0.1) - - custom_state = BaseState( + model = SoftSphereMultiModel( + atomic_numbers=species, + sigma_matrix=sigma, + dtype=torch.float32, + neighbor_list_fn=torch_nl_n2, + retain_graph=True, + ) + state = SimState( positions=R, + masses=torch.ones(N), cell=cell, - species=species, - pbc=torch.tensor([True, True, True], dtype=torch.bool), - ) - state = init_fn(custom_state) - for i in range(short_simulation_steps): - state = apply_fn(state) - - grad_outputs: list[torch.Tensor] = [ - torch.ones_like( - diameter, - ) - ] - grad = torch.autograd.grad( - outputs=[ - model(state)["energy"], - ], - inputs=[diameter], - grad_outputs=grad_outputs, - create_graph=True, - retain_graph=False, + pbc=True, + atomic_numbers=species, ) - - dU_dd = grad[0] - return model(state)["energy"], dU_dd + state = gradient_descent_init(state, model) + for _ in range(short_simulation_steps): + state = gradient_descent_step(state, model, pos_lr=0.1) + energy = model(state)["energy"] + (dU_dd,) = torch.autograd.grad(energy, diameter, create_graph=True) + return energy, dU_dd # %% diff --git a/pyproject.toml b/pyproject.toml index 38e08515..4322706b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,6 +111,7 @@ ignore = [ "PLR2004", # Magic value used in comparison, consider replacing {value} with a constant variable "PLW2901", # Outer for loop variable overwritten by inner assignment target "PTH", # flake8-use-pathlib + "RUF002", # Greek letters are discouraged "S301", # pickle and modules that wrap it can be unsafe, possible security issue "S311", # Standard pseudo-random generators are not suitable for cryptographic purposes "SIM105", # Use contextlib.suppress instead of try-except-pass diff --git a/tests/models/test_lennard_jones.py b/tests/models/test_lennard_jones.py index f70817e4..a3e5ed5a 100644 --- a/tests/models/test_lennard_jones.py +++ b/tests/models/test_lennard_jones.py @@ -1,408 +1,137 @@ -"""Cheap integration tests ensuring different parts of TorchSim work together.""" +"""Tests for the Lennard-Jones pair functions and wrapped model.""" -from collections.abc import Callable - -import pytest import torch -from ase.build import bulk import torch_sim as ts -from tests.conftest import DEVICE -from torch_sim import neighbors -from torch_sim.models.interface import validate_model_outputs from torch_sim.models.lennard_jones import ( LennardJonesModel, - UnbatchedLennardJonesModel, lennard_jones_pair, lennard_jones_pair_force, ) +def _dummy_z(n: int) -> torch.Tensor: + return torch.ones(n, dtype=torch.long) + + def test_lennard_jones_pair_minimum() -> None: - """Test that the potential has its minimum at r=sigma.""" - dr = torch.linspace(0.8, 1.2, 100) - dr = dr.reshape(-1, 1) - energy = lennard_jones_pair(dr, sigma=1.0, epsilon=1.0) - min_idx = torch.argmin(energy) + """Minimum of LJ is at r = 2^(1/6) * sigma.""" + dr = torch.linspace(0.9, 1.5, 500) + z = _dummy_z(len(dr)) + energies = lennard_jones_pair(dr, z, z, sigma=1.0, epsilon=1.0) + min_r = dr[energies.argmin()] + assert abs(min_r.item() - 2 ** (1 / 6)) < 0.01 - torch.testing.assert_close( - dr[min_idx], torch.tensor([2 ** (1 / 6)]), rtol=1e-2, atol=1e-2 - ) + +def test_lennard_jones_pair_energy_at_minimum() -> None: + """Energy at minimum equals -epsilon.""" + r_min = torch.tensor([2 ** (1 / 6)]) + z = _dummy_z(1) + e = lennard_jones_pair(r_min, z, z, sigma=1.0, epsilon=2.0) + torch.testing.assert_close(e, torch.tensor([-2.0]), rtol=1e-5, atol=1e-5) -def test_lennard_jones_pair_scaling() -> None: - """Test that the potential scales correctly with epsilon.""" - dr = torch.ones(5, 5) * 1.5 - e1 = lennard_jones_pair(dr, sigma=1.0, epsilon=1.0) - e2 = lennard_jones_pair(dr, sigma=1.0, epsilon=2.0) - torch.testing.assert_close(e2, 2 * e1) +def test_lennard_jones_pair_epsilon_scaling() -> None: + """Energy scales linearly with epsilon.""" + dr = torch.tensor([1.5]) + z = _dummy_z(1) + e1 = lennard_jones_pair(dr, z, z, sigma=1.0, epsilon=1.0) + e2 = lennard_jones_pair(dr, z, z, sigma=1.0, epsilon=3.0) + torch.testing.assert_close(e2, 3.0 * e1) def test_lennard_jones_pair_repulsive_core() -> None: - """Test that the potential is strongly repulsive at short distances.""" - dr_close = torch.tensor([[0.5]]) # Less than sigma - dr_far = torch.tensor([[2.0]]) # Greater than sigma - e_close = lennard_jones_pair(dr_close) - e_far = lennard_jones_pair(dr_far) + """The potential is strongly repulsive at short distances.""" + z = _dummy_z(1) + e_close = lennard_jones_pair(torch.tensor([0.5]), z, z) + e_far = lennard_jones_pair(torch.tensor([2.0]), z, z) assert e_close > e_far assert e_close > 0 # Repulsive assert e_far < 0 # Attractive -def test_lennard_jones_pair_tensor_params() -> None: - """Test that the function works with tensor parameters.""" - dr = torch.ones(3, 3) * 1.5 - sigma = torch.ones(3, 3) - epsilon = torch.ones(3, 3) * 2.0 - energy = lennard_jones_pair(dr, sigma=sigma, epsilon=epsilon) - assert energy.shape == (3, 3) - - def test_lennard_jones_pair_zero_distance() -> None: - """Test that the function handles zero distances gracefully.""" - dr = torch.zeros(2, 2) - energy = lennard_jones_pair(dr) + """The function handles zero distances gracefully.""" + dr = torch.zeros(2) + z = _dummy_z(2) + energy = lennard_jones_pair(dr, z, z) assert not torch.isnan(energy).any() assert not torch.isinf(energy).any() -def test_lennard_jones_pair_batch() -> None: - """Test that the function works with batched inputs.""" - batch_size = 10 - n_particles = 5 - dr = torch.rand(batch_size, n_particles, n_particles) + 0.5 - energy = lennard_jones_pair(dr) - assert energy.shape == (batch_size, n_particles, n_particles) - - def test_lennard_jones_pair_force_scaling() -> None: - """Test that the force scales correctly with epsilon.""" - dr = torch.ones(5, 5) * 1.5 + """Force scales linearly with epsilon.""" + dr = torch.tensor([1.5]) f1 = lennard_jones_pair_force(dr, sigma=1.0, epsilon=1.0) f2 = lennard_jones_pair_force(dr, sigma=1.0, epsilon=2.0) - assert torch.allclose(f2, 2 * f1) + torch.testing.assert_close(f2, 2.0 * f1) def test_lennard_jones_pair_force_repulsive_core() -> None: - """Test that the force is strongly repulsive at short distances.""" - dr_close = torch.tensor([[0.5]]) # Less than sigma - dr_far = torch.tensor([[2.0]]) # Greater than sigma - f_close = lennard_jones_pair_force(dr_close) - f_far = lennard_jones_pair_force(dr_far) + """Force is repulsive at short distances and attractive at long distances.""" + f_close = lennard_jones_pair_force(torch.tensor([0.5])) + f_far = lennard_jones_pair_force(torch.tensor([2.0])) assert f_close > 0 # Repulsive assert f_far < 0 # Attractive - assert abs(f_close) > abs(f_far) # Stronger at short range - - -def test_lennard_jones_pair_force_tensor_params() -> None: - """Test that the function works with tensor parameters.""" - dr = torch.ones(3, 3) * 1.5 - sigma = torch.ones(3, 3) - epsilon = torch.ones(3, 3) * 2.0 - force = lennard_jones_pair_force(dr, sigma=sigma, epsilon=epsilon) - assert force.shape == (3, 3) + assert abs(f_close) > abs(f_far) def test_lennard_jones_pair_force_zero_distance() -> None: - """Test that the function handles zero distances gracefully.""" - dr = torch.zeros(2, 2) + """The force function handles zero distances gracefully.""" + dr = torch.zeros(2) force = lennard_jones_pair_force(dr) assert not torch.isnan(force).any() assert not torch.isinf(force).any() -def test_lennard_jones_pair_force_batch() -> None: - """Test that the function works with batched inputs.""" - batch_size = 10 - n_particles = 5 - dr = torch.rand(batch_size, n_particles, n_particles) + 0.5 - force = lennard_jones_pair_force(dr) - assert force.shape == (batch_size, n_particles, n_particles) - - def test_lennard_jones_force_energy_consistency() -> None: - """Test that the force is consistent with the energy gradient.""" + """Force is consistent with the energy gradient.""" dr = torch.linspace(0.8, 2.0, 100, requires_grad=True) - dr = dr.reshape(-1, 1) + z = _dummy_z(len(dr)) - # Calculate force directly force_direct = lennard_jones_pair_force(dr) - # Calculate force from energy gradient - energy = lennard_jones_pair(dr) + energy = lennard_jones_pair(dr, z, z) force_from_grad = -torch.autograd.grad(energy.sum(), dr, create_graph=True)[0] - # Compare forces (allowing for some numerical differences) - assert torch.allclose(force_direct, force_from_grad, rtol=1e-4, atol=1e-4) - - -# NOTE: This is a large system to robustly compare neighbor-list implementations. -@pytest.fixture -def ar_supercell_sim_state_large() -> ts.SimState: - """Create a face-centered cubic (FCC) Argon structure.""" - # Create FCC Ar using ASE, with 4x4x4 supercell - ar_atoms = bulk("Ar", "fcc", a=5.26, cubic=True).repeat([4, 4, 4]) - return ts.io.atoms_to_state(ar_atoms, DEVICE, torch.float64) - - -@pytest.fixture -def models( - ar_supercell_sim_state_large: ts.SimState, -) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: - """Create default and explicit N2 neighbor-list outputs with Argon parameters.""" - model_kwargs: dict[str, float | bool | torch.dtype] = { - "sigma": 3.405, # Å, typical for Ar - "epsilon": 0.0104, # eV, typical for Ar - "dtype": torch.float64, - "compute_forces": True, - "compute_stress": True, - "per_atom_energies": True, - "per_atom_stresses": True, - } - cutoff = 2.5 * 3.405 # Standard LJ cutoff * sigma - model_default = LennardJonesModel(cutoff=cutoff, **model_kwargs) - model_n2 = LennardJonesModel( - cutoff=cutoff, - neighbor_list_fn=neighbors.torch_nl_n2, - **model_kwargs, - ) - - return ( - model_default(ar_supercell_sim_state_large), - model_n2(ar_supercell_sim_state_large), - ) - - -@pytest.fixture -def standard_vs_batched_models( - ar_supercell_sim_state_large: ts.SimState, -) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: - """Create unbatched and default Lennard-Jones outputs for parity checks.""" - model_kwargs: dict[str, float | bool | torch.dtype | torch.device] = { - "sigma": 3.405, - "epsilon": 0.0104, - "dtype": torch.float64, - "device": DEVICE, - "compute_forces": True, - "compute_stress": True, - "per_atom_energies": True, - "per_atom_stresses": True, - } - cutoff = 2.5 * 3.405 - standard = UnbatchedLennardJonesModel(cutoff=cutoff, **model_kwargs) - batched = LennardJonesModel(cutoff=cutoff, **model_kwargs) - return standard(ar_supercell_sim_state_large), batched(ar_supercell_sim_state_large) - - -def test_energy_match( - models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]], -) -> None: - """Test that total energy matches across neighbor-list implementations.""" - results_default, results_n2 = models - assert torch.allclose(results_default["energy"], results_n2["energy"], rtol=1e-10) - - -def test_per_atom_energy_match( - models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]], -) -> None: - """Test that per-atom energy matches across neighbor-list implementations.""" - results_default, results_n2 = models - assert torch.allclose(results_default["energies"], results_n2["energies"], rtol=1e-10) - - -def test_forces_match( - models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]], -) -> None: - """Test that forces match across neighbor-list implementations.""" - results_default, results_n2 = models - assert torch.allclose(results_default["forces"], results_n2["forces"], rtol=1e-10) - - -def test_stress_match( - models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]], -) -> None: - """Test that stress tensors match across neighbor-list implementations.""" - results_default, results_n2 = models - assert torch.allclose(results_default["stress"], results_n2["stress"], rtol=1e-10) - - -def test_per_atom_stress_match( - models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]], -) -> None: - """Test that per-atom stress tensors match across neighbor-list implementations.""" - results_default, results_n2 = models - assert torch.allclose(results_default["stresses"], results_n2["stresses"], rtol=1e-10) - - -@pytest.mark.parametrize( - "key", - ["energy", "energies", "forces", "stress", "stresses"], -) -def test_batched_lj_matches_standard( - standard_vs_batched_models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]], - key: str, -) -> None: - """Test that the default batched model matches the unbatched reference.""" - standard_out, batched_out = standard_vs_batched_models - torch.testing.assert_close( - batched_out[key], standard_out[key], rtol=1e-10, atol=1e-10 - ) - - -def test_batched_lj_multi_system_matches_standard( - ar_double_sim_state: ts.SimState, -) -> None: - """Test default model multi-system parity with unbatched reference.""" - model_kwargs: dict[str, float | bool | torch.dtype | torch.device] = { - "sigma": 3.405, - "epsilon": 0.0104, - "dtype": torch.float64, - "device": DEVICE, - "compute_forces": True, - "compute_stress": True, - } - cutoff = 2.5 * 3.405 - standard = UnbatchedLennardJonesModel(cutoff=cutoff, **model_kwargs) - batched = LennardJonesModel(cutoff=cutoff, **model_kwargs) - - standard_out = standard(ar_double_sim_state) - batched_out = batched(ar_double_sim_state) - - assert batched_out["energy"].shape == (ar_double_sim_state.n_systems,) - for key in ("energy", "forces", "stress"): - torch.testing.assert_close( - batched_out[key], standard_out[key], rtol=1e-10, atol=1e-10 - ) - - -def test_force_conservation( - models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]], -) -> None: - """Test that forces sum to zero.""" - results_default, _ = models - assert torch.allclose( - results_default["forces"].sum(dim=0), - torch.zeros(3, dtype=torch.float64), - atol=1e-10, - ) - - -@pytest.mark.parametrize("model_cls", [UnbatchedLennardJonesModel, LennardJonesModel]) -@pytest.mark.parametrize( - "neighbor_list_fn", - [neighbors.torch_nl_linked_cell, neighbors.torch_nl_n2], -) -def test_custom_neighbor_list_fn_matches_default( - model_cls: type[UnbatchedLennardJonesModel], - neighbor_list_fn: Callable[..., tuple[torch.Tensor, torch.Tensor, torch.Tensor]], - ar_supercell_sim_state_large: ts.SimState, -) -> None: - """Test that custom neighbor-list implementations match default behavior.""" - model_kwargs: dict[str, float | bool | torch.dtype | torch.device] = { - "sigma": 3.405, - "epsilon": 0.0104, - "dtype": torch.float64, - "device": DEVICE, - "compute_forces": True, - "compute_stress": True, - } - cutoff = 2.5 * 3.405 - default_model = model_cls( - cutoff=cutoff, - neighbor_list_fn=neighbors.torchsim_nl, - **model_kwargs, - ) - custom_model = model_cls( - cutoff=cutoff, - neighbor_list_fn=neighbor_list_fn, - **model_kwargs, - ) - - default_out = default_model(ar_supercell_sim_state_large) - custom_out = custom_model(ar_supercell_sim_state_large) - - for key in ("energy", "forces", "stress"): - torch.testing.assert_close( - custom_out[key], default_out[key], rtol=1e-10, atol=1e-10 - ) - + torch.testing.assert_close(force_direct, force_from_grad, rtol=1e-4, atol=1e-4) -def test_stress_tensor_symmetry( - models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]], -) -> None: - """Test that stress tensor is symmetric.""" - results_default, _ = models - # select trailing two dimensions - stress_tensor = results_default["stress"][0] - assert torch.allclose(stress_tensor, stress_tensor.T, atol=1e-10) - - -def test_validate_model_outputs(lj_model: LennardJonesModel) -> None: - """Test that the model outputs are valid.""" - validate_model_outputs(lj_model, DEVICE, torch.float64) - - -def test_unwrapped_positions_consistency() -> None: - """Test that wrapped and unwrapped positions give identical results. - - This tests that models correctly handle positions outside the unit cell - by wrapping them before neighbor list computation. - """ - # Create a periodic system - ar_atoms = bulk("Ar", "fcc", a=5.26, cubic=True).repeat([2, 2, 2]) - cell = torch.tensor(ar_atoms.get_cell().array, dtype=torch.float64, device=DEVICE) - # Create wrapped state (positions inside unit cell) - state_wrapped = ts.io.atoms_to_state(ar_atoms, DEVICE, torch.float64) - - # Create unwrapped state by shifting some atoms outside the cell - positions_unwrapped = state_wrapped.positions.clone() - # Shift first half of atoms by +1 cell vector in x direction - n_atoms = positions_unwrapped.shape[0] - positions_unwrapped[: n_atoms // 2] += cell[0] - # Shift some atoms by -1 cell vector in y direction - positions_unwrapped[n_atoms // 4 : n_atoms // 2] -= cell[1] - - state_unwrapped = ts.SimState.from_state(state_wrapped, positions=positions_unwrapped) - - # Create model +def test_lennard_jones_model_evaluation(si_double_sim_state: ts.SimState) -> None: + """LennardJonesModel (wrapped PairPotentialModel) evaluates correctly.""" model = LennardJonesModel( sigma=3.405, epsilon=0.0104, cutoff=2.5 * 3.405, dtype=torch.float64, - device=DEVICE, compute_forces=True, compute_stress=True, ) + results = model(si_double_sim_state) + assert "energy" in results + assert "forces" in results + assert "stress" in results + assert results["energy"].shape == (si_double_sim_state.n_systems,) + assert results["forces"].shape == (si_double_sim_state.n_atoms, 3) + assert results["stress"].shape == (si_double_sim_state.n_systems, 3, 3) - # Compute results - results_wrapped = model(state_wrapped) - results_unwrapped = model(state_unwrapped) - - # Verify energy matches - torch.testing.assert_close( - results_wrapped["energy"], - results_unwrapped["energy"], - rtol=1e-10, - atol=1e-10, - msg="Energies should match for wrapped and unwrapped positions", - ) - # Verify forces match - torch.testing.assert_close( - results_wrapped["forces"], - results_unwrapped["forces"], - rtol=1e-10, - atol=1e-10, - msg="Forces should match for wrapped and unwrapped positions", - ) - - # Verify stress matches - torch.testing.assert_close( - results_wrapped["stress"], - results_unwrapped["stress"], - rtol=1e-10, - atol=1e-10, - msg="Stress should match for wrapped and unwrapped positions", +def test_lennard_jones_model_force_conservation( + si_double_sim_state: ts.SimState, +) -> None: + """LennardJonesModel forces sum to zero (Newton's third law).""" + model = LennardJonesModel( + sigma=3.405, + epsilon=0.0104, + cutoff=2.5 * 3.405, + dtype=torch.float64, + compute_forces=True, ) + results = model(si_double_sim_state) + for sys_idx in range(si_double_sim_state.n_systems): + mask = si_double_sim_state.system_idx == sys_idx + assert torch.allclose( + results["forces"][mask].sum(dim=0), + torch.zeros(3, dtype=torch.float64), + atol=1e-10, + ) diff --git a/tests/models/test_morse.py b/tests/models/test_morse.py index f65cb10b..5fe8aa5c 100644 --- a/tests/models/test_morse.py +++ b/tests/models/test_morse.py @@ -1,144 +1,115 @@ -"""Tests for Morse potential calculator using copper parameters.""" +"""Tests for the Morse pair functions and wrapped model.""" -import pytest import torch import torch_sim as ts from torch_sim.models.morse import MorseModel, morse_pair, morse_pair_force -def test_morse_pair_minimum() -> None: - """Test that the potential has its minimum at r=sigma.""" - dr = torch.linspace(0.8, 1.2, 100) - dr = dr.reshape(-1, 1) - energy = morse_pair(dr) - min_idx = torch.argmin(energy) - torch.testing.assert_close(dr[min_idx], torch.tensor([1.0]), rtol=0.01, atol=1e-5) +def _dummy_z(n: int) -> torch.Tensor: + return torch.ones(n, dtype=torch.long) + + +def test_morse_pair_minimum_at_sigma() -> None: + """Morse minimum is at r = sigma.""" + dr = torch.linspace(0.5, 2.0, 500) + z = _dummy_z(len(dr)) + energies = morse_pair(dr, z, z, sigma=1.0, epsilon=5.0, alpha=5.0) + min_r = dr[energies.argmin()] + assert abs(min_r.item() - 1.0) < 0.01 + + +def test_morse_pair_energy_at_minimum() -> None: + """Morse energy at minimum equals -epsilon.""" + dr = torch.tensor([1.0]) + z = _dummy_z(1) + e = morse_pair(dr, z, z, sigma=1.0, epsilon=5.0, alpha=5.0) + torch.testing.assert_close(e, torch.tensor([-5.0]), rtol=1e-5, atol=1e-5) def test_morse_pair_scaling() -> None: - """Test that the potential scales correctly with epsilon.""" - dr = torch.ones(5, 5) * 1.5 - e1 = morse_pair(dr, epsilon=1.0) - e2 = morse_pair(dr, epsilon=2.0) - torch.testing.assert_close(e2, 2 * e1, rtol=1e-5, atol=1e-5) - - -def test_morse_pair_asymptotic() -> None: - """Test that the potential approaches -epsilon at large distances.""" - dr = torch.tensor([[1.0]]) # Large distance - epsilon = 5.0 - energy = morse_pair(dr, epsilon=epsilon) - torch.testing.assert_close( - energy, -epsilon * torch.ones_like(energy), rtol=1e-2, atol=1e-5 - ) + """Energy scales linearly with epsilon.""" + dr = torch.tensor([1.5]) + z = _dummy_z(1) + e1 = morse_pair(dr, z, z, epsilon=1.0) + e2 = morse_pair(dr, z, z, epsilon=2.0) + torch.testing.assert_close(e2, 2.0 * e1, rtol=1e-5, atol=1e-5) def test_morse_pair_force_scaling() -> None: - """Test that the force scales correctly with epsilon.""" - dr = torch.ones(5, 5) * 1.5 - f1 = morse_pair_force(dr, epsilon=1.0) - f2 = morse_pair_force(dr, epsilon=2.0) - assert torch.allclose(f2, 2 * f1) + """Force scales linearly with epsilon.""" + dr = torch.tensor([1.5]) + z = _dummy_z(1) + f1 = morse_pair_force(dr, z, z, epsilon=1.0) + f2 = morse_pair_force(dr, z, z, epsilon=2.0) + torch.testing.assert_close(f2, 2.0 * f1) def test_morse_force_energy_consistency() -> None: - """Test that the force is consistent with the energy gradient.""" + """Force is consistent with the energy gradient.""" dr = torch.linspace(0.8, 2.0, 100, requires_grad=True) - dr = dr.reshape(-1, 1) + z = _dummy_z(len(dr)) - # Calculate force directly - force_direct = morse_pair_force(dr) + force_direct = morse_pair_force(dr, z, z) - # Calculate force from energy gradient - energy = morse_pair(dr) + energy = morse_pair(dr, z, z) force_from_grad = -torch.autograd.grad(energy.sum(), dr, create_graph=True)[0] - # Compare forces - assert torch.allclose(force_direct, force_from_grad, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(force_direct, force_from_grad, rtol=1e-4, atol=1e-4) def test_morse_alpha_effect() -> None: - """Test that larger alpha values make the potential well narrower.""" + """Larger alpha values make the potential well narrower.""" dr = torch.linspace(0.8, 1.2, 100) - dr = dr.reshape(-1, 1) + z = _dummy_z(len(dr)) - energy1 = morse_pair(dr, alpha=5.0) - energy2 = morse_pair(dr, alpha=10.0) + energy1 = morse_pair(dr, z, z, alpha=5.0) + energy2 = morse_pair(dr, z, z, alpha=10.0) - # Calculate width at half minimum def get_well_width(energy: torch.Tensor) -> torch.Tensor: min_e = torch.min(energy) half_e = min_e / 2 mask = energy < half_e return dr[mask].max() - dr[mask].min() - width1 = get_well_width(energy1) - width2 = get_well_width(energy2) - assert width2 < width1 # Higher alpha should give narrower well - - -@pytest.fixture -def models( - cu_supercell_sim_state: ts.SimState, -) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: - """Create both neighbor list and direct calculators with Copper parameters.""" - # Parameters for Copper (Cu) using Morse potential - # Values from: https://doi.org/10.1016/j.commatsci.2004.12.069 - model_kwargs: dict[str, float | bool | torch.dtype] = { - "sigma": 2.55, # Å, equilibrium distance - "epsilon": 0.436, # eV, dissociation energy - "alpha": 1.359, # Å^-1, controls potential well width - "dtype": torch.float64, - "compute_forces": True, - "compute_stress": True, - } - cutoff = 2.5 * 2.55 # Similar scaling as LJ cutoff - model_nl = MorseModel(use_neighbor_list=True, cutoff=cutoff, **model_kwargs) - model_direct = MorseModel(use_neighbor_list=False, cutoff=cutoff, **model_kwargs) - - return model_nl(cu_supercell_sim_state), model_direct(cu_supercell_sim_state) - - -def test_energy_match( - models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]], -) -> None: - """Test that total energy matches between neighbor list and direct calculations.""" - results_nl, results_direct = models - assert torch.allclose(results_nl["energy"], results_direct["energy"], rtol=1e-10) - - -def test_forces_match( - models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]], -) -> None: - """Test that forces match between neighbor list and direct calculations.""" - results_nl, results_direct = models - assert torch.allclose(results_nl["forces"], results_direct["forces"], rtol=1e-10) - - -def test_stress_match( - models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]], -) -> None: - """Test that stress tensors match between neighbor list and direct calculations.""" - results_nl, results_direct = models - assert torch.allclose(results_nl["stress"], results_direct["stress"], rtol=1e-10) - - -def test_force_conservation( - models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]], -) -> None: - """Test that forces sum to zero (Newton's third law).""" - results_nl, _ = models - assert torch.allclose( - results_nl["forces"].sum(dim=0), torch.zeros(3, dtype=torch.float64), atol=1e-10 - ) + assert get_well_width(energy2) < get_well_width(energy1) -def test_stress_tensor_symmetry( - models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]], -) -> None: - """Test that stress tensor is symmetric.""" - results_nl, _ = models - assert torch.allclose( - results_nl["stress"].squeeze(), results_nl["stress"].squeeze().T, atol=1e-10 +def test_morse_model_evaluation(si_double_sim_state: ts.SimState) -> None: + """MorseModel (wrapped PairPotentialModel) evaluates correctly.""" + model = MorseModel( + sigma=2.55, + epsilon=0.436, + alpha=1.359, + cutoff=6.0, + dtype=torch.float64, + compute_forces=True, + compute_stress=True, + ) + results = model(si_double_sim_state) + assert "energy" in results + assert "forces" in results + assert "stress" in results + assert results["energy"].shape == (si_double_sim_state.n_systems,) + assert results["forces"].shape == (si_double_sim_state.n_atoms, 3) + assert results["stress"].shape == (si_double_sim_state.n_systems, 3, 3) + + +def test_morse_model_force_conservation(si_double_sim_state: ts.SimState) -> None: + """MorseModel forces sum to zero (Newton's third law).""" + model = MorseModel( + sigma=2.55, + epsilon=0.436, + alpha=1.359, + cutoff=6.0, + dtype=torch.float64, + compute_forces=True, ) + results = model(si_double_sim_state) + for sys_idx in range(si_double_sim_state.n_systems): + mask = si_double_sim_state.system_idx == sys_idx + assert torch.allclose( + results["forces"][mask].sum(dim=0), + torch.zeros(3, dtype=torch.float64), + atol=1e-10, + ) diff --git a/tests/models/test_pair_potential.py b/tests/models/test_pair_potential.py index b0b62cc1..355d8171 100644 --- a/tests/models/test_pair_potential.py +++ b/tests/models/test_pair_potential.py @@ -1,40 +1,75 @@ -"""Tests for the general pair potential model and standard pair functions.""" +"""Tests for the general pair potential model and pair forces model.""" import functools import pytest import torch +from ase.build import bulk import torch_sim as ts from tests.conftest import DEVICE, DTYPE from tests.models.conftest import make_validate_model_outputs_test -from torch_sim.models.interface import ModelInterface -from torch_sim.models.lennard_jones import LennardJonesModel -from torch_sim.models.morse import MorseModel +from torch_sim import io +from torch_sim.models.lennard_jones import LennardJonesModel, lennard_jones_pair +from torch_sim.models.morse import morse_pair from torch_sim.models.pair_potential import ( - MultiSoftSpherePairFn, PairForcesModel, PairPotentialModel, full_to_half_list, - lj_pair, - morse_pair, - particle_life_pair_force, - soft_sphere_pair, ) -from torch_sim.models.soft_sphere import SoftSphereModel +from torch_sim.models.particle_life import particle_life_pair_force +from torch_sim.models.soft_sphere import soft_sphere_pair +from torch_sim.neighbors import torch_nl_n2 + + +# BMHTF (Born-Meyer-Huggins-Tosi-Fumi) potential for NaCl +# Na-Cl interaction parameters +BMHTF_A = 20.3548 +BMHTF_B = 3.1546 +BMHTF_C = 674.4793 +BMHTF_D = 837.0770 +BMHTF_SIGMA = 2.755 +BMHTF_CUTOFF = 10.0 + + +def bmhtf_pair( + dr: torch.Tensor, + zi: torch.Tensor, # noqa: ARG001 + zj: torch.Tensor, # noqa: ARG001 + A: float, + B: float, + C: float, + D: float, + sigma: float, +) -> torch.Tensor: + """Born-Meyer-Huggins-Tosi-Fumi (BMHTF) potential for ionic crystals.""" + exp_term = A * torch.exp(B * (sigma - dr)) + r6_term = C / dr.pow(6) + r8_term = D / dr.pow(8) + energy = exp_term - r6_term - r8_term + return torch.where(dr > 0, energy, torch.zeros_like(energy)) -# Argon LJ parameters -LJ_SIGMA = 3.405 -LJ_EPSILON = 0.0104 -LJ_CUTOFF = 2.5 * LJ_SIGMA +@pytest.fixture +def nacl_sim_state() -> ts.SimState: + """NaCl structure for BMHTF potential tests.""" + nacl_atoms = bulk("NaCl", "rocksalt", a=5.64) + return io.atoms_to_state(nacl_atoms, device=DEVICE, dtype=DTYPE) @pytest.fixture -def lj_model_pp() -> PairPotentialModel: +def bmhtf_model_pp() -> PairPotentialModel: + """BMHTF model using PairPotentialModel to test general case.""" return PairPotentialModel( - pair_fn=functools.partial(lj_pair, sigma=LJ_SIGMA, epsilon=LJ_EPSILON), - cutoff=LJ_CUTOFF, + pair_fn=functools.partial( + bmhtf_pair, + A=BMHTF_A, + B=BMHTF_B, + C=BMHTF_C, + D=BMHTF_D, + sigma=BMHTF_SIGMA, + ), + cutoff=BMHTF_CUTOFF, dtype=DTYPE, compute_forces=True, compute_stress=True, @@ -56,7 +91,7 @@ def particle_life_model() -> PairForcesModel: # Interface validation via factory test_pair_potential_model_outputs = make_validate_model_outputs_test( - model_fixture_name="lj_model_pp", device=DEVICE, dtype=DTYPE + model_fixture_name="bmhtf_model_pp", device=DEVICE, dtype=DTYPE ) test_pair_forces_model_outputs = make_validate_model_outputs_test( @@ -64,267 +99,6 @@ def particle_life_model() -> PairForcesModel: ) -def _dummy_z(n: int) -> torch.Tensor: - return torch.ones(n, dtype=torch.long) - - -def test_lj_pair_minimum() -> None: - """Minimum of LJ is at r = 2^(1/6) * sigma.""" - dr = torch.linspace(0.9, 1.5, 500) - z = _dummy_z(len(dr)) - energies = lj_pair(dr, z, z, sigma=1.0, epsilon=1.0) - min_r = dr[energies.argmin()] - assert abs(min_r.item() - 2 ** (1 / 6)) < 0.01 - - -def test_lj_pair_energy_at_minimum() -> None: - """Energy at minimum equals -epsilon.""" - r_min = torch.tensor([2 ** (1 / 6)]) - z = _dummy_z(1) - e = lj_pair(r_min, z, z, sigma=1.0, epsilon=2.0) - torch.testing.assert_close(e, torch.tensor([-2.0]), rtol=1e-5, atol=1e-5) - - -def test_lj_pair_epsilon_scaling() -> None: - dr = torch.tensor([1.5]) - z = _dummy_z(1) - e1 = lj_pair(dr, z, z, sigma=1.0, epsilon=1.0) - e2 = lj_pair(dr, z, z, sigma=1.0, epsilon=3.0) - torch.testing.assert_close(e2, 3.0 * e1) - - -def test_morse_pair_minimum_at_sigma() -> None: - """Morse minimum is at r = sigma.""" - dr = torch.linspace(0.5, 2.0, 500) - z = _dummy_z(len(dr)) - energies = morse_pair(dr, z, z, sigma=1.0, epsilon=5.0, alpha=5.0) - min_r = dr[energies.argmin()] - assert abs(min_r.item() - 1.0) < 0.01 - - -def test_morse_pair_energy_at_minimum() -> None: - """Morse energy at minimum equals -epsilon.""" - dr = torch.tensor([1.0]) - z = _dummy_z(1) - e = morse_pair(dr, z, z, sigma=1.0, epsilon=5.0, alpha=5.0) - torch.testing.assert_close(e, torch.tensor([-5.0]), rtol=1e-5, atol=1e-5) - - -def test_soft_sphere_zero_beyond_sigma() -> None: - """Soft-sphere energy is zero for r >= sigma.""" - dr = torch.tensor([1.0, 1.5, 2.0]) - z = _dummy_z(3) - e = soft_sphere_pair(dr, z, z, sigma=1.0) - assert e[1] == 0.0 - assert e[2] == 0.0 - - -def test_soft_sphere_repulsive_only() -> None: - """Soft-sphere energies are non-negative for r < sigma.""" - dr = torch.linspace(0.1, 0.99, 50) - z = _dummy_z(len(dr)) - e = soft_sphere_pair(dr, z, z, sigma=1.0, epsilon=1.0, alpha=2.0) - assert (e >= 0).all() - - -def test_particle_life_force_inner() -> None: - """For dr < beta the force is negative (repulsive).""" - dr = torch.tensor([0.1, 0.2]) - z = _dummy_z(2) - f = particle_life_pair_force(dr, z, z, A=1.0, beta=0.3, sigma=1.0) - assert (f < 0).all() - - -def test_particle_life_force_zero_beyond_sigma() -> None: - dr = torch.tensor([1.0, 1.5]) - z = _dummy_z(2) - f = particle_life_pair_force(dr, z, z, A=1.0, beta=0.3, sigma=1.0) - assert f[0] == 0.0 - assert f[1] == 0.0 - - -def _make_mss( - sigma: float = 1.0, epsilon: float = 1.0, alpha: float = 2.0 -) -> MultiSoftSpherePairFn: - """Two-species MultiSoftSpherePairFn with uniform parameters.""" - n = 2 - return MultiSoftSpherePairFn( - atomic_numbers=torch.tensor([18, 36]), - sigma_matrix=torch.full((n, n), sigma), - epsilon_matrix=torch.full((n, n), epsilon), - alpha_matrix=torch.full((n, n), alpha), - ) - - -def test_multi_soft_sphere_zero_beyond_sigma() -> None: - """Energy is zero for r >= sigma.""" - fn = _make_mss(sigma=1.0) - dr = torch.tensor([1.0, 1.5]) - zi = zj = torch.tensor([18, 36]) - e = fn(dr, zi, zj) - assert (e == 0.0).all() - - -def test_multi_soft_sphere_repulsive_only() -> None: - """Energy is non-negative for r < sigma.""" - fn = _make_mss(sigma=2.0, epsilon=1.0, alpha=2.0) - dr = torch.linspace(0.1, 1.99, 20) - zi = zj = torch.full((20,), 18, dtype=torch.long) - assert (fn(dr, zi, zj) >= 0).all() - - -def test_multi_soft_sphere_species_lookup() -> None: - """Different species pairs use the correct off-diagonal parameters.""" - sigma_matrix = torch.tensor([[1.0, 2.0], [2.0, 3.0]]) - epsilon_matrix = torch.ones(2, 2) - alpha_matrix = torch.full((2, 2), 2.0) - fn = MultiSoftSpherePairFn( - atomic_numbers=torch.tensor([18, 36]), - sigma_matrix=sigma_matrix, - epsilon_matrix=epsilon_matrix, - alpha_matrix=alpha_matrix, - ) - dr = torch.tensor([0.5]) - zi_same = torch.tensor([18]) - zj_same = torch.tensor([18]) - zi_cross = torch.tensor([18]) - zj_cross = torch.tensor([36]) - e_same = fn(dr, zi_same, zj_same) # sigma=1.0, r=0.5 < sigma → non-zero - e_cross = fn(dr, zi_cross, zj_cross) # sigma=2.0, r=0.5 < sigma → non-zero - # cross pair has larger sigma so (1 - r/sigma) is larger → higher energy - assert e_cross > e_same - - -def test_multi_soft_sphere_alpha_matrix_default() -> None: - """Omitting alpha_matrix defaults to 2.0 for all pairs.""" - fn_default = MultiSoftSpherePairFn( - atomic_numbers=torch.tensor([18, 36]), - sigma_matrix=torch.full((2, 2), 1.0), - epsilon_matrix=torch.full((2, 2), 1.0), - ) - fn_explicit = _make_mss(sigma=1.0, epsilon=1.0, alpha=2.0) - dr = torch.tensor([0.5]) - zi = zj = torch.tensor([18]) - torch.testing.assert_close(fn_default(dr, zi, zj), fn_explicit(dr, zi, zj)) - - -def test_multi_soft_sphere_bad_matrix_shape_raises() -> None: - with pytest.raises(ValueError, match="sigma_matrix"): - MultiSoftSpherePairFn( - atomic_numbers=torch.tensor([18, 36]), - sigma_matrix=torch.ones(3, 3), # wrong shape - epsilon_matrix=torch.ones(2, 2), - ) - - -def _build_potential_model_pair(name: str) -> tuple[PairPotentialModel, ModelInterface]: - """Return (PairPotentialModel, reference_model) for a named potential.""" - if name == "lj-half": - pp = PairPotentialModel( - pair_fn=functools.partial(lj_pair, sigma=LJ_SIGMA, epsilon=LJ_EPSILON), - cutoff=LJ_CUTOFF, - dtype=DTYPE, - compute_forces=True, - compute_stress=True, - per_atom_energies=True, - per_atom_stresses=True, - reduce_to_half_list=True, - ) - ref = LennardJonesModel( - sigma=LJ_SIGMA, - epsilon=LJ_EPSILON, - cutoff=LJ_CUTOFF, - dtype=DTYPE, - compute_forces=True, - compute_stress=True, - per_atom_energies=True, - per_atom_stresses=True, - ) - return pp, ref - if name == "lj-full": - pp = PairPotentialModel( - pair_fn=functools.partial(lj_pair, sigma=LJ_SIGMA, epsilon=LJ_EPSILON), - cutoff=LJ_CUTOFF, - dtype=DTYPE, - compute_forces=True, - compute_stress=True, - per_atom_energies=True, - per_atom_stresses=True, - reduce_to_half_list=False, - ) - ref = LennardJonesModel( - sigma=LJ_SIGMA, - epsilon=LJ_EPSILON, - cutoff=LJ_CUTOFF, - dtype=DTYPE, - compute_forces=True, - compute_stress=True, - per_atom_energies=True, - per_atom_stresses=True, - ) - return pp, ref - if name == "morse": - sigma, epsilon, alpha, cutoff = 4.0, 5.0, 5.0, 5.0 - pp = PairPotentialModel( - pair_fn=functools.partial( - morse_pair, sigma=sigma, epsilon=epsilon, alpha=alpha - ), - cutoff=cutoff, - dtype=DTYPE, - compute_forces=True, - compute_stress=True, - ) - ref = MorseModel( - sigma=sigma, - epsilon=epsilon, - alpha=alpha, - cutoff=cutoff, - dtype=DTYPE, - compute_forces=True, - compute_stress=True, - ) - return pp, ref - if name == "soft_sphere": - sigma, epsilon, alpha = 5, 0.0104, 2.0 - pp = PairPotentialModel( - pair_fn=functools.partial( - soft_sphere_pair, sigma=sigma, epsilon=epsilon, alpha=alpha - ), - cutoff=sigma, - dtype=DTYPE, - compute_forces=True, - compute_stress=True, - ) - ref = SoftSphereModel( - sigma=sigma, - epsilon=epsilon, - alpha=alpha, - cutoff=sigma, - dtype=DTYPE, - compute_forces=True, - compute_stress=True, - ) - return pp, ref - - raise ValueError(f"Unknown potential: {name}") - - -@pytest.mark.parametrize("potential", ["lj-half", "lj-full", "morse", "soft_sphere"]) -def test_potential_matches_reference( - mixed_double_sim_state: ts.SimState, - potential: str, -) -> None: - """PairPotentialModel matches the dedicated reference model.""" - model_pp, model_ref = _build_potential_model_pair(potential) - out_pp = model_pp(mixed_double_sim_state) - out_ref = model_ref(mixed_double_sim_state) - - assert (out_pp["energy"] != 0).all() - - for key in out_pp: - torch.testing.assert_close(out_pp[key], out_ref[key], rtol=1e-4, atol=1e-5) - - def test_full_to_half_list_removes_duplicates() -> None: """i < j mask halves a symmetric full neighbor list.""" # 3-atom full list: (0,1),(1,0),(0,2),(2,0),(1,2),(2,1) @@ -353,13 +127,17 @@ def test_full_to_half_list_preserves_system_and_shifts() -> None: @pytest.mark.parametrize("key", ["energy", "forces", "stress", "stresses"]) def test_half_list_matches_full(si_double_sim_state: ts.SimState, key: str) -> None: """reduce_to_half_list=True gives the same result as the default full list.""" - fn = functools.partial(lj_pair, sigma=LJ_SIGMA, epsilon=LJ_EPSILON) + # Argon LJ parameters + sigma = 3.405 + epsilon = 0.0104 + cutoff = 2.5 * sigma + fn = functools.partial(lennard_jones_pair, sigma=sigma, epsilon=epsilon) needs_forces = key in ("forces", "stress", "stresses") needs_stress = key in ("stress", "stresses") common = dict( pair_fn=fn, - cutoff=LJ_CUTOFF, - dtype=DTYPE, + cutoff=cutoff, + dtype=si_double_sim_state.dtype, compute_forces=needs_forces, compute_stress=needs_stress, per_atom_stresses=(key == "stresses"), @@ -371,14 +149,25 @@ def test_half_list_matches_full(si_double_sim_state: ts.SimState, key: str) -> N torch.testing.assert_close(out_half[key], out_full[key], rtol=1e-10, atol=1e-14) -@pytest.mark.parametrize("potential", ["lj", "morse", "soft_sphere"]) +@pytest.mark.parametrize("potential", ["bmhtf", "morse", "soft_sphere"]) def test_autograd_force_fn_matches_potential_model( - si_double_sim_state: ts.SimState, potential: str + nacl_sim_state: ts.SimState, + si_double_sim_state: ts.SimState, + potential: str, ) -> None: """PairForcesModel with -dV/dr force fn matches PairPotentialModel forces/stress.""" - if potential == "lj": - pair_fn = functools.partial(lj_pair, sigma=LJ_SIGMA, epsilon=LJ_EPSILON) - cutoff = LJ_CUTOFF + # Use NaCl for BMHTF, si_double for others + sim_state = nacl_sim_state if potential == "bmhtf" else si_double_sim_state + if potential == "bmhtf": + pair_fn = functools.partial( + bmhtf_pair, + A=BMHTF_A, + B=BMHTF_B, + C=BMHTF_C, + D=BMHTF_D, + sigma=BMHTF_SIGMA, + ) + cutoff = BMHTF_CUTOFF elif potential == "morse": pair_fn = functools.partial(morse_pair, sigma=4.0, epsilon=5.0, alpha=5.0) cutoff = 5.0 @@ -395,7 +184,7 @@ def force_fn(dr: torch.Tensor, zi: torch.Tensor, zj: torch.Tensor) -> torch.Tens model_pp = PairPotentialModel( pair_fn=pair_fn, cutoff=cutoff, - dtype=DTYPE, + dtype=sim_state.dtype, compute_forces=True, compute_stress=True, per_atom_stresses=True, @@ -403,12 +192,12 @@ def force_fn(dr: torch.Tensor, zi: torch.Tensor, zj: torch.Tensor) -> torch.Tens model_pf = PairForcesModel( force_fn=force_fn, cutoff=cutoff, - dtype=DTYPE, + dtype=sim_state.dtype, compute_stress=True, per_atom_stresses=True, ) - out_pp = model_pp(si_double_sim_state) - out_pf = model_pf(si_double_sim_state) + out_pp = model_pp(sim_state) + out_pf = model_pf(sim_state) assert (out_pp["forces"] != 0.0).all() @@ -435,3 +224,130 @@ def test_forces_model_half_list_matches_full( out_full = model_full(si_double_sim_state) out_half = model_half(si_double_sim_state) torch.testing.assert_close(out_half[key], out_full[key], rtol=1e-10, atol=1e-13) + + +def test_force_conservation( + bmhtf_model_pp: PairPotentialModel, nacl_sim_state: ts.SimState +) -> None: + """Forces sum to zero (Newton's third law).""" + out = bmhtf_model_pp(nacl_sim_state) + for sys_idx in range(nacl_sim_state.n_systems): + mask = nacl_sim_state.system_idx == sys_idx + assert torch.allclose( + out["forces"][mask].sum(dim=0), + torch.zeros(3, dtype=nacl_sim_state.dtype), + atol=1e-10, + ) + + +def test_stress_tensor_symmetry( + bmhtf_model_pp: PairPotentialModel, nacl_sim_state: ts.SimState +) -> None: + """Stress tensor is symmetric.""" + out = bmhtf_model_pp(nacl_sim_state) + for i in range(nacl_sim_state.n_systems): + stress = out["stress"][i] + assert torch.allclose(stress, stress.T, atol=1e-10) + + +def test_multi_system(ar_double_sim_state: ts.SimState) -> None: + """Multi-system batched evaluation matches single-system evaluation.""" + model = LennardJonesModel( + sigma=3.405, + epsilon=0.0104, + cutoff=2.5 * 3.405, + dtype=torch.float64, + device=DEVICE, + compute_forces=True, + compute_stress=True, + ) + out = model(ar_double_sim_state) + + assert out["energy"].shape == (ar_double_sim_state.n_systems,) + # Both systems are identical, so energies should match + torch.testing.assert_close(out["energy"][0], out["energy"][1], rtol=1e-10, atol=1e-10) + + +def test_unwrapped_positions_consistency() -> None: + """Wrapped and unwrapped positions give identical results.""" + ar_atoms = bulk("Ar", "fcc", a=5.26, cubic=True).repeat([2, 2, 2]) + cell = torch.tensor(ar_atoms.get_cell().array, dtype=torch.float64, device=DEVICE) + + state_wrapped = ts.io.atoms_to_state(ar_atoms, DEVICE, torch.float64) + + positions_unwrapped = state_wrapped.positions.clone() + n_atoms = positions_unwrapped.shape[0] + positions_unwrapped[: n_atoms // 2] += cell[0] + positions_unwrapped[n_atoms // 4 : n_atoms // 2] -= cell[1] + + state_unwrapped = ts.SimState.from_state(state_wrapped, positions=positions_unwrapped) + + model = LennardJonesModel( + sigma=3.405, + epsilon=0.0104, + cutoff=2.5 * 3.405, + dtype=torch.float64, + device=DEVICE, + compute_forces=True, + compute_stress=True, + ) + + results_wrapped = model(state_wrapped) + results_unwrapped = model(state_unwrapped) + + for key in ("energy", "forces", "stress"): + torch.testing.assert_close( + results_wrapped[key], results_unwrapped[key], rtol=1e-10, atol=1e-10 + ) + + +def test_retain_graph_allows_param_grad(nacl_sim_state: ts.SimState) -> None: + """With retain_graph=True, energy graph survives force computation so we can + differentiate energy w.r.t. model parameters (e.g. A, B, C, D).""" + A = torch.tensor(BMHTF_A, dtype=nacl_sim_state.dtype, requires_grad=True) + pair_fn = functools.partial( + bmhtf_pair, + A=A, + B=BMHTF_B, + C=BMHTF_C, + D=BMHTF_D, + sigma=BMHTF_SIGMA, + ) + model = PairPotentialModel( + pair_fn=pair_fn, + cutoff=BMHTF_CUTOFF, + dtype=nacl_sim_state.dtype, + compute_forces=True, + neighbor_list_fn=torch_nl_n2, + retain_graph=True, + ) + out = model(nacl_sim_state) + assert out["forces"] is not None + (grad,) = torch.autograd.grad(out["energy"].sum(), A) + assert grad.shape == A.shape + assert grad.abs() > 0 + + +def test_no_retain_graph_frees_graph(nacl_sim_state: ts.SimState) -> None: + """Without retain_graph, differentiating energy w.r.t. parameters after force + computation raises because the graph has been freed.""" + A = torch.tensor(BMHTF_A, dtype=nacl_sim_state.dtype, requires_grad=True) + pair_fn = functools.partial( + bmhtf_pair, + A=A, + B=BMHTF_B, + C=BMHTF_C, + D=BMHTF_D, + sigma=BMHTF_SIGMA, + ) + model = PairPotentialModel( + pair_fn=pair_fn, + cutoff=BMHTF_CUTOFF, + dtype=nacl_sim_state.dtype, + compute_forces=True, + neighbor_list_fn=torch_nl_n2, + retain_graph=False, + ) + out = model(nacl_sim_state) + with pytest.raises(RuntimeError, match="does not require grad"): + torch.autograd.grad(out["energy"].sum(), A) diff --git a/tests/models/test_particle_life.py b/tests/models/test_particle_life.py new file mode 100644 index 00000000..d93b1f3b --- /dev/null +++ b/tests/models/test_particle_life.py @@ -0,0 +1,80 @@ +"""Tests for the particle life force function and wrapped model.""" + +import torch + +import torch_sim as ts +from torch_sim.models.particle_life import ParticleLifeModel, particle_life_pair_force + + +def _dummy_z(n: int) -> torch.Tensor: + return torch.ones(n, dtype=torch.long) + + +def test_inner_region_repulsive() -> None: + """For dr < beta the force is negative (repulsive).""" + dr = torch.tensor([0.1, 0.2]) + z = _dummy_z(2) + f = particle_life_pair_force(dr, z, z, A=1.0, beta=0.3, sigma=1.0) + assert (f < 0).all() + + +def test_zero_beyond_sigma() -> None: + """Force is zero at and beyond sigma.""" + dr = torch.tensor([1.0, 1.5]) + z = _dummy_z(2) + f = particle_life_pair_force(dr, z, z, A=1.0, beta=0.3, sigma=1.0) + assert (f == 0.0).all() + + +def test_amplitude_scaling() -> None: + """Outer-region force scales with A.""" + dr = torch.tensor([0.6]) # between beta and sigma + z = _dummy_z(1) + f1 = particle_life_pair_force(dr, z, z, A=1.0, beta=0.3, sigma=1.0) + f2 = particle_life_pair_force(dr, z, z, A=3.0, beta=0.3, sigma=1.0) + torch.testing.assert_close(f2, 3.0 * f1) + + +def test_particle_life_model_evaluation(si_double_sim_state: ts.SimState) -> None: + """ParticleLifeModel (wrapped PairForcesModel) evaluates correctly.""" + model = ParticleLifeModel( + A=1.0, + beta=0.3, + sigma=5.26, + cutoff=5.26, + dtype=si_double_sim_state.dtype, + compute_stress=True, + ) + results = model(si_double_sim_state) + assert "energy" in results + assert "forces" in results + assert "stress" in results + assert results["energy"].shape == (si_double_sim_state.n_systems,) + assert results["forces"].shape == (si_double_sim_state.n_atoms, 3) + assert results["stress"].shape == (si_double_sim_state.n_systems, 3, 3) + # Energy should be zeros for PairForcesModel + assert torch.allclose( + results["energy"], + torch.zeros(si_double_sim_state.n_systems, dtype=si_double_sim_state.dtype), + ) + + +def test_particle_life_model_force_conservation( + si_double_sim_state: ts.SimState, +) -> None: + """ParticleLifeModel forces sum to zero (Newton's third law).""" + model = ParticleLifeModel( + A=1.0, + beta=0.3, + sigma=5.26, + cutoff=5.26, + dtype=torch.float64, + ) + results = model(si_double_sim_state) + for sys_idx in range(si_double_sim_state.n_systems): + mask = si_double_sim_state.system_idx == sys_idx + assert torch.allclose( + results["forces"][mask].sum(dim=0), + torch.zeros(3, dtype=torch.float64), + atol=1e-10, + ) diff --git a/tests/models/test_soft_sphere.py b/tests/models/test_soft_sphere.py index 1b1cae6f..daa32102 100644 --- a/tests/models/test_soft_sphere.py +++ b/tests/models/test_soft_sphere.py @@ -1,146 +1,36 @@ -"""Tests for soft sphere models ensuring different parts of TorchSim work together.""" +"""Tests for the soft sphere pair functions, wrapped model, and multi-species models.""" import pytest import torch import torch_sim as ts -import torch_sim.models.soft_sphere as ss -from tests.conftest import DEVICE -from torch_sim.models.interface import validate_model_outputs - - -def _make_soft_sphere_model( - *, use_neighbor_list: bool, with_per_atom: bool = False -) -> ss.SoftSphereModel: - """Create a SoftSphereModel with common test defaults.""" - model_kwargs: dict[str, float | bool | torch.dtype] = { - "sigma": 3.405, # Å, typical for Ar - "epsilon": 0.0104, # eV, typical for Ar - "alpha": 2.0, - "dtype": torch.float64, - "compute_forces": True, - "compute_stress": True, - } - if with_per_atom: - model_kwargs["per_atom_energies"] = True - model_kwargs["per_atom_stresses"] = True - return ss.SoftSphereModel(use_neighbor_list=use_neighbor_list, **model_kwargs) - - -@pytest.fixture -def models( - fe_supercell_sim_state: ts.SimState, -) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: - """Create both neighbor list and direct calculators.""" - model_nl = _make_soft_sphere_model(use_neighbor_list=True) - model_direct = _make_soft_sphere_model(use_neighbor_list=False) - - return model_nl(fe_supercell_sim_state), model_direct(fe_supercell_sim_state) - - -@pytest.fixture -def models_with_per_atom( - fe_supercell_sim_state: ts.SimState, -) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: - """Create calculators with per-atom properties enabled.""" - model_nl = _make_soft_sphere_model(use_neighbor_list=True, with_per_atom=True) - model_direct = _make_soft_sphere_model(use_neighbor_list=False, with_per_atom=True) - - return model_nl(fe_supercell_sim_state), model_direct(fe_supercell_sim_state) - - -@pytest.fixture -def small_system() -> tuple[torch.Tensor, torch.Tensor]: - """Create a small simple cubic system for testing.""" - positions = torch.tensor( - [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]], - dtype=torch.float64, - ) - cell = torch.eye(3, dtype=torch.float64) * 2.0 - return positions, cell - - -@pytest.fixture -def small_sim_state(small_system: tuple[torch.Tensor, torch.Tensor]) -> ts.SimState: - """Create a small SimState for testing.""" - positions, cell = small_system - return ts.SimState( - positions=positions, - cell=cell, - pbc=True, - masses=torch.ones(positions.shape[0], dtype=torch.float64), - atomic_numbers=torch.ones(positions.shape[0], dtype=torch.long), - ) - - -@pytest.fixture -def small_batched_sim_state(small_sim_state: ts.SimState) -> ts.SimState: - """Create a batched state from the small system.""" - return ts.concatenate_states( - [small_sim_state, small_sim_state], device=small_sim_state.device - ) - - -@pytest.mark.parametrize("output_key", ["energy", "forces", "stress"]) -def test_outputs_match( - models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]], - output_key: str, -) -> None: - """Test that outputs match between neighbor list and direct calculations.""" - results_nl, results_direct = models - assert torch.allclose(results_nl[output_key], results_direct[output_key], rtol=1e-10) +from torch_sim.models.soft_sphere import ( + MultiSoftSpherePairFn, + SoftSphereModel, + SoftSphereMultiModel, + soft_sphere_pair, +) -def test_force_conservation( - models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]], -) -> None: - """Test that forces sum to zero.""" - results_nl, _ = models - assert torch.allclose( - results_nl["forces"].sum(dim=0), torch.zeros(3, dtype=torch.float64), atol=1e-10 - ) +def _dummy_z(n: int) -> torch.Tensor: + return torch.ones(n, dtype=torch.long) -def test_stress_tensor_symmetry( - models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]], -) -> None: - """Test that stress tensor is symmetric.""" - results_nl, _ = models - assert torch.allclose(results_nl["stress"], results_nl["stress"].T, atol=1e-10) +def test_soft_sphere_zero_beyond_sigma() -> None: + """Soft-sphere energy is zero for r >= sigma.""" + dr = torch.tensor([1.0, 1.5, 2.0]) + z = _dummy_z(3) + e = soft_sphere_pair(dr, z, z, sigma=1.0) + assert e[1] == 0.0 + assert e[2] == 0.0 -def test_validate_model_outputs() -> None: - """Test that the model outputs are valid.""" - model_nl = _make_soft_sphere_model(use_neighbor_list=True) - model_direct = _make_soft_sphere_model(use_neighbor_list=False) - for out in (model_nl, model_direct): - validate_model_outputs(out, DEVICE, torch.float64) - - -@pytest.mark.parametrize( - ("per_atom_key", "total_key"), [("energies", "energy"), ("stresses", "stress")] -) -def test_per_atom_properties( - models_with_per_atom: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]], - per_atom_key: str, - total_key: str, -) -> None: - """Test that per-atom properties are calculated correctly.""" - results_nl, results_direct = models_with_per_atom - - # Check per-atom properties are calculated and match - assert torch.allclose( - results_nl[per_atom_key], results_direct[per_atom_key], rtol=1e-10 - ) - - # Check sum of per-atom properties matches total property - if per_atom_key == "energies": - assert torch.allclose( - results_nl[per_atom_key].sum(), results_nl[total_key], rtol=1e-10 - ) - else: # stresses - total_from_atoms = results_nl[per_atom_key].sum(dim=0) - assert torch.allclose(total_from_atoms, results_nl[total_key], rtol=1e-10) +def test_soft_sphere_repulsive_only() -> None: + """Soft-sphere energies are non-negative for r < sigma.""" + dr = torch.linspace(0.1, 0.99, 50) + z = _dummy_z(len(dr)) + e = soft_sphere_pair(dr, z, z, sigma=1.0, epsilon=1.0, alpha=2.0) + assert (e >= 0).all() @pytest.mark.parametrize( @@ -155,108 +45,102 @@ def test_soft_sphere_pair_single( distance: float, sigma: float, epsilon: float, alpha: float, expected: float ) -> None: """Test the soft sphere pair calculation for single values.""" - energy = ss.soft_sphere_pair( - torch.tensor(distance), - torch.tensor(sigma), - torch.tensor(epsilon), - torch.tensor(alpha), + dr = torch.tensor([distance]) + z = _dummy_z(1) + energy = soft_sphere_pair(dr, z, z, sigma=sigma, epsilon=epsilon, alpha=alpha) + torch.testing.assert_close(energy, torch.tensor([expected])) + + +def _make_mss( + sigma: float = 1.0, epsilon: float = 1.0, alpha: float = 2.0 +) -> MultiSoftSpherePairFn: + """Two-species MultiSoftSpherePairFn with uniform parameters.""" + n = 2 + return MultiSoftSpherePairFn( + atomic_numbers=torch.tensor([18, 36]), + sigma_matrix=torch.full((n, n), sigma), + epsilon_matrix=torch.full((n, n), epsilon), + alpha_matrix=torch.full((n, n), alpha), ) - assert torch.allclose(energy, torch.tensor(expected)) - - -def test_model_initialization_defaults() -> None: - """Test initialization with default parameters.""" - model = ss.SoftSphereModel() - - # Check default parameters are used - assert torch.allclose(model.sigma, ss.DEFAULT_SIGMA) - assert torch.allclose(model.epsilon, ss.DEFAULT_EPSILON) - assert torch.allclose(model.alpha, ss.DEFAULT_ALPHA) - assert torch.allclose(model.cutoff, ss.DEFAULT_SIGMA) # Default cutoff is sigma - - -@pytest.mark.parametrize( - ("param_name", "param_value", "expected_dtype"), - [ - ("sigma", 2.0, torch.float64), - ("epsilon", 3.0, torch.float64), - ("alpha", 4.0, torch.float64), - ("cutoff", 5.0, torch.float64), - ], -) -def test_model_initialization_custom_params( - param_name: str, param_value: float, expected_dtype: torch.dtype -) -> None: - """Test initialization with custom parameters.""" - model = ss.SoftSphereModel(**{param_name: param_value, "dtype": expected_dtype}) - - param_tensor = getattr(model, param_name) - assert torch.allclose(param_tensor, torch.tensor(param_value, dtype=expected_dtype)) - assert param_tensor.dtype == expected_dtype -@pytest.mark.parametrize( - ("flag_name", "flag_value"), - [ - ("compute_forces", False), - ("compute_stress", True), - ("per_atom_energies", True), - ("per_atom_stresses", True), - ("use_neighbor_list", False), - ], -) -def test_model_initialization_custom_flags(*, flag_name: str, flag_value: bool) -> None: - """Test initialization with custom flags.""" - model = ss.SoftSphereModel(**{flag_name: flag_value}) - - # For compute_forces and compute_stress, we need to check the private attributes - if flag_name == "compute_forces": - flag_name = "_compute_forces" - elif flag_name == "compute_stress": - flag_name = "_compute_stress" +def test_multi_soft_sphere_zero_beyond_sigma() -> None: + """Energy is zero for r >= sigma.""" + fn = _make_mss(sigma=1.0) + dr = torch.tensor([1.0, 1.5]) + zi = zj = torch.tensor([18, 36]) + e = fn(dr, zi, zj) + assert (e == 0.0).all() - assert getattr(model, flag_name) is flag_value +def test_multi_soft_sphere_repulsive_only() -> None: + """Energy is non-negative for r < sigma.""" + fn = _make_mss(sigma=2.0, epsilon=1.0, alpha=2.0) + dr = torch.linspace(0.1, 1.99, 20) + zi = zj = torch.full((20,), 18, dtype=torch.long) + assert (fn(dr, zi, zj) >= 0).all() -@pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) -def test_model_dtype(dtype: torch.dtype) -> None: - """Test model with different dtypes.""" - model = ss.SoftSphereModel(dtype=dtype) - assert model.sigma.dtype == dtype - assert model.epsilon.dtype == dtype - assert model.alpha.dtype == dtype - assert model.cutoff.dtype == dtype +def test_multi_soft_sphere_species_lookup() -> None: + """Different species pairs use the correct off-diagonal parameters.""" + sigma_matrix = torch.tensor([[1.0, 2.0], [2.0, 3.0]]) + epsilon_matrix = torch.ones(2, 2) + alpha_matrix = torch.full((2, 2), 2.0) + fn = MultiSoftSpherePairFn( + atomic_numbers=torch.tensor([18, 36]), + sigma_matrix=sigma_matrix, + epsilon_matrix=epsilon_matrix, + alpha_matrix=alpha_matrix, + ) + dr = torch.tensor([0.5]) + zi_same = torch.tensor([18]) + zj_same = torch.tensor([18]) + zi_cross = torch.tensor([18]) + zj_cross = torch.tensor([36]) + e_same = fn(dr, zi_same, zj_same) # sigma=1.0, r=0.5 < sigma → non-zero + e_cross = fn(dr, zi_cross, zj_cross) # sigma=2.0, r=0.5 < sigma → non-zero + # cross pair has larger sigma so (1 - r/sigma) is larger → higher energy + assert e_cross > e_same + + +def test_multi_soft_sphere_alpha_matrix_default() -> None: + """Omitting alpha_matrix defaults to 2.0 for all pairs.""" + fn_default = MultiSoftSpherePairFn( + atomic_numbers=torch.tensor([18, 36]), + sigma_matrix=torch.full((2, 2), 1.0), + epsilon_matrix=torch.full((2, 2), 1.0), + ) + fn_explicit = _make_mss(sigma=1.0, epsilon=1.0, alpha=2.0) + dr = torch.tensor([0.5]) + zi = zj = torch.tensor([18]) + torch.testing.assert_close(fn_default(dr, zi, zj), fn_explicit(dr, zi, zj)) + + +def test_multi_soft_sphere_bad_matrix_shape_raises() -> None: + with pytest.raises(ValueError, match="sigma_matrix"): + MultiSoftSpherePairFn( + atomic_numbers=torch.tensor([18, 36]), + sigma_matrix=torch.ones(3, 3), # wrong shape + epsilon_matrix=torch.ones(2, 2), + ) def test_multispecies_initialization_defaults() -> None: - """Test initialization of multi-species model with defaults.""" - dtype = torch.float32 - model = ss.SoftSphereMultiModel(n_species=2, dtype=dtype) - - # Check matrices are created with defaults + """Multi-species model initializes with default parameters.""" + model = SoftSphereMultiModel(atomic_numbers=torch.tensor([0, 1]), dtype=torch.float32) assert model.sigma_matrix.shape == (2, 2) assert model.epsilon_matrix.shape == (2, 2) assert model.alpha_matrix.shape == (2, 2) - # Check default values - ones = torch.ones(2, 2, dtype=dtype) - assert torch.allclose(model.sigma_matrix, ss.DEFAULT_SIGMA * ones) - assert torch.allclose(model.epsilon_matrix, ss.DEFAULT_EPSILON * ones) - assert torch.allclose(model.alpha_matrix, ss.DEFAULT_ALPHA * ones) - - # Check cutoff is max sigma - assert model.cutoff.item() == ss.DEFAULT_SIGMA.item() - def test_multispecies_initialization_custom() -> None: - """Test initialization of multi-species model with custom parameters.""" + """Multi-species model stores custom parameter matrices.""" sigma_matrix = torch.tensor([[1.0, 1.5], [1.5, 2.0]], dtype=torch.float64) epsilon_matrix = torch.tensor([[1.0, 0.5], [0.5, 1.5]], dtype=torch.float64) alpha_matrix = torch.tensor([[2.0, 3.0], [3.0, 4.0]], dtype=torch.float64) - model = ss.SoftSphereMultiModel( - n_species=2, + model = SoftSphereMultiModel( + atomic_numbers=torch.tensor([0, 1]), sigma_matrix=sigma_matrix, epsilon_matrix=epsilon_matrix, alpha_matrix=alpha_matrix, @@ -264,25 +148,20 @@ def test_multispecies_initialization_custom() -> None: dtype=torch.float64, ) - # Check matrices are stored correctly assert torch.allclose(model.sigma_matrix, sigma_matrix) assert torch.allclose(model.epsilon_matrix, epsilon_matrix) assert torch.allclose(model.alpha_matrix, alpha_matrix) - - # Check cutoff is set explicitly assert model.cutoff.item() == 3.0 def test_multispecies_matrix_validation() -> None: - """Test validation of parameter matrices.""" - # Create incorrect-sized matrices (2x2 instead of 3x3) + """Incorrectly sized matrices raise ValueError.""" sigma_matrix = torch.tensor([[1.0, 1.5], [1.5, 2.0]]) epsilon_matrix = torch.tensor([[1.0, 0.5], [0.5, 1.5]]) - # Should raise ValueError due to matrix size mismatch with pytest.raises(ValueError, match="sigma_matrix must have shape"): - ss.SoftSphereMultiModel( - n_species=3, + SoftSphereMultiModel( + atomic_numbers=torch.tensor([0, 1, 2]), sigma_matrix=sigma_matrix, epsilon_matrix=epsilon_matrix, ) @@ -297,52 +176,98 @@ def test_multispecies_matrix_validation() -> None: ], ) def test_matrix_symmetry_validation(matrix_name: str, matrix: torch.Tensor) -> None: - """Test that parameter matrices are validated for symmetry.""" - # Create symmetric matrices for the other parameters + """Parameter matrices are validated for symmetry.""" symmetric_matrix = torch.tensor([[1.0, 1.5], [1.5, 2.0]]) - params = { - "n_species": 2, + "atomic_numbers": torch.tensor([0, 1]), "sigma_matrix": symmetric_matrix, "epsilon_matrix": symmetric_matrix, "alpha_matrix": symmetric_matrix, } - - # Replace one matrix with the non-symmetric version params[matrix_name] = matrix - # Should raise ValueError due to asymmetric matrix with pytest.raises(ValueError, match="is not symmetric"): - ss.SoftSphereMultiModel(**params) + SoftSphereMultiModel(**params) def test_multispecies_cutoff_default() -> None: - """Test that the default cutoff is the maximum sigma value.""" + """Default cutoff is the maximum sigma value.""" sigma_matrix = torch.tensor([[1.0, 1.5, 2.0], [1.5, 2.0, 2.5], [2.0, 2.5, 3.0]]) + model = SoftSphereMultiModel( + atomic_numbers=torch.tensor([0, 1, 2]), sigma_matrix=sigma_matrix + ) + assert model.cutoff.item() == 3.0 - model = ss.SoftSphereMultiModel(n_species=3, sigma_matrix=sigma_matrix) - # Cutoff should default to max value in sigma_matrix - assert model.cutoff.item() == 3.0 +def test_multispecies_evaluation() -> None: + """Multi-species model evaluates without error on a small system.""" + sigma_matrix = torch.tensor([[1.0, 0.8], [0.8, 0.6]], dtype=torch.float64) + epsilon_matrix = torch.tensor([[1.0, 0.5], [0.5, 2.0]], dtype=torch.float64) + model = SoftSphereMultiModel( + atomic_numbers=torch.tensor([0, 1]), + sigma_matrix=sigma_matrix, + epsilon_matrix=epsilon_matrix, + dtype=torch.float64, + compute_forces=True, + compute_stress=True, + ) + + positions = torch.tensor( + [[0.0, 0.0, 0.0], [0.5, 0.0, 0.0], [0.0, 0.5, 0.0], [0.5, 0.5, 0.0]], + dtype=torch.float64, + ) + cell = torch.eye(3, dtype=torch.float64) * 2.0 + state = ts.SimState( + positions=positions, + cell=cell, + pbc=True, + masses=torch.ones(4, dtype=torch.float64), + atomic_numbers=torch.tensor([0, 0, 1, 1], dtype=torch.long), + ) + results = model(state) + assert "energy" in results + assert "forces" in results + assert "stress" in results + + +def test_soft_sphere_model_evaluation(si_double_sim_state: ts.SimState) -> None: + """SoftSphereModel (wrapped PairPotentialModel) evaluates correctly.""" + model = SoftSphereModel( + sigma=5.0, + epsilon=0.0104, + alpha=2.0, + cutoff=5.0, + dtype=torch.float64, + compute_forces=True, + compute_stress=True, + ) + results = model(si_double_sim_state) + assert "energy" in results + assert "forces" in results + assert "stress" in results + assert results["energy"].shape == (si_double_sim_state.n_systems,) + assert results["forces"].shape == (si_double_sim_state.n_atoms, 3) + assert results["stress"].shape == (si_double_sim_state.n_systems, 3, 3) -@pytest.mark.parametrize( - ("flag_name", "flag_value"), - [ - ("pbc", torch.tensor([True, True, True])), - ("pbc", torch.tensor([False, False, False])), - ("compute_forces", False), - ("compute_stress", True), - ("per_atom_energies", True), - ("per_atom_stresses", False), - ("use_neighbor_list", True), - ("use_neighbor_list", False), - ], -) -def test_multispecies_model_flags(*, flag_name: str, flag_value: bool) -> None: - """Test flags of the SoftSphereMultiModel.""" - model = ss.SoftSphereMultiModel(n_species=2, **{flag_name: flag_value}) - # For SoftSphereMultiModel, we don't need to convert attribute names - # as it uses public attribute names for all flags - assert getattr(model, flag_name) is flag_value +def test_soft_sphere_model_force_conservation( + si_double_sim_state: ts.SimState, +) -> None: + """SoftSphereModel forces sum to zero (Newton's third law).""" + model = SoftSphereModel( + sigma=5.0, + epsilon=0.0104, + alpha=2.0, + cutoff=5.0, + dtype=torch.float64, + compute_forces=True, + ) + results = model(si_double_sim_state) + for sys_idx in range(si_double_sim_state.n_systems): + mask = si_double_sim_state.system_idx == sys_idx + assert torch.allclose( + results["forces"][mask].sum(dim=0), + torch.zeros(3, dtype=torch.float64), + atol=1e-10, + ) diff --git a/tests/test_fix_symmetry.py b/tests/test_fix_symmetry.py index a770168b..ea5076d7 100644 --- a/tests/test_fix_symmetry.py +++ b/tests/test_fix_symmetry.py @@ -11,9 +11,10 @@ from ase.stress import full_3x3_to_voigt_6_stress, voigt_6_to_full_3x3_stress import torch_sim as ts +from tests.conftest import DEVICE, DTYPE from torch_sim.constraints import FixCom, FixSymmetry from torch_sim.models.interface import ModelInterface -from torch_sim.models.lennard_jones import UnbatchedLennardJonesModel +from torch_sim.models.lennard_jones import LennardJonesModel from torch_sim.symmetrize import get_symmetry_datasets @@ -22,15 +23,11 @@ SPACEGROUPS = {"fcc": 225, "hcp": 194, "diamond": 227, "bcc": 229, "p6bar": 174} MAX_STEPS = 30 -DTYPE = torch.float64 SYMPREC = 0.01 -CPU = torch.device("cpu") +REPEATS = 2 -# === Structure helpers === - - -def make_structure(name: str) -> Atoms: +def make_structure(name: str, repeats: int = REPEATS) -> Atoms: """Create a test structure by name (fcc/hcp/diamond/bcc/p6bar + _rotated suffix).""" base = name.replace("_rotated", "") builders = { @@ -46,6 +43,8 @@ def make_structure(name: str) -> Atoms: ), } atoms = builders[base]() + # make a supercell to exaggerate the impact of symmetry breaking noise + atoms = atoms.repeat([repeats, repeats, repeats]) if "_rotated" in name: rotation_product = np.eye(3) for axis_idx in range(3): @@ -63,13 +62,10 @@ def make_structure(name: str) -> Atoms: return atoms -# === Fixtures === - - @pytest.fixture -def model() -> UnbatchedLennardJonesModel: +def model() -> LennardJonesModel: """LJ model for testing.""" - return UnbatchedLennardJonesModel( + return LennardJonesModel( sigma=1.0, epsilon=0.05, cutoff=6.0, @@ -79,42 +75,58 @@ def model() -> UnbatchedLennardJonesModel: class NoisyModelWrapper(ModelInterface): - """Wrapper that adds noise to forces and stress.""" + """Wrapper that adds Weibull-distributed noise to forces and stress. + + Uses Weibull noise (heavy-tailed) rather than Gaussian so that occasional + large perturbations can break symmetry in negative-control tests. This + also better mimics real ML potential errors, which have heavy tails. + """ - model: UnbatchedLennardJonesModel - rng: np.random.Generator + model: LennardJonesModel noise_scale: float + concentration: float def __init__( self, - model: UnbatchedLennardJonesModel, - noise_scale: float = 1e-4, + model: LennardJonesModel, + noise_scale: float = 1e-1, + concentration: float = 1.0, ) -> None: super().__init__() self.model = model - self.rng = np.random.default_rng(seed=1) self.noise_scale = noise_scale + self.concentration = concentration self._device = model.device self._dtype = model.dtype self._compute_stress = model.compute_stress self._compute_forces = model.compute_forces def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]: - """Forward pass with added noise.""" + """Forward pass with added Weibull noise.""" results = self.model(state, **kwargs) for key in ("forces", "stress"): if key in results: - noise = torch.tensor( - self.rng.normal(size=results[key].shape), - dtype=results[key].dtype, - device=results[key].device, + shape = results[key].shape + # Random direction on the unit sphere (per element for stress) + direction = torch.randn(shape, generator=state.rng) + direction = direction / torch.norm(direction, dim=-1, keepdim=True).clamp( + min=1e-12 + ) + # Weibull magnitude via inverse CDF: scale * (-ln(U))^(1/k) + u = torch.rand(shape[0], generator=state.rng) + magnitudes = self.noise_scale * (-torch.log(u)).pow( + 1.0 / self.concentration ) - results[key] = results[key] + self.noise_scale * noise + if key == "forces": + noise = magnitudes.unsqueeze(-1) * direction + else: + noise = magnitudes.view(-1, 1, 1) * direction + results[key] = results[key] + noise return results @pytest.fixture -def noisy_lj_model(model: UnbatchedLennardJonesModel) -> NoisyModelWrapper: +def noisy_lj_model(model: LennardJonesModel) -> NoisyModelWrapper: """LJ model with noise added to forces/stress.""" return NoisyModelWrapper(model) @@ -123,7 +135,7 @@ def noisy_lj_model(model: UnbatchedLennardJonesModel) -> NoisyModelWrapper: def p6bar_both_constraints() -> tuple[ts.SimState, FixSymmetry, Atoms, ASEFixSymmetry]: """P-6 structure with both TorchSim and ASE constraints (shared setup).""" atoms = make_structure("p6bar") - state = ts.io.atoms_to_state(atoms, CPU, DTYPE) + state = ts.io.atoms_to_state(atoms, DEVICE, DTYPE) ts_constraint = FixSymmetry.from_state(state, symprec=SYMPREC) ase_atoms = atoms.copy() ase_refine_symmetry(ase_atoms, symprec=SYMPREC) @@ -131,9 +143,6 @@ def p6bar_both_constraints() -> tuple[ts.SimState, FixSymmetry, Atoms, ASEFixSym return state, ts_constraint, ase_atoms, ase_constraint -# === Optimization helper === - - def run_optimization_check_symmetry( state: ts.SimState, model: ModelInterface, @@ -168,9 +177,6 @@ def run_optimization_check_symmetry( } -# === Tests: Creation & Basics === - - class TestFixSymmetryCreation: """Tests for FixSymmetry creation and basic behavior.""" @@ -178,14 +184,15 @@ def test_from_state_batched(self) -> None: """Batched state with FCC + diamond gets correct ops, atom counts, and DOF.""" state = ts.io.atoms_to_state( [make_structure("fcc"), make_structure("diamond")], - CPU, + DEVICE, DTYPE, ) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) assert len(constraint.rotations) == 2 - assert constraint.rotations[0].shape[0] == 48 # cubic - assert constraint.symm_maps[0].shape == (48, 1) # Cu: 1 atom - assert constraint.symm_maps[1].shape == (48, 2) # Si: 2 atoms + assert constraint.rotations[0].shape[0] == 48 * REPEATS**3 # cubic + n_ops = 48 * REPEATS**3 + assert constraint.symm_maps[0].shape == (n_ops, REPEATS**3) + assert constraint.symm_maps[1].shape == (n_ops, 2 * REPEATS**3) assert torch.all(constraint.get_removed_dof(state) == 0) def test_p1_identity_is_noop(self) -> None: @@ -196,7 +203,7 @@ def test_p1_identity_is_noop(self) -> None: cell=[[3.0, 0.1, 0.2], [0.15, 3.5, 0.1], [0.2, 0.15, 4.0]], pbc=True, ) - state = ts.io.atoms_to_state(atoms, CPU, DTYPE) + state = ts.io.atoms_to_state(atoms, DEVICE, DTYPE) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) assert constraint.rotations[0].shape[0] == 1 @@ -217,7 +224,7 @@ def test_from_state_refine_symmetry(self, *, refine: bool) -> None: atoms = make_structure("fcc") rng = np.random.default_rng(42) atoms.positions += rng.standard_normal(atoms.positions.shape) * 0.001 - state = ts.io.atoms_to_state(atoms, CPU, DTYPE) + state = ts.io.atoms_to_state(atoms, DEVICE, DTYPE) orig_pos = state.positions.clone() _ = FixSymmetry.from_state(state, symprec=SYMPREC, refine_symmetry_state=refine) if not refine: @@ -235,7 +242,7 @@ def test_refine_symmetry_produces_correct_spacegroup( expected = SPACEGROUPS[structure_name] rng = np.random.default_rng(42) atoms.positions += rng.standard_normal(atoms.positions.shape) * 0.001 - state = ts.io.atoms_to_state(atoms, CPU, DTYPE) + state = ts.io.atoms_to_state(atoms, DEVICE, DTYPE) refined_cell, refined_pos = refine_symmetry( state.row_vector_cell[0], @@ -250,23 +257,26 @@ def test_refine_symmetry_produces_correct_spacegroup( assert datasets[0].number == expected def test_cubic_forces_vanish(self) -> None: - """Asymmetric force on single cubic atom symmetrizes to zero.""" + """Asymmetric force on cubic atoms symmetrizes to zero.""" state = ts.io.atoms_to_state( [make_structure("fcc"), make_structure("diamond")], - CPU, + DEVICE, DTYPE, ) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) - forces = torch.tensor( - [[1.0, 0.0, 0.0], [0.5, 0.5, 0.0], [0.0, 0.5, 0.5]], - dtype=DTYPE, - ) + n_atoms = state.positions.shape[0] + # Random asymmetric forces on all atoms + forces = torch.randn(n_atoms, 3, device=DEVICE, dtype=DTYPE, generator=state.rng) constraint.adjust_forces(state, forces) - assert torch.allclose(forces[0], torch.zeros(3, dtype=DTYPE), atol=1e-10) + # All FCC atoms (first 27 in 3x3x3 supercell) should have zero forces + n_fcc = REPEATS**3 + assert torch.allclose( + forces[:n_fcc], torch.zeros(n_fcc, 3, dtype=DTYPE), atol=1e-10 + ) def test_large_deformation_clamped(self) -> None: """Per-step deformation > 0.25 is clamped rather than rejected.""" - state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE) + state = ts.io.atoms_to_state(make_structure("fcc"), DEVICE, DTYPE) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) orig_cell = state.cell.clone() new_cell = state.cell.clone() * 1.5 # 50% strain, well over 0.25 @@ -282,7 +292,7 @@ def test_large_deformation_clamped(self) -> None: def test_nan_deformation_raises(self) -> None: """NaN in proposed cell raises RuntimeError instead of propagating.""" - state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE) + state = ts.io.atoms_to_state(make_structure("fcc"), DEVICE, DTYPE) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) new_cell = state.cell.clone() new_cell[0, 0, 0] = float("nan") @@ -304,7 +314,7 @@ def test_init_mismatched_lengths_raises(self) -> None: def test_adjust_skipped_when_disabled(self, method: str) -> None: """adjust_positions=False / adjust_cell=False leaves data unchanged.""" flag = method.replace("adjust_", "") # "positions" or "cell" - state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE) + state = ts.io.atoms_to_state(make_structure("fcc"), DEVICE, DTYPE) constraint = FixSymmetry.from_state( state, symprec=SYMPREC, @@ -319,9 +329,6 @@ def test_adjust_skipped_when_disabled(self, method: str) -> None: assert torch.equal(data, expected) -# === Tests: Comparison with ASE === - - class TestFixSymmetryComparisonWithASE: """Compare TorchSim FixSymmetry with ASE's implementation on P-6 structure.""" @@ -384,15 +391,12 @@ def test_position_symmetrization_matches_ase( assert np.allclose(new_pos_ts.numpy(), new_pos_ase, atol=1e-10) -# === Tests: Merge, Select, Reindex === - - class TestFixSymmetryMergeSelectReindex: """Tests for reindex/merge API, select, and concatenation.""" def test_reindex_preserves_symmetry_data(self) -> None: """reindex shifts system_idx but preserves rotations and symm_maps.""" - state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE) + state = ts.io.atoms_to_state(make_structure("fcc"), DEVICE, DTYPE) orig = FixSymmetry.from_state(state, symprec=SYMPREC) shifted = orig.reindex(atom_offset=100, system_offset=5) assert shifted.system_idx.item() == 5 @@ -401,8 +405,8 @@ def test_reindex_preserves_symmetry_data(self) -> None: def test_merge_two_constraints(self) -> None: """Merge two single-system constraints via reindex + merge.""" - s1 = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE) - s2 = ts.io.atoms_to_state(make_structure("diamond"), CPU, DTYPE) + s1 = ts.io.atoms_to_state(make_structure("fcc"), DEVICE, DTYPE) + s2 = ts.io.atoms_to_state(make_structure("diamond"), DEVICE, DTYPE) c1 = FixSymmetry.from_state(s1) c2 = FixSymmetry.from_state(s2).reindex(atom_offset=0, system_offset=1) merged = FixSymmetry.merge([c1, c2]) @@ -417,9 +421,9 @@ def test_merge_multi_system_no_duplicate_indices(self) -> None: make_structure("hcp"), ] atoms_b = [make_structure("bcc"), make_structure("fcc")] - c_a = FixSymmetry.from_state(ts.io.atoms_to_state(atoms_a, CPU, DTYPE)) + c_a = FixSymmetry.from_state(ts.io.atoms_to_state(atoms_a, DEVICE, DTYPE)) c_b = FixSymmetry.from_state( - ts.io.atoms_to_state(atoms_b, CPU, DTYPE), + ts.io.atoms_to_state(atoms_b, DEVICE, DTYPE), ).reindex(atom_offset=0, system_offset=3) merged = FixSymmetry.merge([c_a, c_b]) assert merged.system_idx.tolist() == [0, 1, 2, 3, 4] @@ -428,12 +432,12 @@ def test_system_constraint_merge_multi_system_via_concatenate(self) -> None: """Regression: merging multi-system FixCom via concatenate_states.""" s1 = ts.io.atoms_to_state( [make_structure("fcc"), make_structure("diamond")], - CPU, + DEVICE, DTYPE, ) s2 = ts.io.atoms_to_state( [make_structure("bcc"), make_structure("hcp")], - CPU, + DEVICE, DTYPE, ) s1.constraints = [FixCom(system_idx=torch.tensor([0, 1]))] @@ -445,8 +449,8 @@ def test_system_constraint_merge_multi_system_via_concatenate(self) -> None: def test_concatenate_states_with_fix_symmetry(self) -> None: """FixSymmetry survives concatenate_states and still symmetrizes correctly.""" - s1 = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE) - s2 = ts.io.atoms_to_state(make_structure("diamond"), CPU, DTYPE) + s1 = ts.io.atoms_to_state(make_structure("fcc"), DEVICE, DTYPE) + s2 = ts.io.atoms_to_state(make_structure("diamond"), DEVICE, DTYPE) s1.constraints = [FixSymmetry.from_state(s1, symprec=SYMPREC)] s2.constraints = [FixSymmetry.from_state(s2, symprec=SYMPREC)] combined = ts.concatenate_states([s1, s2]) @@ -454,64 +458,70 @@ def test_concatenate_states_with_fix_symmetry(self) -> None: assert isinstance(constraint, FixSymmetry) assert constraint.system_idx.tolist() == [0, 1] assert len(constraint.rotations) == 2 - # Forces on single FCC atom should still vanish - forces = torch.tensor( - [[1.0, 0.0, 0.0], [0.5, 0.5, 0.0], [0.0, 0.5, 0.5]], - dtype=DTYPE, + # Forces on FCC atoms should still vanish after symmetrization + n_atoms = combined.positions.shape[0] + n_fcc = REPEATS**3 + forces = torch.randn( + n_atoms, 3, device=DEVICE, dtype=DTYPE, generator=combined.rng ) constraint.adjust_forces(combined, forces) - assert torch.allclose(forces[0], torch.zeros(3, dtype=DTYPE), atol=1e-10) + assert torch.allclose( + forces[:n_fcc], torch.zeros(n_fcc, 3, dtype=DTYPE), atol=1e-10 + ) def test_select_sub_constraint(self) -> None: """Select second system from batched constraint.""" state = ts.io.atoms_to_state( [make_structure("fcc"), make_structure("diamond")], - CPU, + DEVICE, DTYPE, ) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) selected = constraint.select_sub_constraint(torch.tensor([1, 2]), sys_idx=1) assert selected is not None - assert selected.symm_maps[0].shape[1] == 2 + # diamond has 2 atoms per unit cell + assert selected.symm_maps[0].shape[1] == 2 * REPEATS**3 assert selected.system_idx.item() == 0 def test_select_constraint_by_mask(self) -> None: """Select first system via system_mask.""" state = ts.io.atoms_to_state( [make_structure("fcc"), make_structure("diamond")], - CPU, + DEVICE, DTYPE, ) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + n_atoms = state.positions.shape[0] + atom_mask = torch.zeros(n_atoms, dtype=torch.bool) + atom_mask[0] = True # keep at least one atom from first system selected = constraint.select_constraint( - atom_mask=torch.tensor([True, False, False]), + atom_mask=atom_mask, system_mask=torch.tensor([True, False]), ) assert selected is not None assert len(selected.rotations) == 1 - assert selected.rotations[0].shape[0] == 48 + n_ops = 48 * REPEATS**3 # cubic x supercell translations + assert selected.rotations[0].shape[0] == n_ops def test_select_returns_none_for_nonexistent(self) -> None: """select_sub_constraint and select_constraint return None when no match.""" state = ts.io.atoms_to_state( [make_structure("fcc"), make_structure("diamond")], - CPU, + DEVICE, DTYPE, ) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) + n_atoms = state.positions.shape[0] assert constraint.select_sub_constraint(torch.tensor([0]), sys_idx=99) is None assert ( constraint.select_constraint( - atom_mask=torch.zeros(3, dtype=torch.bool), + atom_mask=torch.zeros(n_atoms, dtype=torch.bool), system_mask=torch.zeros(2, dtype=torch.bool), ) is None ) -# === Tests: build_symmetry_map chunked path === - - def test_build_symmetry_map_chunked_matches_vectorized() -> None: """Per-op loop gives same result as vectorized path.""" import torch_sim.symmetrize as sym_mod @@ -521,11 +531,11 @@ def test_build_symmetry_map_chunked_matches_vectorized() -> None: build_symmetry_map, ) - state = ts.io.atoms_to_state(make_structure("p6bar"), CPU, DTYPE) + state = ts.io.atoms_to_state(make_structure("p6bar"), DEVICE, DTYPE) cell = state.row_vector_cell[0] frac = state.positions @ torch.linalg.inv(cell) dataset = _moyo_dataset(cell, frac, state.atomic_numbers) - rotations, translations = _extract_symmetry_ops(dataset, DTYPE, CPU) + rotations, translations = _extract_symmetry_ops(dataset, DTYPE, DEVICE) old_threshold = sym_mod._SYMM_MAP_CHUNK_THRESHOLD # noqa: SLF001 try: @@ -538,9 +548,6 @@ def test_build_symmetry_map_chunked_matches_vectorized() -> None: assert torch.equal(vectorized, chunked) -# === Tests: Optimization === - - class TestFixSymmetryWithOptimization: """Test FixSymmetry with actual optimization routines.""" @@ -560,7 +567,7 @@ def test_distorted_preserves_symmetry( """Compressed structure relaxes while preserving symmetry.""" atoms = make_structure(structure_name) expected = SPACEGROUPS[structure_name] - state = ts.io.atoms_to_state(atoms, CPU, DTYPE) + state = ts.io.atoms_to_state(atoms, DEVICE, DTYPE) constraint = FixSymmetry.from_state( state, symprec=SYMPREC, @@ -581,11 +588,11 @@ def test_distorted_preserves_symmetry( @pytest.mark.parametrize("cell_filter", [ts.CellFilter.unit, ts.CellFilter.frechet]) def test_cell_filter_preserves_symmetry( self, - model: UnbatchedLennardJonesModel, + model: LennardJonesModel, cell_filter: ts.CellFilter, ) -> None: """Cell filters with FixSymmetry preserve symmetry.""" - state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE) + state = ts.io.atoms_to_state(make_structure("fcc"), DEVICE, DTYPE) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) state.constraints = [constraint] initial = get_symmetry_datasets(state, symprec=SYMPREC) @@ -607,7 +614,7 @@ def test_lbfgs_preserves_symmetry( cell_filter: ts.CellFilter, ) -> None: """Regression: LBFGS must use set_constrained_cell for FixSymmetry support.""" - state = ts.io.atoms_to_state(make_structure("bcc"), CPU, DTYPE) + state = ts.io.atoms_to_state(make_structure("bcc"), DEVICE, DTYPE) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) state.constraints = [constraint] state.cell = state.cell * 0.95 @@ -630,10 +637,12 @@ def test_noisy_model_loses_symmetry_without_constraint( self, noisy_lj_model: NoisyModelWrapper, ) -> None: - """Negative control: without FixSymmetry, noise breaks rotated BCC symmetry.""" - state = ts.io.atoms_to_state(make_structure("bcc_rotated"), CPU, DTYPE) - noisier_model = NoisyModelWrapper(noisy_lj_model.model, noise_scale=5e-4) - result = run_optimization_check_symmetry(state, noisier_model, constraint=None) + """Negative control: without FixSymmetry, Weibull noise breaks BCC symmetry.""" + # Need supercell to reliably break symmetry. Previously test pinned to magic seed. + state = ts.io.atoms_to_state( + make_structure("bcc_rotated", repeats=max(REPEATS, 2)), DEVICE, DTYPE + ) + result = run_optimization_check_symmetry(state, noisy_lj_model, constraint=None) assert result["initial_spacegroups"][0] == 229 assert result["final_spacegroups"][0] != 229 @@ -646,7 +655,7 @@ def test_noisy_model_preserves_symmetry_with_constraint( ) -> None: """With FixSymmetry, noisy forces still preserve symmetry.""" name = "bcc_rotated" if rotated else "bcc" - state = ts.io.atoms_to_state(make_structure(name), CPU, DTYPE) + state = ts.io.atoms_to_state(make_structure(name), DEVICE, DTYPE) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) result = run_optimization_check_symmetry( state, @@ -665,7 +674,7 @@ def test_cumulative_strain_clamp_direct(self) -> None: 1. The cell doesn't drift beyond the strain envelope 2. Symmetry is preserved after many small steps """ - state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE) + state = ts.io.atoms_to_state(make_structure("fcc", repeats=1), DEVICE, DTYPE) constraint = FixSymmetry.from_state(state, symprec=SYMPREC) constraint.max_cumulative_strain = 0.15 assert constraint.reference_cells is not None diff --git a/tests/test_integrators.py b/tests/test_integrators.py index b3d8c336..dd9b49d0 100644 --- a/tests/test_integrators.py +++ b/tests/test_integrators.py @@ -339,7 +339,7 @@ def test_nvt_nose_hoover(ar_double_sim_state: ts.SimState, lj_model: LennardJone temperatures_list = [t.tolist() for t in temperatures_tensor.T] assert torch.allclose( temperatures_tensor[-1], - torch.tensor([299.9910, 299.6800], dtype=dtype), + torch.tensor([300.0096, 299.7024], dtype=dtype), ) energies_tensor = torch.stack(energies) @@ -728,7 +728,7 @@ def test_npt_nose_hoover(ar_double_sim_state: ts.SimState, lj_model: LennardJone temperatures_list = [t.tolist() for t in temperatures_tensor.T] assert torch.allclose( temperatures_tensor[-1], - torch.tensor([297.8602, 297.5306], dtype=dtype), + torch.tensor([298.2752, 297.9444], dtype=dtype), ) energies_tensor = torch.stack(energies) @@ -1041,11 +1041,6 @@ def test_compute_cell_force_atoms_per_system(): assert abs(force_ratio - 8.0) / 8.0 < 0.1 -# --------------------------------------------------------------------------- -# Reproducibility tests -# --------------------------------------------------------------------------- - - def test_nvt_langevin_reproducibility( ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel ): diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index cae16fb3..e51ead76 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -1096,25 +1096,26 @@ def test_optimizer_batch_consistency( lj_model: ModelInterface, ) -> None: """Test batched optimizer is consistent with individual optimizations.""" - generator = torch.Generator(device=ar_supercell_sim_state.device) - # Create two distinct initial states by cloning and perturbing state1_orig = ar_supercell_sim_state.clone() + state1_orig.rng = 0 # Apply identical perturbations to state1_orig # for state_item in [state1_orig, state2_orig]: # Old loop structure - generator.manual_seed(43) # Reset seed for positions state1_orig.positions += ( torch.randn( - state1_orig.positions.shape, device=state1_orig.device, generator=generator + state1_orig.positions.shape, + device=state1_orig.device, + generator=state1_orig.rng, ) * 0.1 ) if filter_func: - generator.manual_seed(44) # Reset seed for cell state1_orig.cell += ( torch.randn( - state1_orig.cell.shape, device=state1_orig.device, generator=generator + state1_orig.cell.shape, + device=state1_orig.device, + generator=state1_orig.rng, ) * 0.01 ) @@ -1177,7 +1178,7 @@ def energy_converged(e_current: torch.Tensor, e_prev: torch.Tensor) -> bool: # Converge when all batch energies have converged while not torch.allclose(e_current_batch, e_prev_batch, atol=1e-6): e_prev_batch = e_current_batch.clone() - batch_opt_state = step_fn_batch(model=lj_model, state=batch_opt_state) + batch_opt_state = step_fn_batch(model=lj_model, state=batch_opt_state, dt_max=0.3) e_current_batch = batch_opt_state.energy.clone() steps_batch += 1 if steps_batch > 1000: diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index 0660d757..91323880 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -859,8 +859,6 @@ def from_state( reference_cells=reference_cells, ) - # === Symmetrization hooks === - def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None: """Symmetrize forces according to crystal symmetry.""" self._symmetrize_rank1(state, forces) @@ -959,8 +957,6 @@ def _symmetrize_rank1(self, state: SimState, vectors: torch.Tensor) -> None: self.symm_maps[ci], ) - # === Constraint interface === - def get_removed_dof(self, state: SimState) -> torch.Tensor: """Returns zero - constrains direction, not DOF count.""" return torch.zeros(state.n_systems, dtype=torch.long, device=state.device) diff --git a/torch_sim/elastic.py b/torch_sim/elastic.py index 94b95b79..944cdb81 100644 --- a/torch_sim/elastic.py +++ b/torch_sim/elastic.py @@ -1,4 +1,4 @@ -# ruff: noqa: RUF002, RUF003, PLC2401 +# ruff: noqa: RUF003, PLC2401 """Calculation of elastic properties of crystals. Primary Sources and References for Crystal Elasticity. diff --git a/torch_sim/math.py b/torch_sim/math.py index 5f7604f6..64a2edd7 100644 --- a/torch_sim/math.py +++ b/torch_sim/math.py @@ -1,6 +1,6 @@ """Mathematical operations and utilities. Adapted from https://github.com/abhijeetgangan/torch_matfunc.""" -# ruff: noqa: FBT001, FBT002, RUF002 +# ruff: noqa: FBT001, FBT002 from typing import Final diff --git a/torch_sim/models/lennard_jones.py b/torch_sim/models/lennard_jones.py index 518355cc..c4b67c24 100644 --- a/torch_sim/models/lennard_jones.py +++ b/torch_sim/models/lennard_jones.py @@ -1,158 +1,92 @@ -"""Classical pairwise interatomic potential model. +"""Lennard-Jones 12-6 potential model. -This module implements the Lennard-Jones potential for molecular dynamics simulations. -It provides efficient calculation of energies, forces, and stresses based on the -classic 12-6 potential function. The implementation supports both full pairwise -calculations and neighbor list-based optimizations. +Thin wrapper around :class:`~torch_sim.models.pair_potential.PairPotentialModel` with +the :func:`lennard_jones_pair` energy function baked in. Example:: - # Create a Lennard-Jones model with default parameters - model = LennardJonesModel(device=torch.device("cuda")) - - # Create a model with custom parameters - model = LennardJonesModel( - sigma=3.405, # Angstroms - epsilon=0.01032, # eV - cutoff=10.0, # Angstroms - compute_stress=True, - ) - - # Calculate properties for a simulation state - output = model(sim_state) - energy = output["energy"] - forces = output["forces"] + model = LennardJonesModel(sigma=3.405, epsilon=0.0104, cutoff=8.5) + results = model(sim_state) """ -from collections.abc import Callable +from __future__ import annotations + +import functools +from collections.abc import Callable # noqa: TC003 import torch -import torch_sim as ts -from torch_sim import transforms -from torch_sim.models.interface import ModelInterface +from torch_sim.models.pair_potential import PairPotentialModel from torch_sim.neighbors import torchsim_nl -DEFAULT_SIGMA = 1.0 -DEFAULT_EPSILON = 1.0 - - def lennard_jones_pair( dr: torch.Tensor, - sigma: float | torch.Tensor = DEFAULT_SIGMA, - epsilon: float | torch.Tensor = DEFAULT_EPSILON, + zi: torch.Tensor, # noqa: ARG001 + zj: torch.Tensor, # noqa: ARG001 + sigma: torch.Tensor | float = 1.0, + epsilon: torch.Tensor | float = 1.0, ) -> torch.Tensor: - """Calculate pairwise Lennard-Jones interaction energies between particles. - - Implements the standard 12-6 Lennard-Jones potential that combines short-range - repulsion with longer-range attraction. The potential has a minimum at r=sigma. + """Lennard-Jones 12-6 pair energy. - The functional form is: - V(r) = 4*epsilon*[(sigma/r)^12 - (sigma/r)^6] + V(r) = 4ε[(σ/r)¹² - (σ/r)⁶] Args: - dr: Pairwise distances between particles. Shape: [n, m]. - sigma: Distance at which potential reaches its minimum. Either a scalar float - or tensor of shape [n, m] for particle-specific interaction distances. - epsilon: Depth of the potential well (energy scale). Either a scalar float - or tensor of shape [n, m] for pair-specific interaction strengths. + dr: Pairwise distances, shape [n_pairs]. + zi: Atomic numbers of first atoms (unused, for interface compatibility). + zj: Atomic numbers of second atoms (unused, for interface compatibility). + sigma: Length scale. Defaults to 1.0. + epsilon: Energy scale. Defaults to 1.0. Returns: - torch.Tensor: Pairwise Lennard-Jones interaction energies between particles. - Shape: [n, m]. Each element [i,j] represents the interaction energy between - particles i and j. + Pair energies, shape [n_pairs]. """ - # Calculate inverse dr and its powers - idr = sigma / dr - idr2 = idr * idr - idr6 = idr2 * idr2 * idr2 + idr6 = (sigma / dr).pow(6) idr12 = idr6 * idr6 - - # Calculate potential energy energy = 4.0 * epsilon * (idr12 - idr6) - - # Handle potential numerical instabilities and infinities return torch.where(dr > 0, energy, torch.zeros_like(energy)) - # return torch.nan_to_num(energy, nan=0.0, posinf=0.0, neginf=0.0) def lennard_jones_pair_force( dr: torch.Tensor, - sigma: float | torch.Tensor = DEFAULT_SIGMA, - epsilon: float | torch.Tensor = DEFAULT_EPSILON, + sigma: torch.Tensor | float = 1.0, + epsilon: torch.Tensor | float = 1.0, ) -> torch.Tensor: - """Calculate pairwise Lennard-Jones forces between particles. - - Implements the force derived from the 12-6 Lennard-Jones potential. The force - is repulsive at short range and attractive at long range, with a zero-crossing - at r=sigma. - - The functional form is: - F(r) = 24*epsilon/r * [(2*sigma^12/r^12) - (sigma^6/r^6)] + """Lennard-Jones 12-6 pair force (negative gradient of energy). - This is the negative gradient of the Lennard-Jones potential energy. + F(r) = 24ε/r [2(σ/r)¹² - (σ/r)⁶] Args: - dr: Pairwise distances between particles. Shape: [n, m]. - sigma: Distance at which force changes from repulsive to attractive. - Either a scalar float or tensor of shape [n, m] for particle-specific - interaction distances. - epsilon: Energy scale of the interaction. Either a scalar float or tensor - of shape [n, m] for pair-specific interaction strengths. + dr: Pairwise distances, shape [n_pairs]. + sigma: Length scale. Defaults to 1.0. + epsilon: Energy scale. Defaults to 1.0. Returns: - torch.Tensor: Pairwise Lennard-Jones forces between particles. Shape: [n, m]. - Each element [i,j] represents the force magnitude between particles i and j. - Positive values indicate repulsion, negative values indicate attraction. + Pair force magnitudes (positive = repulsive), shape [n_pairs]. """ - # Calculate inverse dr and its powers idr = sigma / dr - idr2 = idr * idr - idr6 = idr2 * idr2 * idr2 + idr6 = idr.pow(6) idr12 = idr6 * idr6 - - # Calculate force (negative gradient of potential) - # F = -24*epsilon/r * ((sigma/r)^6 - 2*(sigma/r)^12) force = 24.0 * epsilon / dr * (2.0 * idr12 - idr6) - - # Handle potential numerical instabilities and infinities return torch.where(dr > 0, force, torch.zeros_like(force)) -class UnbatchedLennardJonesModel(ModelInterface): - """Unbatched Lennard-Jones model. - - Implements the Lennard-Jones 12-6 potential for molecular dynamics simulations. - This implementation loops over systems in batched inputs and is intended for - testing or baseline comparisons with the default batched model. +class LennardJonesModel(PairPotentialModel): + """Lennard-Jones 12-6 pair potential model. - Attributes: - sigma (torch.Tensor): Length parameter controlling particle size/repulsion - distance. - epsilon (torch.Tensor): Energy parameter controlling interaction strength. - cutoff (torch.Tensor): Distance cutoff for truncating potential calculation. - device (torch.device): Device where calculations are performed. - dtype (torch.dtype): Data type used for calculations. - compute_forces (bool): Whether to compute atomic forces. - compute_stress (bool): Whether to compute stress tensor. - per_atom_energies (bool): Whether to compute per-atom energy decomposition. - per_atom_stresses (bool): Whether to compute per-atom stress decomposition. - neighbor_list_fn (Callable): Function used to construct neighbor lists. + Convenience subclass that fixes the pair function to :func:`lj_pair` so the + caller only needs to supply ``sigma`` and ``epsilon``. Example:: - # Basic usage with default parameters - lj_model = UnbatchedLennardJonesModel(device=torch.device("cuda")) - results = lj_model(sim_state) - - # Custom parameterization for Argon - ar_model = UnbatchedLennardJonesModel( - sigma=3.405, # Å - epsilon=0.0104, # eV - cutoff=8.5, # Å + model = LennardJonesModel( + sigma=3.405, + epsilon=0.0104, + cutoff=2.5 * 3.405, + compute_forces=True, compute_stress=True, ) + results = model(sim_state) """ def __init__( @@ -160,360 +94,48 @@ def __init__( sigma: float = 1.0, epsilon: float = 1.0, device: torch.device | None = None, - dtype: torch.dtype = torch.float32, - *, # Force keyword-only arguments + dtype: torch.dtype = torch.float64, + *, compute_forces: bool = True, compute_stress: bool = False, per_atom_energies: bool = False, per_atom_stresses: bool = False, neighbor_list_fn: Callable = torchsim_nl, cutoff: float | None = None, + retain_graph: bool = False, ) -> None: - """Initialize the Lennard-Jones potential calculator. - - Creates a model with specified interaction parameters and computational flags. - The model can be configured to compute different properties (forces, stresses) - and use different optimization strategies. - - Args: - sigma (float): Length parameter of the Lennard-Jones potential in distance - units. Controls the size of particles. Defaults to 1.0. - epsilon (float): Energy parameter of the Lennard-Jones potential in energy - units. Controls the strength of the interaction. Defaults to 1.0. - device (torch.device | None): Device to run computations on. If None, uses - CPU. Defaults to None. - dtype (torch.dtype): Data type for calculations. Defaults to torch.float32. - compute_forces (bool): Whether to compute forces. Defaults to True. - compute_stress (bool): Whether to compute stress tensor. Defaults to False. - per_atom_energies (bool): Whether to compute per-atom energy decomposition. - Defaults to False. - per_atom_stresses (bool): Whether to compute per-atom stress decomposition. - Defaults to False. - neighbor_list_fn (Callable): Batched neighbor-list function to use when - constructing interactions. Defaults to torchsim_nl. - cutoff (float | None): Cutoff distance for interactions in distance units. - If None, uses 2.5*sigma. Defaults to None. - - Example:: - - # Model with custom parameters - model = UnbatchedLennardJonesModel( - sigma=3.405, - epsilon=0.01032, - device=torch.device("cuda"), - dtype=torch.float64, - compute_stress=True, - per_atom_energies=True, - cutoff=10.0, - ) - """ - super().__init__() - self._device = device or torch.device("cpu") - self._dtype = dtype - self._compute_forces = compute_forces - self._compute_stress = compute_stress - self.per_atom_energies = per_atom_energies - self.per_atom_stresses = per_atom_stresses - self.neighbor_list_fn = neighbor_list_fn - - # Convert parameters to tensors - self.sigma = torch.tensor(sigma, dtype=dtype, device=self.device) - self.cutoff = torch.tensor(cutoff or 2.5 * sigma, dtype=dtype, device=self.device) - self.epsilon = torch.tensor(epsilon, dtype=dtype, device=self.device) - - def unbatched_forward( - self, - state: ts.SimState, - ) -> dict[str, torch.Tensor]: - """Compute Lennard-Jones properties for a single unbatched system. - - Internal implementation that processes a single, non-batched simulation state. - This method handles the core computations of pair interactions, neighbor lists, - and property calculations. - - Args: - state (SimState): Single, non-batched simulation state containing atomic - positions, cell vectors, and other system information. - - Returns: - dict[str, torch.Tensor]: Computed properties: - - "energy": Total potential energy (scalar) - - "forces": Atomic forces with shape [n_atoms, 3] (if - compute_forces=True) - - "stress": Stress tensor with shape [3, 3] (if compute_stress=True) - - "energies": Per-atom energies with shape [n_atoms] (if - per_atom_energies=True) - - "stresses": Per-atom stresses with shape [n_atoms, 3, 3] (if - per_atom_stresses=True) - - Notes: - Neighbor lists are always used to construct interacting pairs. - """ - positions = state.positions - cell = state.row_vector_cell - cell = cell.squeeze() - - # Ensure system_idx exists (create if None for single system) - system_idx = ( - state.system_idx - if state.system_idx is not None - else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device) - ) - - # Wrap positions into the unit cell - wrapped_positions = ( - ts.transforms.pbc_wrap_batched(positions, state.cell, system_idx, state.pbc) - if state.pbc.any() - else positions - ) - - mapping, _, shifts_idx = self.neighbor_list_fn( - positions=wrapped_positions, - cell=cell, - pbc=state.pbc, - cutoff=self.cutoff, - system_idx=system_idx, - ) - # Pass shifts_idx directly - get_pair_displacements will convert them - dr_vec, distances = transforms.get_pair_displacements( - positions=wrapped_positions, - cell=cell, - pbc=state.pbc, - pairs=(mapping[0], mapping[1]), - shifts=shifts_idx, - ) - - # Calculate pair energies and apply cutoff - pair_energies = lennard_jones_pair( - distances, sigma=self.sigma, epsilon=self.epsilon - ) - # Zero out energies beyond cutoff - mask = distances < self.cutoff - pair_energies = torch.where(mask, pair_energies, torch.zeros_like(pair_energies)) - - # Initialize results with total energy (sum/2 to avoid double counting) - results = {"energy": 0.5 * pair_energies.sum()} - - if self.per_atom_energies: - atom_energies = torch.zeros( - positions.shape[0], dtype=self.dtype, device=self.device - ) - # Each atom gets half of the pair energy - atom_energies.index_add_(0, mapping[0], 0.5 * pair_energies) - atom_energies.index_add_(0, mapping[1], 0.5 * pair_energies) - results["energies"] = atom_energies - - if self.compute_forces or self.compute_stress: - # Calculate forces and apply cutoff - pair_forces = lennard_jones_pair_force( - distances, sigma=self.sigma, epsilon=self.epsilon - ) - pair_forces = torch.where(mask, pair_forces, torch.zeros_like(pair_forces)) - - # Project forces along displacement vectors - force_vectors = (pair_forces / distances)[:, None] * dr_vec - - if self.compute_forces: - # Initialize forces tensor - forces = torch.zeros_like(positions) - # Add force contributions (f_ij on i, -f_ij on j) - forces.index_add_(0, mapping[0], -force_vectors) - forces.index_add_(0, mapping[1], force_vectors) - results["forces"] = forces - - if self.compute_stress and cell is not None: - # Compute stress tensor - stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors) - volume = torch.abs(torch.linalg.det(cell)) - - results["stress"] = -stress_per_pair.sum(dim=0) / volume - - if self.per_atom_stresses: - atom_stresses = torch.zeros( - (state.positions.shape[0], 3, 3), - dtype=self.dtype, - device=self.device, - ) - atom_stresses.index_add_(0, mapping[0], -0.5 * stress_per_pair) - atom_stresses.index_add_(0, mapping[1], -0.5 * stress_per_pair) - results["stresses"] = atom_stresses / volume - - return results - - def forward(self, state: ts.SimState, **_kwargs: object) -> dict[str, torch.Tensor]: - """Compute Lennard-Jones energies, forces, and stresses for a system. - - Main entry point for Lennard-Jones calculations that handles batched states by - dispatching each system to the unbatched implementation and combining results. + """Initialize the Lennard-Jones model. Args: - state (SimState): Input state containing atomic positions, cell vectors, - and other system information. - **_kwargs: Unused; accepted for interface compatibility. - - Returns: - dict[str, torch.Tensor]: Computed properties: - - "energy": Potential energy with shape [n_systems] - - "forces": Atomic forces with shape [n_atoms, 3] (if - compute_forces=True) - - "stress": Stress tensor with shape [n_systems, 3, 3] (if - compute_stress=True) - - "energies": Per-atom energies with shape [n_atoms] (if - per_atom_energies=True) - - "stresses": Per-atom stresses with shape [n_atoms, 3, 3] (if - per_atom_stresses=True) - - Raises: - ValueError: If system cannot be inferred for multi-cell systems. - - Example:: - - # Compute properties for a simulation state - model = UnbatchedLennardJonesModel(compute_stress=True) - results = model(sim_state) - - energy = results["energy"] # Shape: [n_systems] - forces = results["forces"] # Shape: [n_atoms, 3] - stress = results["stress"] # Shape: [n_systems, 3, 3] - energies = results["energies"] # Shape: [n_atoms] - stresses = results["stresses"] # Shape: [n_atoms, 3, 3] + sigma: Length scale parameter. Defaults to 1.0. + epsilon: Energy scale parameter. Defaults to 1.0. + device: Device for computations. Defaults to CPU. + dtype: Floating-point dtype. Defaults to torch.float32. + compute_forces: Whether to compute atomic forces. Defaults to True. + compute_stress: Whether to compute the stress tensor. Defaults to False. + per_atom_energies: Whether to return per-atom energies. Defaults to False. + per_atom_stresses: Whether to return per-atom stresses. Defaults to False. + neighbor_list_fn: Neighbor-list constructor. Defaults to torchsim_nl. + cutoff: Interaction cutoff. Defaults to 2.5 * sigma. + retain_graph: Keep computation graph for differentiable simulation. """ - sim_state = state - - if sim_state.system_idx is None and sim_state.cell.shape[0] > 1: - raise ValueError("System can only be inferred for batch size 1.") - - outputs = [ - self.unbatched_forward(sim_state[idx]) for idx in range(sim_state.n_systems) - ] - properties = outputs[0] - - # we always return tensors - # per atom properties are returned as (atoms, ...) tensors - # global properties are returned as shape (..., n) tensors - results: dict[str, torch.Tensor] = {} - for key in ("stress", "energy"): - if key in properties: - results[key] = torch.stack([out[key] for out in outputs]) - for key in ("forces", "energies", "stresses"): - if key in properties: - results[key] = torch.cat([out[key] for out in outputs], dim=0) - - return results - - -class LennardJonesModel(UnbatchedLennardJonesModel): - """Default vectorized Lennard-Jones model for batched systems. - - This class computes Lennard-Jones energies, forces, and stresses for all systems in - a batch in one pass, avoiding Python loops over systems in the model forward path. - Use this class for production runs. - """ - - def forward( # noqa: PLR0915 - self, state: ts.SimState, **_kwargs: object - ) -> dict[str, torch.Tensor]: - """Compute Lennard-Jones properties with batched tensor operations.""" - sim_state = state - - if sim_state.system_idx is None and sim_state.cell.shape[0] > 1: - raise ValueError("System can only be inferred for batch size 1.") - - positions = sim_state.positions - row_cell = sim_state.row_vector_cell - pbc = sim_state.pbc - - system_idx = ( - sim_state.system_idx - if sim_state.system_idx is not None - else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device) - ) - - wrapped_positions = ( - ts.transforms.pbc_wrap_batched(positions, sim_state.cell, system_idx, pbc) - if pbc.any() - else positions - ) - - if pbc.ndim == 1: - pbc_batched = pbc.unsqueeze(0).expand(sim_state.n_systems, -1) - else: - pbc_batched = pbc - - mapping, system_mapping, shifts_idx = self.neighbor_list_fn( - positions=wrapped_positions, - cell=row_cell, - pbc=pbc_batched, - cutoff=self.cutoff, - system_idx=system_idx, + self.sigma = sigma + self.epsilon = epsilon + pair_fn = functools.partial(lennard_jones_pair, sigma=sigma, epsilon=epsilon) + super().__init__( + pair_fn=pair_fn, + cutoff=cutoff if cutoff is not None else 2.5 * sigma, + device=device, + dtype=dtype, + compute_forces=compute_forces, + compute_stress=compute_stress, + per_atom_energies=per_atom_energies, + per_atom_stresses=per_atom_stresses, + neighbor_list_fn=neighbor_list_fn, + reduce_to_half_list=True, + retain_graph=retain_graph, ) - cell_shifts = transforms.compute_cell_shifts(row_cell, shifts_idx, system_mapping) - dr_vec = ( - wrapped_positions[mapping[1]] - wrapped_positions[mapping[0]] + cell_shifts - ) - distances = dr_vec.norm(dim=1) - - cutoff_mask = distances < self.cutoff - pair_energies = lennard_jones_pair( - distances, sigma=self.sigma, epsilon=self.epsilon - ) - pair_energies = torch.where( - cutoff_mask, pair_energies, torch.zeros_like(pair_energies) - ) - - n_systems = sim_state.n_systems - results: dict[str, torch.Tensor] = {} - energy = torch.zeros(n_systems, dtype=self.dtype, device=self.device) - energy.index_add_(0, system_mapping, 0.5 * pair_energies) - results["energy"] = energy - - if self.per_atom_energies: - atom_energies = torch.zeros( - positions.shape[0], dtype=self.dtype, device=self.device - ) - atom_energies.index_add_(0, mapping[0], 0.5 * pair_energies) - atom_energies.index_add_(0, mapping[1], 0.5 * pair_energies) - results["energies"] = atom_energies - - if self.compute_forces or self.compute_stress: - pair_forces = lennard_jones_pair_force( - distances, sigma=self.sigma, epsilon=self.epsilon - ) - pair_forces = torch.where( - cutoff_mask, pair_forces, torch.zeros_like(pair_forces) - ) - safe_distances = torch.where( - distances > 0, distances, torch.ones_like(distances) - ) - force_vectors = (pair_forces / safe_distances)[:, None] * dr_vec - - if self.compute_forces: - forces = torch.zeros_like(positions) - forces.index_add_(0, mapping[0], -force_vectors) - forces.index_add_(0, mapping[1], force_vectors) - results["forces"] = forces - - if self.compute_stress: - stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors) - volumes = torch.abs(torch.linalg.det(row_cell)) - stress = torch.zeros( - (n_systems, 3, 3), - dtype=self.dtype, - device=self.device, - ) - stress.index_add_(0, system_mapping, -stress_per_pair) - results["stress"] = stress / volumes[:, None, None] - - if self.per_atom_stresses: - atom_stresses = torch.zeros( - (positions.shape[0], 3, 3), - dtype=self.dtype, - device=self.device, - ) - atom_stresses.index_add_(0, mapping[0], -0.5 * stress_per_pair) - atom_stresses.index_add_(0, mapping[1], -0.5 * stress_per_pair) - atom_volumes = volumes[system_idx] - results["stresses"] = atom_stresses / atom_volumes[:, None, None] - return results +# Keep old name as alias for backward compatibility +UnbatchedLennardJonesModel = LennardJonesModel diff --git a/torch_sim/models/morse.py b/torch_sim/models/morse.py index 2c881808..1e786c2a 100644 --- a/torch_sim/models/morse.py +++ b/torch_sim/models/morse.py @@ -1,395 +1,144 @@ -"""Anharmonic interatomic potential for molecular dynamics. +"""Morse potential model. -This module implements the Morse potential for molecular dynamics simulations. -The Morse potential provides a more realistic description of anharmonic bond -behavior than simple harmonic potentials, capturing bond breaking and formation. -It includes both energy and force calculations with support for neighbor lists. +Thin wrapper around :class:`~torch_sim.models.pair_potential.PairPotentialModel` with +the :func:`morse_pair` energy function baked in. Example:: - # Create a Morse model with default parameters - model = MorseModel(device=torch.device("cuda")) - - # Calculate properties for a simulation state - output = model(sim_state) - energy = output["energy"] - forces = output["forces"] + model = MorseModel(sigma=2.55, epsilon=0.436, alpha=1.359, cutoff=6.0) + results = model(sim_state) +""" -Notes: - The Morse potential follows the form: - V(r) = D_e * (1 - exp(-a(r-r_e)))^2 +from __future__ import annotations - Where: - - D_e (epsilon) is the well depth (dissociation energy) - - r_e (sigma) is the equilibrium bond distance - - a (alpha) controls the width of the potential well -""" +import functools +from collections.abc import Callable # noqa: TC003 import torch -import torch_sim as ts -from torch_sim import transforms -from torch_sim.models.interface import ModelInterface +from torch_sim.models.pair_potential import PairPotentialModel from torch_sim.neighbors import torchsim_nl -DEFAULT_SIGMA = 1.0 -DEFAULT_EPSILON = 5.0 -DEFAULT_ALPHA = 5.0 - - def morse_pair( dr: torch.Tensor, - sigma: float | torch.Tensor = DEFAULT_SIGMA, - epsilon: float | torch.Tensor = DEFAULT_EPSILON, - alpha: float | torch.Tensor = DEFAULT_ALPHA, + zi: torch.Tensor, # noqa: ARG001 + zj: torch.Tensor, # noqa: ARG001 + sigma: torch.Tensor | float = 1.0, + epsilon: torch.Tensor | float = 5.0, + alpha: torch.Tensor | float = 5.0, ) -> torch.Tensor: - """Calculate pairwise Morse potential energies between particles. - - Implements the Morse potential that combines short-range repulsion with - longer-range attraction. The potential has a minimum at r=sigma and approaches - -epsilon as r→∞. + """Morse pair energy. - The functional form is: - V(r) = epsilon * (1 - exp(-alpha*(r-sigma)))^2 - epsilon + V(r) = ε(1 - exp(-α(r - σ)))² - ε Args: - dr: Pairwise distances between particles. Shape: [n, m]. - sigma: Distance at which potential reaches its minimum. Either a scalar float - or tensor of shape [n, m] for particle-specific equilibrium distances. - epsilon: Depth of the potential well (energy scale). Either a scalar float - or tensor of shape [n, m] for pair-specific interaction strengths. - alpha: Controls the width of the potential well. Larger values give a narrower - well. Either a scalar float or tensor of shape [n, m]. + dr: Pairwise distances, shape [n_pairs]. + zi: Atomic numbers of first atoms (unused). + zj: Atomic numbers of second atoms (unused). + sigma: Equilibrium bond distance. Defaults to 1.0. + epsilon: Well depth / dissociation energy. Defaults to 5.0. + alpha: Width parameter. Defaults to 5.0. Returns: - torch.Tensor: Pairwise Morse interaction energies between particles. - Shape: [n, m]. Each element [i,j] represents the interaction energy between - particles i and j. + Pair energies, shape [n_pairs]. """ - # Calculate potential energy energy = epsilon * (1.0 - torch.exp(-alpha * (dr - sigma))).pow(2) - epsilon - - # Handle potential numerical instabilities return torch.where(dr > 0, energy, torch.zeros_like(energy)) def morse_pair_force( dr: torch.Tensor, - sigma: float | torch.Tensor = DEFAULT_SIGMA, - epsilon: float | torch.Tensor = DEFAULT_EPSILON, - alpha: float | torch.Tensor = DEFAULT_ALPHA, + zi: torch.Tensor, # noqa: ARG001 + zj: torch.Tensor, # noqa: ARG001 + sigma: torch.Tensor | float = 1.0, + epsilon: torch.Tensor | float = 5.0, + alpha: torch.Tensor | float = 5.0, ) -> torch.Tensor: - """Calculate pairwise Morse forces between particles. - - Implements the force derived from the Morse potential. The force changes - from repulsive to attractive at r=sigma. - - The functional form is: - F(r) = 2*alpha*epsilon * exp(-alpha*(r-sigma)) * (1 - exp(-alpha*(r-sigma))) + """Morse pair force (negative gradient of energy). - This is the negative gradient of the Morse potential energy. + F(r) = -2αε exp(-α(r-σ)) (1 - exp(-α(r-σ))) Args: - dr: Pairwise distances between particles. Shape: [n, m]. - sigma: Distance at which force changes from repulsive to attractive. - Either a scalar float or tensor of shape [n, m]. - epsilon: Energy scale of the interaction. Either a scalar float or tensor - of shape [n, m]. - alpha: Controls the force range and stiffness. Either a scalar float or - tensor of shape [n, m]. + dr: Pairwise distances. + zi: Atomic numbers of first atoms (unused). + zj: Atomic numbers of second atoms (unused). + sigma: Equilibrium distance. Defaults to 1.0. + epsilon: Well depth. Defaults to 5.0. + alpha: Width parameter. Defaults to 5.0. Returns: - torch.Tensor: Pairwise Morse forces between particles. Shape: [n, m]. - Positive values indicate repulsion, negative values indicate attraction. + Pair force magnitudes. """ exp_term = torch.exp(-alpha * (dr - sigma)) force = -2.0 * alpha * epsilon * exp_term * (1.0 - exp_term) - - # Handle potential numerical instabilities return torch.where(dr > 0, force, torch.zeros_like(force)) -class MorseModel(ModelInterface): - """Morse potential energy and force calculator. - - Implements the Morse potential for molecular dynamics simulations. This model - is particularly useful for modeling covalent bonds as it can accurately describe - bond stretching, breaking, and anharmonic behavior. Unlike the Lennard-Jones - potential, Morse is often better for cases where accurate dissociation energy - and bond dynamics are important. +class MorseModel(PairPotentialModel): + """Morse pair potential model. - Attributes: - sigma (torch.Tensor): Equilibrium bond length (r_e) in distance units. - epsilon (torch.Tensor): Dissociation energy (D_e) in energy units. - alpha (torch.Tensor): Parameter controlling the width/steepness of the potential. - cutoff (torch.Tensor): Distance cutoff for truncating potential calculation. - device (torch.device): Device where calculations are performed. - dtype (torch.dtype): Data type used for calculations. - compute_forces (bool): Whether to compute atomic forces. - compute_stress (bool): Whether to compute stress tensor. - per_atom_energies (bool): Whether to compute per-atom energy decomposition. - per_atom_stresses (bool): Whether to compute per-atom stress decomposition. - use_neighbor_list (bool): Whether to use neighbor list optimization. + Convenience subclass that fixes the pair function to :func:`morse_pair` so the + caller only needs to supply ``sigma``, ``epsilon``, and ``alpha``. - Examples: - ```py - # Basic usage with default parameters - morse_model = MorseModel(device=torch.device("cuda")) - results = morse_model(sim_state) + Example:: - # Model parameterized for O-H bonds in water, atomic units - oh_model = MorseModel( - sigma=0.96, - epsilon=4.52, - alpha=2.0, + model = MorseModel( + sigma=2.55, + epsilon=0.436, + alpha=1.359, + cutoff=6.0, compute_forces=True, - compute_stress=True, ) - ``` + results = model(sim_state) """ def __init__( self, - sigma: float | torch.Tensor = 1.0, - epsilon: float | torch.Tensor = 5.0, - alpha: float | torch.Tensor = 5.0, + sigma: float = 1.0, + epsilon: float = 5.0, + alpha: float = 5.0, device: torch.device | None = None, - dtype: torch.dtype = torch.float32, - *, # Force keyword-only arguments - compute_forces: bool = False, + dtype: torch.dtype = torch.float64, + *, + compute_forces: bool = True, compute_stress: bool = False, per_atom_energies: bool = False, per_atom_stresses: bool = False, - use_neighbor_list: bool = True, - cutoff: float | torch.Tensor | None = None, + neighbor_list_fn: Callable = torchsim_nl, + cutoff: float | None = None, + retain_graph: bool = False, ) -> None: - """Initialize the Morse potential calculator. - - Creates a model with specified interaction parameters and computational flags. - The Morse potential is defined by three key parameters: sigma (equilibrium - distance), epsilon (dissociation energy), and alpha (width control). + """Initialize the Morse potential model. Args: - sigma (float): Equilibrium bond distance (r_e) in distance units. - Defaults to 1.0. - epsilon (float): Dissociation energy (D_e) in energy units. - Defaults to 5.0. - alpha (float): Controls the width/steepness of the potential well. - Larger values create a narrower well. Defaults to 5.0. - device (torch.device | None): Device to run computations on. If None, uses - CPU. Defaults to None. - dtype (torch.dtype): Data type for calculations. Defaults to torch.float32. - compute_forces (bool): Whether to compute forces. Defaults to False. - compute_stress (bool): Whether to compute stress tensor. Defaults to False. - per_atom_energies (bool): Whether to compute per-atom energy decomposition. - Defaults to False. - per_atom_stresses (bool): Whether to compute per-atom stress decomposition. - Defaults to False. - use_neighbor_list (bool): Whether to use a neighbor list for optimization. - Significantly faster for large systems. Defaults to True. - cutoff (float | None): Cutoff distance for interactions in distance units. - If None, uses 2.5*sigma. Defaults to None. - - Examples: - ```py - # Basic model with default parameters - model = MorseModel() - - # Model for diatomic hydrogen - model = MorseModel( - sigma=0.74, # Å - epsilon=4.75, # eV - alpha=1.94, # Steepness parameter - compute_forces=True, - ) - ``` - - Notes: - The alpha parameter can be related to the harmonic force constant k and - dissociation energy D_e by: alpha = sqrt(k/(2*D_e)) - """ - super().__init__() - self._device = device or torch.device("cpu") - self._dtype = dtype - self._compute_forces = compute_forces - self._compute_stress = compute_stress - self._per_atom_energies = per_atom_energies - self._per_atom_stresses = per_atom_stresses - self.use_neighbor_list = use_neighbor_list - # Convert parameters to tensors - self.sigma = torch.as_tensor(sigma, dtype=self.dtype, device=self.device) - self.cutoff = torch.as_tensor( - cutoff or 2.5 * sigma, dtype=self.dtype, device=self.device - ) - self.epsilon = torch.as_tensor(epsilon, dtype=self.dtype, device=self.device) - self.alpha = torch.as_tensor(alpha, dtype=self.dtype, device=self.device) - - def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: - """Compute Morse potential properties for a single unbatched system. - - Internal implementation that processes a single, non-batched simulation state. - This method handles the core computations of pair interactions, including - neighbor list construction, distance calculations, and property computation. - - Args: - state (SimState): Single, non-batched simulation state containing atomic - positions, cell vectors, and other system information. - - Returns: - dict[str, torch.Tensor]: Computed properties: - - "energy": Total potential energy (scalar) - - "forces": Atomic forces with shape [n_atoms, 3] (if - compute_forces=True) - - "stress": Stress tensor with shape [3, 3] (if compute_stress=True) - - "energies": Per-atom energies with shape [n_atoms] (if - per_atom_energies=True) - - "stresses": Per-atom stresses with shape [n_atoms, 3, 3] (if - per_atom_stresses=True) - - Notes: - This method can work with both neighbor list and full pairwise calculations. - In both cases, interactions are truncated at the cutoff distance. + sigma: Equilibrium bond distance. Defaults to 1.0. + epsilon: Well depth / dissociation energy. Defaults to 5.0. + alpha: Width parameter. Defaults to 5.0. + device: Device for computations. Defaults to CPU. + dtype: Floating-point dtype. Defaults to torch.float32. + compute_forces: Whether to compute atomic forces. Defaults to True. + compute_stress: Whether to compute the stress tensor. Defaults to False. + per_atom_energies: Whether to return per-atom energies. Defaults to False. + per_atom_stresses: Whether to return per-atom stresses. Defaults to False. + neighbor_list_fn: Neighbor-list constructor. Defaults to torchsim_nl. + cutoff: Interaction cutoff. Defaults to 2.5 * sigma. + retain_graph: Keep computation graph for differentiable simulation. """ - positions = state.positions - cell = state.row_vector_cell - cell = cell.squeeze() - pbc = state.pbc - - # Wrap positions into the unit cell - wrapped_positions = ( - ts.transforms.pbc_wrap_batched(positions, state.cell, state.system_idx, pbc) - if pbc.any() - else positions + self.sigma = sigma + self.epsilon = epsilon + self.alpha = alpha + pair_fn = functools.partial(morse_pair, sigma=sigma, epsilon=epsilon, alpha=alpha) + super().__init__( + pair_fn=pair_fn, + cutoff=cutoff if cutoff is not None else 2.5 * sigma, + device=device, + dtype=dtype, + compute_forces=compute_forces, + compute_stress=compute_stress, + per_atom_energies=per_atom_energies, + per_atom_stresses=per_atom_stresses, + neighbor_list_fn=neighbor_list_fn, + reduce_to_half_list=True, + retain_graph=retain_graph, ) - - if self.use_neighbor_list: - mapping, _, shifts_idx = torchsim_nl( - positions=wrapped_positions, - cell=cell, - pbc=pbc, - cutoff=self.cutoff, - system_idx=state.system_idx, - ) - # Pass shifts_idx directly - get_pair_displacements will convert them - dr_vec, distances = transforms.get_pair_displacements( - positions=wrapped_positions, - cell=cell, - pbc=pbc, - pairs=(mapping[0], mapping[1]), - shifts=shifts_idx, - ) - else: - dr_vec, distances = transforms.get_pair_displacements( - positions=wrapped_positions, - cell=cell, - pbc=pbc, - ) - mask = torch.eye( - wrapped_positions.shape[0], dtype=torch.bool, device=self.device - ) - distances = distances.masked_fill(mask, float("inf")) - mask = distances < self.cutoff - i, j = torch.where(mask) - mapping = torch.stack([j, i]) - dr_vec = dr_vec[mask] - distances = distances[mask] - - # Calculate pair energies and apply cutoff - pair_energies = morse_pair( - distances, sigma=self.sigma, epsilon=self.epsilon, alpha=self.alpha - ) - mask = distances < self.cutoff - pair_energies = torch.where(mask, pair_energies, torch.zeros_like(pair_energies)) - - # Initialize results with total energy (sum/2 to avoid double counting) - results = {"energy": 0.5 * pair_energies.sum()} - - if self._per_atom_energies: - atom_energies = torch.zeros( - positions.shape[0], dtype=self.dtype, device=self.device - ) - atom_energies.index_add_(0, mapping[0], 0.5 * pair_energies) - atom_energies.index_add_(0, mapping[1], 0.5 * pair_energies) - results["energies"] = atom_energies - - if self.compute_forces or self.compute_stress: - pair_forces = morse_pair_force( - distances, sigma=self.sigma, epsilon=self.epsilon, alpha=self.alpha - ) - pair_forces = torch.where(mask, pair_forces, torch.zeros_like(pair_forces)) - - force_vectors = (pair_forces / distances)[:, None] * dr_vec - - if self.compute_forces: - forces = torch.zeros_like(state.positions) - forces.index_add_(0, mapping[0], -force_vectors) - forces.index_add_(0, mapping[1], force_vectors) - results["forces"] = forces - - if self.compute_stress and state.cell is not None: - stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors) - volume = torch.abs(torch.linalg.det(state.cell)) - - results["stress"] = -stress_per_pair.sum(dim=0) / volume - - if self._per_atom_stresses: - atom_stresses = torch.zeros( - (state.positions.shape[0], 3, 3), - dtype=self.dtype, - device=self.device, - ) - atom_stresses.index_add_(0, mapping[0], -0.5 * stress_per_pair) - atom_stresses.index_add_(0, mapping[1], -0.5 * stress_per_pair) - results["stresses"] = atom_stresses / volume - - return results - - def forward(self, state: ts.SimState, **_kwargs: object) -> dict[str, torch.Tensor]: - """Compute Morse potential energies, forces, and stresses for a system. - - Main entry point for Morse potential calculations that handles batched states - by dispatching each batch to the unbatched implementation and combining results. - - Args: - state (SimState): Input state containing atomic positions, cell vectors, - and other system information. - **_kwargs: Unused; accepted for interface compatibility. - - Returns: - dict[str, torch.Tensor]: Computed properties: - - "energy": Potential energy with shape [n_systems] - - "forces": Atomic forces with shape [n_atoms, 3] - (if compute_forces=True) - - "stress": Stress tensor with shape [n_systems, 3, 3] - (if compute_stress=True) - - May include additional outputs based on configuration - - Raises: - ValueError: If batch cannot be inferred for multi-cell systems. - - Examples: - ```py - # Compute properties for a simulation state - model = MorseModel(compute_forces=True) - results = model(sim_state) - - energy = results["energy"] # Shape: [n_systems] - forces = results["forces"] # Shape: [n_atoms, 3] - ``` - """ - outputs = [self.unbatched_forward(state[i]) for i in range(state.n_systems)] - properties = outputs[0] - - # we always return tensors - # per atom properties are returned as (atoms, ...) tensors - # global properties are returned as shape (..., n) tensors - results: dict[str, torch.Tensor] = {} - for key in ("stress", "energy"): - if key in properties: - results[key] = torch.stack([out[key] for out in outputs]) - for key in ("forces",): - if key in properties: - results[key] = torch.cat([out[key] for out in outputs], dim=0) - - return results diff --git a/torch_sim/models/pair_potential.py b/torch_sim/models/pair_potential.py index a537e6e9..1db1c2c9 100644 --- a/torch_sim/models/pair_potential.py +++ b/torch_sim/models/pair_potential.py @@ -1,4 +1,4 @@ -"""General batched pair potential model and standard pair interaction functions. +"""General batched pair potential and pair forces models. This module provides :class:`PairPotentialModel`, a flexible wrapper that turns any pairwise energy function into a full TorchSim model with forces (via autograd) and @@ -9,26 +9,63 @@ (e.g. the asymmetric particle-life interaction) that cannot be expressed as the gradient of a scalar energy. -Standard pair energy functions (all JIT-compatible): +The pair function signature required by :class:`PairPotentialModel` is: +``pair_fn(distances, atomic_numbers_i, atomic_numbers_j) -> pair_energies``, +where all arguments are 1-D tensors of length n_pairs and the return value is a +1-D tensor of pair energies. Additional parameters (e.g., ``sigma``, ``epsilon``) +can be bound using :func:`functools.partial`. + +Notes: + - The ``cutoff`` parameter determines the neighbor list construction range. + Pairs beyond the cutoff are excluded from energy/force calculations. If your + potential has its own natural cutoff (e.g., WCA potential), ensure the model's + ``cutoff`` is at least as large. + - The ``atomic_numbers_i`` and ``atomic_numbers_j`` arguments are provided for + type-dependent potentials, but can be ignored (e.g., with ``# noqa: ARG001``) + for type-independent potentials like Lennard-Jones. + - The ``dtype`` of the SimState must match the model's ``dtype``. The model will + raise a ``TypeError`` if they don't match. + - Use ``reduce_to_half_list=True`` for symmetric potentials to halve computation + time. Only use ``False`` for asymmetric interactions or when you need the + full neighbor list for other purposes. -* :func:`lj_pair` — Lennard-Jones 12-6 -* :func:`morse_pair` — Morse potential -* :func:`soft_sphere_pair` — soft-sphere repulsion -* :func:`particle_life_pair_force` — asymmetric particle-life force (use with - :class:`PairForcesModel`) Example:: - from torch_sim.models.pair_potential import PairPotentialModel, lj_pair + from torch_sim.models.pair_potential import PairPotentialModel + from torch_sim import io + from ase.build import bulk import functools + import torch + + + def bmhtf_pair(dr, zi, zj, A, B, C, D, sigma): + # Born-Meyer-Huggins-Tosi-Fumi (BMHTF) potential for ionic crystals + # V(r) = A * exp(B * (sigma - r)) - C/r^6 - D/r^8 + exp_term = A * torch.exp(B * (sigma - dr)) + r6_term = C / dr.pow(6) + r8_term = D / dr.pow(8) + energy = exp_term - r6_term - r8_term + return torch.where(dr > 0, energy, torch.zeros_like(energy)) + + + # Na-Cl interaction parameters + fn = functools.partial( + bmhtf_pair, + A=20.3548, + B=3.1546, + C=674.4793, + D=837.0770, + sigma=2.755, + ) + model = PairPotentialModel(pair_fn=fn, cutoff=10.0) - fn = functools.partial(lj_pair, sigma=1.0, epsilon=1.0) - model = PairPotentialModel(pair_fn=fn, cutoff=2.5) + # Create NaCl structure using ASE + nacl_atoms = bulk("NaCl", "rocksalt", a=5.64) + sim_state = io.atoms_to_state(nacl_atoms, device=torch.device("cpu")) results = model(sim_state) """ -# ruff: noqa: RUF002 - from __future__ import annotations from typing import TYPE_CHECKING @@ -46,208 +83,6 @@ from torch_sim.state import SimState -@torch.jit.script -def lj_pair( - dr: torch.Tensor, - zi: torch.Tensor, # noqa: ARG001 - zj: torch.Tensor, # noqa: ARG001 - sigma: float = 1.0, - epsilon: float = 1.0, -) -> torch.Tensor: - """Lennard-Jones 12-6 pair energy. - - V(r) = 4ε[(σ/r)¹² - (σ/r)⁶] - - Args: - dr: Pairwise distances, shape [n_pairs]. - zi: Atomic numbers of first atoms (unused, for interface compatibility). - zj: Atomic numbers of second atoms (unused, for interface compatibility). - sigma: Length scale. Defaults to 1.0. - epsilon: Energy scale. Defaults to 1.0. - - Returns: - Pair energies, shape [n_pairs]. - """ - idr6 = (sigma / dr).pow(6) - return 4.0 * epsilon * (idr6 * idr6 - idr6) - - -@torch.jit.script -def morse_pair( - dr: torch.Tensor, - zi: torch.Tensor, # noqa: ARG001 - zj: torch.Tensor, # noqa: ARG001 - sigma: float = 1.0, - epsilon: float = 5.0, - alpha: float = 5.0, -) -> torch.Tensor: - """Morse pair energy. - - V(r) = ε(1 - exp(-α(r - σ)))² - ε - - Args: - dr: Pairwise distances, shape [n_pairs]. - zi: Atomic numbers of first atoms (unused). - zj: Atomic numbers of second atoms (unused). - sigma: Equilibrium bond distance. Defaults to 1.0. - epsilon: Well depth / dissociation energy. Defaults to 5.0. - alpha: Width parameter. Defaults to 5.0. - - Returns: - Pair energies, shape [n_pairs]. - """ - return epsilon * (1.0 - torch.exp(-alpha * (dr - sigma))).pow(2) - epsilon - - -@torch.jit.script -def soft_sphere_pair( - dr: torch.Tensor, - zi: torch.Tensor, # noqa: ARG001 - zj: torch.Tensor, # noqa: ARG001 - sigma: float = 1.0, - epsilon: float = 1.0, - alpha: float = 2.0, -) -> torch.Tensor: - """Soft-sphere repulsive pair energy (zero beyond sigma). - - V(r) = ε/α * (1 - r/σ)^α for r < σ, else 0 - - Args: - dr: Pairwise distances, shape [n_pairs]. - zi: Atomic numbers of first atoms (unused). - zj: Atomic numbers of second atoms (unused). - sigma: Interaction diameter / cutoff. Defaults to 1.0. - epsilon: Energy scale. Defaults to 1.0. - alpha: Repulsion exponent. Defaults to 2.0. - - Returns: - Pair energies, shape [n_pairs]. - """ - energy = epsilon / alpha * (1.0 - dr / sigma).pow(alpha) - return torch.where(dr < sigma, energy, torch.zeros_like(energy)) - - -@torch.jit.script -def particle_life_pair_force( - dr: torch.Tensor, - zi: torch.Tensor, # noqa: ARG001 - zj: torch.Tensor, # noqa: ARG001 - A: float = 1.0, - beta: float = 0.3, - sigma: float = 1.0, -) -> torch.Tensor: - """Asymmetric particle-life scalar force magnitude. - - This is a *force* function (not an energy), intended for use with - :class:`PairForcesModel`. - - Args: - dr: Pairwise distances, shape [n_pairs]. - zi: Atomic numbers of first atoms (unused). - zj: Atomic numbers of second atoms (unused). - A: Interaction amplitude. Defaults to 1.0. - beta: Inner radius. Defaults to 0.3. - sigma: Outer radius / cutoff. Defaults to 1.0. - - Returns: - Scalar force magnitudes, shape [n_pairs]. - """ - inner_mask = dr < beta - outer_mask = (dr >= beta) & (dr < sigma) - inner_force = dr / beta - 1.0 - outer_force = A * (1.0 - torch.abs(2.0 * dr - 1.0 - beta) / (1.0 - beta)) - return torch.where(inner_mask, inner_force, torch.zeros_like(dr)) + torch.where( - outer_mask, outer_force, torch.zeros_like(dr) - ) - - -class MultiSoftSpherePairFn(torch.nn.Module): - """Species-dependent soft-sphere pair energy function. - - Holds per-species-pair parameter matrices and looks up sigma, epsilon, and alpha - for each interacting pair via their atomic numbers. Pass an instance to - :class:`PairPotentialModel`. - - Example:: - - fn = MultiSoftSpherePairFn( - atomic_numbers=torch.tensor([18, 36]), # Ar and Kr - sigma_matrix=torch.tensor([[3.4, 3.6], [3.6, 3.7]]), - epsilon_matrix=torch.tensor([[0.01, 0.012], [0.012, 0.014]]), - ) - model = PairPotentialModel(pair_fn=fn, cutoff=float(fn.sigma_matrix.max())) - """ - - def __init__( - self, - atomic_numbers: torch.Tensor, - sigma_matrix: torch.Tensor, - epsilon_matrix: torch.Tensor, - alpha_matrix: torch.Tensor | None = None, - ) -> None: - """Initialize species-dependent soft-sphere parameters. - - Args: - atomic_numbers: 1-D tensor of the unique atomic numbers present, used to - map ``zi``/``zj`` to row/column indices. Shape: [n_species]. - sigma_matrix: Symmetric matrix of interaction diameters. Shape: - [n_species, n_species]. - epsilon_matrix: Symmetric matrix of energy scales. Shape: - [n_species, n_species]. - alpha_matrix: Symmetric matrix of repulsion exponents. If None, defaults - to 2.0 for all pairs. Shape: [n_species, n_species]. - """ - super().__init__() - self.z_to_idx: torch.Tensor - self.atomic_numbers: torch.Tensor - self.sigma_matrix: torch.Tensor - self.epsilon_matrix: torch.Tensor - self.alpha_matrix: torch.Tensor - - n = len(atomic_numbers) - if sigma_matrix.shape != (n, n): - raise ValueError(f"sigma_matrix must have shape ({n}, {n})") - if epsilon_matrix.shape != (n, n): - raise ValueError(f"epsilon_matrix must have shape ({n}, {n})") - if alpha_matrix is not None and alpha_matrix.shape != (n, n): - raise ValueError(f"alpha_matrix must have shape ({n}, {n})") - - self.register_buffer("atomic_numbers", atomic_numbers) - self.register_buffer("sigma_matrix", sigma_matrix) - self.register_buffer("epsilon_matrix", epsilon_matrix) - self.register_buffer( - "alpha_matrix", - alpha_matrix if alpha_matrix is not None else torch.full((n, n), 2.0), - ) - # Build a lookup table: atomic_number -> species index - max_z = int(atomic_numbers.max().item()) + 1 - z_to_idx = torch.full((max_z,), -1, dtype=torch.long) - for idx, z in enumerate(atomic_numbers.tolist()): - z_to_idx[int(z)] = idx - self.register_buffer("z_to_idx", z_to_idx) - - def forward( - self, dr: torch.Tensor, zi: torch.Tensor, zj: torch.Tensor - ) -> torch.Tensor: - """Compute per-pair soft-sphere energies using species lookup. - - Args: - dr: Pairwise distances, shape [n_pairs]. - zi: Atomic numbers of first atoms, shape [n_pairs]. - zj: Atomic numbers of second atoms, shape [n_pairs]. - - Returns: - Pair energies, shape [n_pairs]. - """ - idx_i = self.z_to_idx[zi] - idx_j = self.z_to_idx[zj] - sigma = self.sigma_matrix[idx_i, idx_j] - epsilon = self.epsilon_matrix[idx_i, idx_j] - alpha = self.alpha_matrix[idx_i, idx_j] - energy = epsilon / alpha * (1.0 - dr / sigma).pow(alpha) - return torch.where(dr < sigma, energy, torch.zeros_like(energy)) - - def full_to_half_list( mapping: torch.Tensor, system_mapping: torch.Tensor, @@ -398,7 +233,7 @@ def _virial_stress( volumes = torch.abs(torch.linalg.det(row_cell)) stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors) stress = torch.zeros((n_systems, 3, 3), dtype=dtype, device=device) - stress.index_add_(0, system_mapping, -stress_per_pair) + stress = stress.index_add(0, system_mapping, -stress_per_pair) stress = stress / volumes[:, None, None] return stress, stress_per_pair, volumes @@ -428,17 +263,18 @@ def _accumulate_stress( dtype, device, ) - stress_scale = 2.0 if half else 1.0 + stress_scale = 1.0 if half else 0.5 out: dict[str, torch.Tensor] = {"stress": stress * stress_scale} if per_atom: - # Half list: each pair once → weight 1.0 per endpoint. - # Full list: each pair twice (i→j and j→i) → weight 0.5 per endpoint. - w = 1.0 if half else 0.5 + # Each endpoint (i and j) gets half the pair's contribution. + # Half list: each unique pair appears once → w = 0.5. + # Full list: each unique pair appears twice → w = 0.25. + w = 0.5 if half else 0.25 n_atoms = positions.shape[0] atom_stresses = torch.zeros((n_atoms, 3, 3), dtype=dtype, device=device) - atom_stresses.index_add_(0, mapping[0], -w * stress_per_pair) - atom_stresses.index_add_(0, mapping[1], -w * stress_per_pair) + atom_stresses = atom_stresses.index_add(0, mapping[0], -w * stress_per_pair) + atom_stresses = atom_stresses.index_add(0, mapping[1], -w * stress_per_pair) out["stresses"] = atom_stresses / volumes[system_idx, None, None] return out @@ -451,7 +287,10 @@ class PairPotentialModel(ModelInterface): callable of the form ``pair_fn(distances, atomic_numbers_i, atomic_numbers_j) -> pair_energies``, where all arguments are 1-D tensors of length n_pairs and the return value is a 1-D tensor of pair energies. Forces are obtained analytically - via autograd. + via autograd by differentiating the energy with respect to positions. + + When stress is computed, it uses the virial formula: σ = -1/V Σ_{ij} r_ij ⊗ f_ij, + where r_ij is the pair displacement vector and f_ij is the force vector. Example:: @@ -467,16 +306,17 @@ def lj_fn(dr, zi, zj): def __init__( self, pair_fn: Callable, + *, cutoff: float, device: torch.device | None = None, - dtype: torch.dtype = torch.float32, - *, + dtype: torch.dtype = torch.float64, compute_forces: bool = True, compute_stress: bool = False, per_atom_energies: bool = False, per_atom_stresses: bool = False, neighbor_list_fn: Callable = torchsim_nl, reduce_to_half_list: bool = False, + retain_graph: bool = False, ) -> None: """Initialize the pair potential model. @@ -496,6 +336,10 @@ def __init__( before computing interactions. Halves pair operations and makes accumulation patterns unambiguous. Only valid for symmetric pair functions; do not use for asymmetric interactions. Defaults to False. + retain_graph: If True, keep the computation graph after computing forces + so that the energy can still be differentiated w.r.t. model parameters + (e.g. for differentiable simulation / meta-optimization). + Defaults to False. """ super().__init__() self._device = device or torch.device("cpu") @@ -508,6 +352,7 @@ def __init__( self.neighbor_list_fn = neighbor_list_fn self.cutoff = torch.tensor(cutoff, dtype=dtype, device=self._device) self.reduce_to_half_list = reduce_to_half_list + self.retain_graph = retain_graph def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]: """Compute pair-potential properties with batched tensor operations. @@ -520,7 +365,17 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] dict with keys ``"energy"`` (shape ``[n_systems]``), optionally ``"forces"`` (``[n_atoms, 3]``), ``"stress"`` (``[n_systems, 3, 3]``), ``"energies"`` (``[n_atoms]``), ``"stresses"`` (``[n_atoms, 3, 3]``). + + Raises: + TypeError: If the SimState's dtype does not match the model's dtype. """ + if state.dtype != self._dtype: + raise TypeError( + f"SimState dtype {state.dtype} does not match model dtype {self._dtype}. " + f"Either set the model dtype to {state.dtype} or convert the SimState " + f"to {self._dtype} using sim_state.to(dtype={self._dtype})." + ) + dtype = self._dtype half = self.reduce_to_half_list ( positions, @@ -555,16 +410,16 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] ew = 1.0 if half else 0.5 results: dict[str, torch.Tensor] = {} - energy = torch.zeros(n_systems, dtype=self._dtype, device=self._device) - energy.index_add_(0, system_mapping, ew * pair_energies) + energy = torch.zeros(n_systems, dtype=dtype, device=self._device) + energy = energy.index_add(0, system_mapping, ew * pair_energies) results["energy"] = energy if self.per_atom_energies: atom_energies = torch.zeros( - positions.shape[0], dtype=self._dtype, device=self._device + positions.shape[0], dtype=dtype, device=self._device ) - atom_energies.index_add_(0, mapping[0], ew * pair_energies) - atom_energies.index_add_(0, mapping[1], ew * pair_energies) + atom_energies = atom_energies.index_add(0, mapping[0], ew * pair_energies) + atom_energies = atom_energies.index_add(0, mapping[1], ew * pair_energies) results["energies"] = atom_energies if need_grad: @@ -572,6 +427,7 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] pair_energies.sum(), dist_for_grad, create_graph=False, + retain_graph=self.retain_graph, ) safe_dist = torch.where(distances > 0, distances, torch.ones_like(distances)) # force_vectors = -dV/dr * r̂_ij: positive (repulsive) pushes j away from i. @@ -581,13 +437,13 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] forces = torch.zeros_like(positions) if half: # Half list: each pair once → apply Newton's third law explicitly. - forces.index_add_(0, mapping[0], -force_vectors) - forces.index_add_(0, mapping[1], force_vectors) + forces = forces.index_add(0, mapping[0], -force_vectors) + forces = forces.index_add(0, mapping[1], force_vectors) else: # Full list: atom i appears as mapping[0] for every i→j pair, # covering all its neighbors. mapping[1] accumulation would # double-count, so we only accumulate on the source atom. - forces.index_add_(0, mapping[0], -force_vectors) + forces = forces.index_add(0, mapping[0], -force_vectors) results["forces"] = forces if self._compute_stress: @@ -601,14 +457,17 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] force_vectors, row_cell, n_systems, - self._dtype, + dtype, self._device, half=half, per_atom=self.per_atom_stresses, ) ) - return {k: v.detach() for k, v in results.items()} + if not self.retain_graph: + results = {k: v.detach() for k, v in results.items()} + + return results class PairForcesModel(ModelInterface): @@ -622,12 +481,17 @@ class PairForcesModel(ModelInterface): Forces are accumulated as: F_i += -f_ij * r̂_ij, F_j += +f_ij * r̂_ij + Note: + Unlike :class:`PairPotentialModel`, this class does not compute energies + (returns zeros) since there is no underlying energy function. Use + :class:`PairPotentialModel` when your interaction can be expressed as an + energy function, as it provides automatic force computation via autograd + and is generally more efficient. + Example:: - from torch_sim.models.pair_potential import ( - PairForcesModel, - particle_life_pair_force, - ) + from torch_sim.models.particle_life import particle_life_pair_force + from torch_sim.models.pair_potential import PairForcesModel import functools fn = functools.partial(particle_life_pair_force, A=1.0, beta=0.3, sigma=1.0) @@ -638,10 +502,10 @@ class PairForcesModel(ModelInterface): def __init__( self, force_fn: Callable, + *, cutoff: float, device: torch.device | None = None, - dtype: torch.dtype = torch.float32, - *, + dtype: torch.dtype = torch.float64, compute_stress: bool = False, per_atom_stresses: bool = False, neighbor_list_fn: Callable = torchsim_nl, @@ -686,7 +550,17 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] ``"forces"`` (shape ``[n_atoms, 3]``), and optionally ``"stress"`` (shape ``[n_systems, 3, 3]``) and ``"stresses"`` (shape ``[n_atoms, 3, 3]``). + + Raises: + TypeError: If the SimState's dtype does not match the model's dtype. """ + if state.dtype != self._dtype: + raise TypeError( + f"SimState dtype {state.dtype} does not match model dtype {self._dtype}. " + f"Either set the model dtype to {state.dtype} or convert the SimState " + f"to {self._dtype} using sim_state.to(dtype={self._dtype})." + ) + dtype = self._dtype half = self.reduce_to_half_list ( positions, @@ -715,11 +589,11 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] force_vectors = (pair_forces / safe_dist)[:, None] * dr_vec forces = torch.zeros_like(positions) - forces.index_add_(0, mapping[0], -force_vectors) - forces.index_add_(0, mapping[1], force_vectors) + forces = forces.index_add(0, mapping[0], -force_vectors) + forces = forces.index_add(0, mapping[1], force_vectors) results: dict[str, torch.Tensor] = { - "energy": torch.zeros(n_systems, dtype=self._dtype, device=self._device), + "energy": torch.zeros(n_systems, dtype=dtype, device=self._device), "forces": forces, } @@ -734,7 +608,7 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] force_vectors, row_cell, n_systems, - self._dtype, + dtype, self._device, half=half, per_atom=self.per_atom_stresses, diff --git a/torch_sim/models/particle_life.py b/torch_sim/models/particle_life.py index c227f272..d5e53767 100644 --- a/torch_sim/models/particle_life.py +++ b/torch_sim/models/particle_life.py @@ -1,94 +1,75 @@ -"""Particle life model for computing forces between particles.""" +"""Particle life model. -import torch - -import torch_sim as ts -from torch_sim import transforms -from torch_sim.models.interface import ModelInterface -from torch_sim.neighbors import torchsim_nl +Thin wrapper around :class:`~torch_sim.models.pair_potential.PairForcesModel` with +the :func:`particle_life_pair_force` force function +baked in. +Example:: -DEFAULT_BETA = torch.tensor(0.3) -DEFAULT_SIGMA = torch.tensor(1.0) + model = ParticleLifeModel(sigma=1.0, epsilon=1.0, beta=0.3, cutoff=1.0) + results = model(sim_state) +""" +from __future__ import annotations -def asymmetric_particle_pair_force( - dr: torch.Tensor, - A: torch.Tensor, - beta: torch.Tensor = DEFAULT_BETA, - sigma: torch.Tensor = DEFAULT_SIGMA, -) -> torch.Tensor: - """Asymmetric interaction between particles. - - Args: - dr: A tensor of shape [n, m] of pairwise distances between particles. - A: Interaction scale. Either a float scalar or a tensor of shape [n, m]. - beta: Inner radius of the interaction. Either a float scalar or tensor of - shape [n, m]. - sigma: Outer radius of the interaction. Either a float scalar or tensor of - shape [n, m]. +import functools +from collections.abc import Callable # noqa: TC003 - Returns: - torch.Tensor: Energies with shape [n, m]. - """ - inner_mask = dr < beta - outer_mask = (dr < sigma) & (dr > beta) - - def inner_force_fn(dr: torch.Tensor) -> torch.Tensor: - return dr / beta - 1 - - def intermediate_force_fn(dr: torch.Tensor) -> torch.Tensor: - return A * (1 - torch.abs(2 * dr - 1 - beta) / (1 - beta)) +import torch - return torch.where(inner_mask, inner_force_fn(dr), 0) + torch.where( - outer_mask, - intermediate_force_fn(dr), - 0, - ) +from torch_sim.models.pair_potential import PairForcesModel +from torch_sim.neighbors import torchsim_nl -def asymmetric_particle_pair_force_jit( +def particle_life_pair_force( dr: torch.Tensor, - A: torch.Tensor, - beta: torch.Tensor = DEFAULT_BETA, - sigma: torch.Tensor = DEFAULT_SIGMA, + zi: torch.Tensor, # noqa: ARG001 + zj: torch.Tensor, # noqa: ARG001 + A: torch.Tensor | float = 1.0, + beta: torch.Tensor | float = 0.3, + sigma: torch.Tensor | float = 1.0, ) -> torch.Tensor: - """Asymmetric interaction between particles. + """Asymmetric particle-life scalar force magnitude. + + This is a *force* function (not an energy), intended for use with + :class:`PairForcesModel`. Args: - dr: A tensor of shape [n, m] of pairwise distances between particles. - A: Interaction scale. Either a float scalar or a tensor of shape [n, m]. - beta: Inner radius of the interaction. Either a float scalar or tensor of - shape [n, m]. - sigma: Outer radius of the interaction. Either a float scalar or tensor of - shape [n, m]. + dr: Pairwise distances, shape [n_pairs]. + zi: Atomic numbers of first atoms (unused). + zj: Atomic numbers of second atoms (unused). + A: Interaction amplitude. Defaults to 1.0. + beta: Inner radius. Defaults to 0.3. + sigma: Outer radius / cutoff. Defaults to 1.0. Returns: - torch.Tensor: Energies with shape [n, m]. + Scalar force magnitudes, shape [n_pairs]. """ inner_mask = dr < beta - outer_mask = (dr < sigma) & (dr > beta) - - # Calculate inner forces directly - inner_forces = torch.where(inner_mask, dr / beta - 1, torch.zeros_like(dr)) - - # Calculate outer forces directly - outer_forces = torch.where( - outer_mask, - A * (1 - torch.abs(2 * dr - 1 - beta) / (1 - beta)), - torch.zeros_like(dr), + outer_mask = (dr >= beta) & (dr < sigma) + inner_force = dr / beta - 1.0 + outer_force = A * (1.0 - torch.abs(2.0 * dr - 1.0 - beta) / (1.0 - beta)) + return torch.where(inner_mask, inner_force, torch.zeros_like(dr)) + torch.where( + outer_mask, outer_force, torch.zeros_like(dr) ) - return inner_forces + outer_forces +class ParticleLifeModel(PairForcesModel): + """Asymmetric particle-life force model. -class ParticleLifeModel(ModelInterface): - """Calculator for asymmetric particle interaction. + Convenience subclass that fixes the force function to + :func:`particle_life_pair_force` so the caller only needs to supply + ``sigma``, ``epsilon`` (amplitude), and ``beta``. - This model implements an asymmetric interaction between particles based on - distance-dependent forces. The interaction is defined by three parameters: - sigma, epsilon, and beta. + Example:: + model = ParticleLifeModel( + sigma=1.0, + epsilon=1.0, + beta=0.3, + cutoff=1.0, + ) + results = model(sim_state) """ def __init__( @@ -97,192 +78,45 @@ def __init__( epsilon: float = 1.0, beta: float = 0.3, device: torch.device | None = None, - dtype: torch.dtype = torch.float32, - *, # Force keyword-only arguments - compute_forces: bool = False, + dtype: torch.dtype = torch.float64, + *, compute_stress: bool = False, - per_atom_energies: bool = False, per_atom_stresses: bool = False, - use_neighbor_list: bool = True, + neighbor_list_fn: Callable = torchsim_nl, cutoff: float | None = None, + **kwargs: object, # noqa: ARG002 ) -> None: - """Initialize the calculator. + """Initialize the particle life model. Args: - sigma: Outer radius of the interaction. - epsilon: Interaction scale. - beta: Inner radius of the interaction. - device: Device for computation. - dtype: Data type for tensors. - compute_forces: Whether to compute forces. - compute_stress: Whether to compute stress tensor. - per_atom_energies: Whether to compute per-atom energies. - per_atom_stresses: Whether to compute per-atom stresses. - use_neighbor_list: Whether to use neighbor list optimization. - cutoff: Interaction cutoff distance. Defaults to 2.5 * sigma. + sigma: Outer radius of the interaction. Defaults to 1.0. + epsilon: Interaction amplitude (``A`` parameter). Defaults to 1.0. + beta: Inner radius of the interaction. Defaults to 0.3. + device: Device for computations. Defaults to CPU. + dtype: Floating-point dtype. Defaults to torch.float32. + compute_forces: Accepted for backward compatibility (always True). + compute_stress: Whether to compute stress tensor. Defaults to False. + per_atom_energies: Accepted for backward compatibility (ignored — this + is a force-only model, energy is always zero). + per_atom_stresses: Whether to return per-atom stresses. Defaults to False. + use_neighbor_list: Accepted for backward compatibility (ignored). + neighbor_list_fn: Neighbor-list constructor. Defaults to torchsim_nl. + cutoff: Interaction cutoff. Defaults to 2.5 * sigma. + **kwargs: Additional keyword arguments. """ - super().__init__() - self._device = device or torch.device("cpu") - self._dtype = dtype - - self._compute_forces = compute_forces - self._compute_stress = compute_stress - self._per_atom_energies = per_atom_energies - self._per_atom_stresses = per_atom_stresses - - self.use_neighbor_list = use_neighbor_list - - # Convert parameters to tensors - self.sigma = torch.tensor(sigma, dtype=self.dtype, device=self.device) - self.cutoff = torch.tensor( - cutoff or 2.5 * sigma, dtype=self.dtype, device=self.device + self.sigma_param = sigma + self.epsilon = epsilon + self.beta = beta + force_fn = functools.partial( + particle_life_pair_force, A=epsilon, beta=beta, sigma=sigma ) - self.epsilon = torch.tensor(epsilon, dtype=self.dtype, device=self.device) - self.beta = torch.tensor(beta, dtype=self.dtype, device=self.device) - - def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: - """Compute energies and forces for a single unbatched system. - - Internal implementation that processes a single, non-batched simulation state. - This method handles the core computations of pair interactions, neighbor lists, - and property calculations. - - Args: - state: Single, non-batched simulation state containing atomic positions, - cell vectors, and other system information. - - Returns: - A dictionary containing the energy, forces, and stresses - """ - positions = state.positions - cell = state.row_vector_cell - - if cell.dim() == 3: # Check if there is an extra batch dimension - cell = cell.squeeze(0) # Squeeze the first dimension - - # Ensure system_idx exists (create if None for single system) - system_idx = ( - state.system_idx - if state.system_idx is not None - else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device) + super().__init__( + force_fn=force_fn, + cutoff=cutoff if cutoff is not None else 2.5 * sigma, + device=device, + dtype=dtype, + compute_stress=compute_stress, + per_atom_stresses=per_atom_stresses, + neighbor_list_fn=neighbor_list_fn, + reduce_to_half_list=False, ) - - # Wrap positions into the unit cell - wrapped_positions = ( - ts.transforms.pbc_wrap_batched(positions, state.cell, system_idx, state.pbc) - if state.pbc.any() - else positions - ) - - if self.use_neighbor_list: - mapping, _, shifts_idx = torchsim_nl( - positions=wrapped_positions, - cell=cell, - pbc=state.pbc, - cutoff=self.cutoff, - system_idx=system_idx, - ) - # Pass shifts_idx directly - get_pair_displacements will convert them - dr_vec, distances = transforms.get_pair_displacements( - positions=wrapped_positions, - cell=cell, - pbc=state.pbc, - pairs=(mapping[0], mapping[1]), - shifts=shifts_idx, - ) - else: - # Get all pairwise displacements - dr_vec, distances = transforms.get_pair_displacements( - positions=wrapped_positions, - cell=cell, - pbc=state.pbc, - ) - # Mask out self-interactions - mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device) - distances = distances.masked_fill(mask, float("inf")) - # Apply cutoff - mask = distances < self.cutoff - # Get valid pairs - match neighbor list convention for pair order - i, j = torch.where(mask) - mapping = torch.stack([j, i]) - # Get valid displacements and distances - dr_vec = dr_vec[mask] - distances = distances[mask] - - # Zero out energies beyond cutoff - mask = distances < self.cutoff - - # Initialize results with total energy (sum/2 to avoid double counting) - results: dict[str, torch.Tensor] = { - "energy": torch.tensor(0.0, dtype=self.dtype, device=self.device), - } - - # Calculate forces and apply cutoff - pair_forces = asymmetric_particle_pair_force_jit( - dr=distances, A=self.epsilon, sigma=self.sigma, beta=self.beta - ) - pair_forces = torch.where(mask, pair_forces, torch.zeros_like(pair_forces)) - - # Project forces along displacement vectors - force_vectors = (pair_forces / distances)[:, None] * dr_vec - - # Initialize forces tensor - forces = torch.zeros_like(state.positions) - # Add force contributions (f_ij on i, -f_ij on j) - forces.index_add_(0, mapping[0], -force_vectors) - forces.index_add_(0, mapping[1], force_vectors) - results["forces"] = forces - - return results - - def forward(self, state: ts.SimState, **_kwargs: object) -> dict[str, torch.Tensor]: - """Compute particle life energies and forces for a system. - - Main entry point for particle life calculations that handles batched states by - dispatching each batch to the unbatched implementation and combining results. - - Args: - state: Input state containing atomic positions, cell vectors, and other - system information. Can be a SimState object or a dictionary with the - same keys. - **_kwargs: Unused; accepted for interface compatibility. - - Returns: - dict[str, torch.Tensor]: Computed properties: - - "energy": Potential energy with shape [n_systems] - - "forces": Atomic forces with shape [n_atoms, 3] (if - compute_forces=True) - - "stress": Stress tensor with shape [n_systems, 3, 3] (if - compute_stress=True) - - "energies": Per-atom energies with shape [n_atoms] (if - per_atom_energies=True) - - "stresses": Per-atom stresses with shape [n_atoms, 3, 3] (if - per_atom_stresses=True) - - Raises: - ValueError: If batch cannot be inferred for multi-cell systems. - """ - sim_state = state - - if sim_state.system_idx is None and sim_state.cell.shape[0] > 1: - raise ValueError( - "system_idx can only be inferred if there is only one system." - ) - - outputs = [ - self.unbatched_forward(sim_state[idx]) for idx in range(sim_state.n_systems) - ] - properties = outputs[0] - - # we always return tensors - # per atom properties are returned as (atoms, ...) tensors - # global properties are returned as shape (..., n) tensors - results: dict[str, torch.Tensor] = {} - for key in ("stress", "energy"): - if key in properties: - results[key] = torch.stack([out[key] for out in outputs]) - for key in ("forces", "energies", "stresses"): - if key in properties: - results[key] = torch.cat([out[key] for out in outputs], dim=0) - - return results diff --git a/torch_sim/models/soft_sphere.py b/torch_sim/models/soft_sphere.py index 6ec7f26e..598eff88 100644 --- a/torch_sim/models/soft_sphere.py +++ b/torch_sim/models/soft_sphere.py @@ -1,616 +1,324 @@ -"""Soft sphere model for computing energies, forces and stresses. +"""Soft sphere potential model. -This module provides implementations of soft sphere potentials for molecular dynamics -simulations. Soft sphere potentials are repulsive interatomic potentials that model -the core repulsion between atoms, avoiding the infinite repulsion of hard sphere models -while maintaining computational efficiency. +Thin wrapper around :class:`~torch_sim.models.pair_potential.PairPotentialModel` with +the :func:`soft_sphere_pair` energy function baked in. The soft sphere potential has the form: - V(r) = epsilon * (sigma/r)^alpha -Where: - -* r is the distance between particles -* sigma is the effective diameter of the particles -* epsilon controls the energy scale -* alpha determines the steepness of the repulsion (typically alpha >= 2) - -Soft sphere models are particularly useful for: - -* Granular matter simulations -* Modeling excluded volume effects -* Initial equilibration of dense systems -* Coarse-grained molecular dynamics + V(r) = ε/α * (1 - r/σ)^α for r < σ, else 0 Example:: - # Create a soft sphere model with default parameters - model = SoftSphereModel() - - # Calculate properties for a simulation state + model = SoftSphereModel(sigma=1.0, epsilon=1.0, alpha=2.0) results = model(sim_state) - energy = results["energy"] - forces = results["forces"] # For multiple species with different interaction parameters multi_model = SoftSphereMultiModel( - species=particle_types, - sigma_matrix=size_matrix, - epsilon_matrix=strength_matrix, + atomic_numbers=torch.tensor([18, 36]), + sigma_matrix=torch.tensor([[1.0, 0.8], [0.8, 0.6]]), + epsilon_matrix=torch.tensor([[1.0, 0.5], [0.5, 2.0]]), ) results = multi_model(sim_state) """ -import torch +from __future__ import annotations -import torch_sim as ts -from torch_sim import transforms -from torch_sim.models.interface import ModelInterface -from torch_sim.neighbors import torchsim_nl +import functools +from collections.abc import Callable # noqa: TC003 +import torch -DEFAULT_SIGMA = torch.tensor(1.0) -DEFAULT_EPSILON = torch.tensor(1.0) -DEFAULT_ALPHA = torch.tensor(2.0) +from torch_sim.models.pair_potential import PairPotentialModel +from torch_sim.neighbors import torchsim_nl def soft_sphere_pair( dr: torch.Tensor, - sigma: float | torch.Tensor = DEFAULT_SIGMA, - epsilon: float | torch.Tensor = DEFAULT_EPSILON, - alpha: float | torch.Tensor = DEFAULT_ALPHA, + zi: torch.Tensor, # noqa: ARG001 + zj: torch.Tensor, # noqa: ARG001 + sigma: torch.Tensor | float = 1.0, + epsilon: torch.Tensor | float = 1.0, + alpha: torch.Tensor | float = 2.0, ) -> torch.Tensor: - """Calculate pairwise repulsive energies between soft spheres with finite-range - interactions. + """Soft-sphere repulsive pair energy (zero beyond sigma). - Computes a soft-core repulsive potential between particle pairs based on - their separation distance, size, and interaction parameters. The potential - goes to zero at finite range. + V(r) = ε/α * (1 - r/σ)^α for r < σ, else 0 Args: - dr: Pairwise distances between particles. Shape: [n, m]. - sigma: Particle diameters. Either a scalar float or tensor of shape [n, m] - for particle-specific sizes. - epsilon: Energy scale of the interaction. Either a scalar float or tensor - of shape [n, m] for pair-specific interaction strengths. - alpha: Stiffness exponent controlling the interaction decay. Either a scalar - float or tensor of shape [n, m]. + dr: Pairwise distances, shape [n_pairs]. + zi: Atomic numbers of first atoms (unused). + zj: Atomic numbers of second atoms (unused). + sigma: Interaction diameter / cutoff. Defaults to 1.0. + epsilon: Energy scale. Defaults to 1.0. + alpha: Repulsion exponent. Defaults to 2.0. Returns: - torch.Tensor: Pairwise interaction energies between particles. Shape: [n, m]. - Each element [i,j] represents the repulsive energy between particles i and j. + Pair energies, shape [n_pairs]. """ - - def fn(dr: torch.Tensor) -> torch.Tensor: - return epsilon / alpha * (1.0 - (dr / sigma)).pow(alpha) - - # Create mask for distances within cutoff i.e sigma - mask = dr < sigma - - # Use transforms.safe_mask to compute energies only where mask is True - return transforms.safe_mask(mask, fn, dr) + energy = epsilon / alpha * (1.0 - dr / sigma).pow(alpha) + return torch.where(dr < sigma, energy, torch.zeros_like(energy)) def soft_sphere_pair_force( dr: torch.Tensor, - sigma: torch.Tensor = DEFAULT_SIGMA, - epsilon: torch.Tensor = DEFAULT_EPSILON, - alpha: torch.Tensor = DEFAULT_ALPHA, + zi: torch.Tensor, # noqa: ARG001 + zj: torch.Tensor, # noqa: ARG001 + sigma: torch.Tensor | float = 1.0, + epsilon: torch.Tensor | float = 1.0, + alpha: torch.Tensor | float = 2.0, ) -> torch.Tensor: - """Computes the pairwise repulsive forces between soft spheres with finite range. + """Soft-sphere pair force (negative gradient of energy). - This function implements a soft-core repulsive interaction that smoothly goes to zero - at the cutoff distance sigma. The force magnitude is controlled by epsilon and its - stiffness by alpha. + F(r) = (ε/σ) (1 - r/σ)^(α-1) for r < σ, else 0 Args: - dr: A tensor of shape [n, m] containing pairwise distances between particles, - where n and m represent different particle indices. - sigma: Particle diameter defining the interaction cutoff distance. Can be either - a float scalar or a tensor of shape [n, m] for particle-specific diameters. - epsilon: Energy scale of the interaction. Can be either a float scalar or a - tensor of shape [n, m] for particle-specific interaction strengths. - alpha: Exponent controlling the stiffness of the repulsion. Higher values create - a harder repulsion. Can be either a float scalar or a tensor of shape [n, m]. + dr: Pairwise distances. + zi: Atomic numbers of first atoms (unused). + zj: Atomic numbers of second atoms (unused). + sigma: Interaction diameter. Defaults to 1.0. + epsilon: Energy scale. Defaults to 1.0. + alpha: Repulsion exponent. Defaults to 2.0. Returns: - torch.Tensor: Forces between particle pairs with shape [n, m]. Forces are zero - for distances greater than sigma. + Pair force magnitudes. """ + force = (epsilon / sigma) * (1.0 - (dr / sigma)).pow(alpha - 1) + mask = dr < sigma + return torch.where(mask, force, torch.zeros_like(force)) - def fn(dr: torch.Tensor) -> torch.Tensor: - return (epsilon / sigma) * (1.0 - (dr / sigma)).pow(alpha - 1) - # Create mask for distances within cutoff i.e sigma - mask = dr < sigma +class MultiSoftSpherePairFn(torch.nn.Module): + """Species-dependent soft-sphere pair energy function. - # Use transforms.safe_mask to compute energies only where mask is True - return transforms.safe_mask(mask, fn, dr) - - -class SoftSphereModel(ModelInterface): - """Calculator for soft sphere potential energies and forces. - - Implements a model for computing properties based on the soft sphere potential, - which describes purely repulsive interactions between particles. This potential - is useful for modeling systems where particles should not overlap but don't have - attractive interactions, such as granular materials and some colloidal systems. - - The potential energy between particles i and j is: - V_ij(r) = epsilon * (sigma/r)^alpha - - Attributes: - sigma (torch.Tensor): Effective particle diameter in distance units. - epsilon (torch.Tensor): Energy scale parameter in energy units. - alpha (torch.Tensor): Exponent controlling repulsion steepness (typically ≥ 2). - cutoff (torch.Tensor): Cutoff distance for interactions. - use_neighbor_list (bool): Whether to use neighbor list optimization. - _device (torch.device): Computation device (CPU/GPU). - _dtype (torch.dtype): Data type for tensor calculations. - _compute_forces (bool): Whether to compute forces. - _compute_stress (bool): Whether to compute stress tensor. - per_atom_energies (bool): Whether to compute per-atom energy decomposition. - per_atom_stresses (bool): Whether to compute per-atom stress decomposition. - - Examples: - ```py - # Basic usage with default parameters - model = SoftSphereModel() - results = model(sim_state) + Holds per-species-pair parameter matrices and looks up sigma, epsilon, and alpha + for each interacting pair via their atomic numbers. Pass an instance to + :class:`PairPotentialModel`. - # Custom parameters for colloidal system - colloid_model = SoftSphereModel( - sigma=2.0, # particle diameter in nm - epsilon=10.0, # energy scale in kJ/mol - alpha=12.0, # steep repulsion for hard colloids - compute_stress=True, - ) + Example:: - # Get forces for a system with periodic boundary conditions - results = colloid_model( - ts.SimState( - positions=positions, - cell=box_vectors, - pbc=torch.tensor([True, True, True]), - ) + fn = MultiSoftSpherePairFn( + atomic_numbers=torch.tensor([18, 36]), # Ar and Kr + sigma_matrix=torch.tensor([[3.4, 3.6], [3.6, 3.7]]), + epsilon_matrix=torch.tensor([[0.01, 0.012], [0.012, 0.014]]), ) - forces = results["forces"] # shape: [n_particles, 3] - ``` + model = PairPotentialModel(pair_fn=fn, cutoff=float(fn.sigma_matrix.max())) """ def __init__( self, - sigma: float | torch.Tensor = 1.0, - epsilon: float | torch.Tensor = 1.0, - alpha: float | torch.Tensor = 2.0, - device: torch.device | None = None, - dtype: torch.dtype = torch.float32, - *, # Force keyword-only arguments - compute_forces: bool = True, - compute_stress: bool = False, - per_atom_energies: bool = False, - per_atom_stresses: bool = False, - use_neighbor_list: bool = True, - cutoff: float | torch.Tensor | None = None, + atomic_numbers: torch.Tensor, + sigma_matrix: torch.Tensor, + epsilon_matrix: torch.Tensor, + alpha_matrix: torch.Tensor | None = None, ) -> None: - """Initialize the soft sphere model. - - Creates a soft sphere model with specified parameters for particle interactions - and computation options. + """Initialize species-dependent soft-sphere parameters. Args: - sigma (float): Effective particle diameter. Determines the distance - scale of the interaction. Defaults to 1.0. - epsilon (float): Energy scale parameter. Controls the strength of - the repulsion. Defaults to 1.0. - alpha (float): Exponent controlling repulsion steepness. Higher values - create steeper, more hard-sphere-like repulsion. Defaults to 2.0. - device (torch.device | None): Device for computations. If None, uses CPU. - Defaults to None. - dtype (torch.dtype): Data type for calculations. Defaults to torch.float32. - compute_forces (bool): Whether to compute forces. Defaults to True. - compute_stress (bool): Whether to compute stress tensor. Defaults to False. - per_atom_energies (bool): Whether to compute per-atom energy decomposition. - Defaults to False. - per_atom_stresses (bool): Whether to compute per-atom stress decomposition. - Defaults to False. - use_neighbor_list (bool): Whether to use a neighbor list for optimization. - Significantly faster for large systems. Defaults to True. - cutoff (float | None): Cutoff distance for interactions. If None, uses - the value of sigma. Defaults to None. - - Examples: - ```py - # Default model - model = SoftSphereModel() - - # WCA-like repulsive potential (derived from Lennard-Jones) - wca_model = SoftSphereModel( - sigma=1.0, - epsilon=1.0, - alpha=12.0, # Steep repulsion similar to r^-12 term in LJ - cutoff=2 ** (1 / 6), # WCA cutoff at minimum of LJ potential - ) - ``` + atomic_numbers: 1-D tensor of the unique atomic numbers present, used to + map ``zi``/``zj`` to row/column indices. Shape: [n_species]. + sigma_matrix: Symmetric matrix of interaction diameters. Shape: + [n_species, n_species]. + epsilon_matrix: Symmetric matrix of energy scales. Shape: + [n_species, n_species]. + alpha_matrix: Symmetric matrix of repulsion exponents. If None, defaults + to 2.0 for all pairs. Shape: [n_species, n_species]. """ super().__init__() - self._device = device or torch.device("cpu") - self._dtype = dtype - self._compute_forces = compute_forces - self._compute_stress = compute_stress - self.per_atom_energies = per_atom_energies - self.per_atom_stresses = per_atom_stresses - self.use_neighbor_list = use_neighbor_list - - # Convert interaction parameters to tensors with proper dtype/device - self.sigma = torch.as_tensor(sigma, dtype=dtype, device=self.device) - self.cutoff = torch.as_tensor(cutoff or sigma, dtype=dtype, device=self.device) - self.epsilon = torch.as_tensor(epsilon, dtype=dtype, device=self.device) - self.alpha = torch.as_tensor(alpha, dtype=dtype, device=self.device) - - def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]: - """Compute energies and forces for a single unbatched system. - - Internal implementation that processes a single, non-batched simulation state. - This method handles the core computations for pair interactions, including - neighbor list construction, distance calculations, and property computation. + n = len(atomic_numbers) + if sigma_matrix.shape != (n, n): + raise ValueError(f"sigma_matrix must have shape ({n}, {n})") + if epsilon_matrix.shape != (n, n): + raise ValueError(f"epsilon_matrix must have shape ({n}, {n})") + if alpha_matrix is not None and alpha_matrix.shape != (n, n): + raise ValueError(f"alpha_matrix must have shape ({n}, {n})") + + self.register_buffer("atomic_numbers", atomic_numbers) + self.sigma_matrix = sigma_matrix + self.epsilon_matrix = epsilon_matrix + self.alpha_matrix = ( + alpha_matrix if alpha_matrix is not None else torch.full((n, n), 2.0) + ) + max_z = int(atomic_numbers.max().item()) + 1 + z_to_idx = torch.full((max_z,), -1, dtype=torch.long) + for idx, z in enumerate(atomic_numbers.tolist()): + z_to_idx[int(z)] = idx + self.z_to_idx: torch.Tensor + self.register_buffer("z_to_idx", z_to_idx) + + def forward( + self, dr: torch.Tensor, zi: torch.Tensor, zj: torch.Tensor + ) -> torch.Tensor: + """Compute per-pair soft-sphere energies using species lookup. Args: - state (SimState): Single, non-batched simulation state containing atomic - positions, cell vectors, and other system information. + dr: Pairwise distances, shape [n_pairs]. + zi: Atomic numbers of first atoms, shape [n_pairs]. + zj: Atomic numbers of second atoms, shape [n_pairs]. Returns: - dict[str, torch.Tensor]: Computed properties: - - "energy": Total potential energy (scalar) - - "forces": Atomic forces with shape [n_atoms, 3] (if - compute_forces=True) - - "stress": Stress tensor with shape [3, 3] (if compute_stress=True) - - "energies": Per-atom energies with shape [n_atoms] (if - per_atom_energies=True) - - "stresses": Per-atom stresses with shape [n_atoms, 3, 3] (if - per_atom_stresses=True) - - Notes: - This method can work with both neighbor list and full pairwise calculations. - The soft sphere potential is purely repulsive, and forces are truncated at - the cutoff distance. + Pair energies, shape [n_pairs]. """ - positions = state.positions - cell = state.row_vector_cell - cell = cell.squeeze() - pbc = state.pbc - - # Ensure system_idx exists (create if None for single system) - system_idx = ( - state.system_idx - if state.system_idx is not None - else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device) - ) + idx_i = self.z_to_idx[zi] + idx_j = self.z_to_idx[zj] + sigma = self.sigma_matrix[idx_i, idx_j] + epsilon = self.epsilon_matrix[idx_i, idx_j] + alpha = self.alpha_matrix[idx_i, idx_j] + energy = epsilon / alpha * (1.0 - dr / sigma).pow(alpha) + return torch.where(dr < sigma, energy, torch.zeros_like(energy)) - # Wrap positions into the unit cell - wrapped_positions = ( - ts.transforms.pbc_wrap_batched(positions, state.cell, system_idx, pbc) - if pbc.any() - else positions - ) - if self.use_neighbor_list: - mapping, _, shifts_idx = torchsim_nl( - positions=wrapped_positions, - cell=cell, - pbc=pbc, - cutoff=self.cutoff, - system_idx=system_idx, - ) - # Pass shifts_idx directly - get_pair_displacements will convert them - dr_vec, distances = transforms.get_pair_displacements( - positions=wrapped_positions, - cell=cell, - pbc=pbc, - pairs=(mapping[0], mapping[1]), - shifts=shifts_idx, - ) - - else: - # Direct N^2 computation of all pairs - dr_vec, distances = transforms.get_pair_displacements( - positions=wrapped_positions, - cell=cell, - pbc=pbc, - ) - # Remove self-interactions and apply cutoff - mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device) - distances = distances.masked_fill(mask, float("inf")) - mask = distances < self.cutoff - - # Get valid pairs and their displacements - i, j = torch.where(mask) - mapping = torch.stack([j, i]) - dr_vec = dr_vec[mask] - distances = distances[mask] - - # Calculate pair energies using soft sphere potential - pair_energies = soft_sphere_pair( - distances, sigma=self.sigma, epsilon=self.epsilon, alpha=self.alpha +DEFAULT_SIGMA = torch.tensor(1.0) +DEFAULT_EPSILON = torch.tensor(1.0) +DEFAULT_ALPHA = torch.tensor(2.0) + + +class SoftSphereModel(PairPotentialModel): + """Soft-sphere repulsive pair potential model. + + Convenience subclass that fixes the pair function to :func:`soft_sphere_pair` + so the caller only needs to supply ``sigma``, ``epsilon``, and ``alpha``. + + Example:: + + model = SoftSphereModel( + sigma=3.405, + epsilon=0.0104, + alpha=2.0, + compute_forces=True, ) + results = model(sim_state) + """ - # Initialize results with total energy (divide by 2 to avoid double counting) - results = {"energy": 0.5 * pair_energies.sum()} - - if self.per_atom_energies: - # Compute per-atom energy contributions - atom_energies = torch.zeros( - positions.shape[0], dtype=self.dtype, device=self.device - ) - # Each atom gets half of the pair energy - atom_energies.index_add_(0, mapping[0], 0.5 * pair_energies) - atom_energies.index_add_(0, mapping[1], 0.5 * pair_energies) - results["energies"] = atom_energies - - if self.compute_forces or self.compute_stress: - # Calculate pair forces - pair_forces = soft_sphere_pair_force( - distances, sigma=self.sigma, epsilon=self.epsilon, alpha=self.alpha - ) - - # Project scalar forces onto displacement vectors - force_vectors = (pair_forces / distances)[:, None] * dr_vec - - if self.compute_forces: - # Compute atomic forces by accumulating pair contributions - forces = torch.zeros_like(positions) - forces.index_add_(0, mapping[0], -force_vectors) - forces.index_add_(0, mapping[1], force_vectors) - results["forces"] = forces - - if self.compute_stress and cell is not None: - # Compute stress tensor using virial formula - stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors) - volume = torch.abs(torch.linalg.det(cell)) - - results["stress"] = -stress_per_pair.sum(dim=0) / volume - - if self.per_atom_stresses: - # Compute per-atom stress contributions - atom_stresses = torch.zeros( - (positions.shape[0], 3, 3), dtype=self.dtype, device=self.device - ) - atom_stresses.index_add_(0, mapping[0], -0.5 * stress_per_pair) - atom_stresses.index_add_(0, mapping[1], -0.5 * stress_per_pair) - results["stresses"] = atom_stresses / volume - - return results - - def forward(self, state: ts.SimState, **_kwargs) -> dict[str, torch.Tensor]: - """Compute soft sphere potential energies, forces, and stresses for a system. - - Main entry point for soft sphere potential calculations that handles batched - states by dispatching each system to the unbatched implementation and combining - results. + def __init__( + self, + sigma: float = 1.0, + epsilon: float = 1.0, + alpha: float = 2.0, + device: torch.device | None = None, + dtype: torch.dtype = torch.float64, + *, + compute_forces: bool = True, + compute_stress: bool = False, + per_atom_energies: bool = False, + per_atom_stresses: bool = False, + neighbor_list_fn: Callable = torchsim_nl, + use_neighbor_list: bool = True, # noqa: ARG002 + cutoff: float | None = None, + retain_graph: bool = False, + ) -> None: + """Initialize the soft sphere model. Args: - state (SimState): Input state containing atomic positions, cell vectors, - and other system information. - **_kwargs: Unused; accepted for interface compatibility. - - Returns: - dict[str, torch.Tensor]: Computed properties: - - "energy": Potential energy with shape [n_systems] - - "forces": Atomic forces with shape [n_atoms, 3] - (if compute_forces=True) - - "stress": Stress tensor with shape [n_systems, 3, 3] - (if compute_stress=True) - - May include additional outputs based on configuration - - Raises: - ValueError: If system indices cannot be inferred for multi-cell systems. - - Examples: - ```py - # Compute properties for a simulation state - model = SoftSphereModel(compute_forces=True) - results = model(sim_state) - - energy = results["energy"] # Shape: [n_systems] - forces = results["forces"] # Shape: [n_atoms, 3] - ``` + sigma: Effective particle diameter. Defaults to 1.0. + epsilon: Energy scale parameter. Defaults to 1.0. + alpha: Repulsion exponent. Defaults to 2.0. + device: Device for computations. Defaults to CPU. + dtype: Floating-point dtype. Defaults to torch.float32. + compute_forces: Whether to compute atomic forces. Defaults to True. + compute_stress: Whether to compute the stress tensor. Defaults to False. + per_atom_energies: Whether to return per-atom energies. Defaults to False. + per_atom_stresses: Whether to return per-atom stresses. Defaults to False. + neighbor_list_fn: Neighbor-list constructor. Defaults to torchsim_nl. + use_neighbor_list: Accepted for backward compatibility (ignored). + cutoff: Interaction cutoff. Defaults to sigma. + retain_graph: Keep computation graph for differentiable simulation. """ - # Handle System indices if not provided - if state.system_idx is None and state.cell.shape[0] > 1: - raise ValueError( - "system_idx can only be inferred if there is only one system" - ) - - outputs = [self.unbatched_forward(state[i]) for i in range(state.n_systems)] - properties = outputs[0] - - # Combine results - results: dict[str, torch.Tensor] = {} - for key in ("stress", "energy"): - if key in properties: - results[key] = torch.stack([out[key] for out in outputs]) - for key in ("forces", "energies", "stresses"): - if key in properties: - results[key] = torch.cat([out[key] for out in outputs], dim=0) - - return results - - -class SoftSphereMultiModel(ModelInterface): - """Calculator for systems with multiple particle types. - - Extends the basic soft sphere model to support multiple particle types with - different interaction parameters for each pair of particle types. This enables - simulation of heterogeneous systems like mixtures, composites, or biomolecular - systems with different interaction strengths between different components. - - This model maintains matrices of interaction parameters (sigma, epsilon, alpha) - where each element [i,j] represents the parameter for interactions between - particle types i and j. - - Attributes: - species (torch.Tensor): Particle type indices for each particle in the system. - sigma_matrix (torch.Tensor): Matrix of distance parameters for each pair of types. - Shape: [n_types, n_types]. - epsilon_matrix (torch.Tensor): Matrix of energy scale parameters for each pair. - Shape: [n_types, n_types]. - alpha_matrix (torch.Tensor): Matrix of exponents for each pair of types. - Shape: [n_types, n_types]. - cutoff (torch.Tensor): Maximum interaction distance. - compute_forces (bool): Whether to compute forces. - compute_stress (bool): Whether to compute stress tensor. - per_atom_energies (bool): Whether to compute per-atom energy decomposition. - per_atom_stresses (bool): Whether to compute per-atom stress decomposition. - use_neighbor_list (bool): Whether to use neighbor list optimization. - periodic (bool): Whether to use periodic boundary conditions. - _device (torch.device): Computation device (CPU/GPU). - _dtype (torch.dtype): Data type for tensor calculations. - - Examples: - ```py - # Create a binary mixture with different interaction parameters - # Define interaction matrices (size 2x2 for binary system) - sigma_matrix = torch.tensor( - [ - [1.0, 0.8], # Type 0-0 and 0-1 interactions - [0.8, 0.6], # Type 1-0 and 1-1 interactions - ] - ) + self.sigma = sigma + self.epsilon = epsilon + self.alpha = alpha - epsilon_matrix = torch.tensor( - [ - [1.0, 0.5], # Type 0-0 and 0-1 interactions - [0.5, 2.0], # Type 1-0 and 1-1 interactions - ] + pair_fn = functools.partial( + soft_sphere_pair, sigma=sigma, epsilon=epsilon, alpha=alpha ) + super().__init__( + pair_fn=pair_fn, + cutoff=cutoff if cutoff is not None else sigma, + device=device, + dtype=dtype, + compute_forces=compute_forces, + compute_stress=compute_stress, + per_atom_energies=per_atom_energies, + per_atom_stresses=per_atom_stresses, + neighbor_list_fn=neighbor_list_fn, + reduce_to_half_list=True, + retain_graph=retain_graph, + ) + - # Particle type assignments (0 or 1 for each particle) - species = torch.tensor([0, 0, 1, 1, 0, 1]) +class SoftSphereMultiModel(PairPotentialModel): + """Multi-species soft-sphere potential model. + + Uses :class:`MultiSoftSpherePairFn` internally + to look up per-species-pair parameters from matrices. + + Example:: - # Create the model model = SoftSphereMultiModel( - species=species, - sigma_matrix=sigma_matrix, - epsilon_matrix=epsilon_matrix, + atomic_numbers=torch.tensor([18, 36]), + sigma_matrix=torch.tensor([[1.0, 0.8], [0.8, 0.6]]), + epsilon_matrix=torch.tensor([[1.0, 0.5], [0.5, 2.0]]), compute_forces=True, ) - - # Compute properties - results = model(simulation_state) - ``` + results = model(sim_state) """ def __init__( self, - n_species: int, + atomic_numbers: torch.Tensor, sigma_matrix: torch.Tensor | None = None, epsilon_matrix: torch.Tensor | None = None, alpha_matrix: torch.Tensor | None = None, device: torch.device | None = None, dtype: torch.dtype = torch.float64, - *, # Force keyword-only arguments + *, pbc: torch.Tensor | bool = True, compute_forces: bool = True, compute_stress: bool = False, per_atom_energies: bool = False, per_atom_stresses: bool = False, use_neighbor_list: bool = True, + neighbor_list_fn: Callable = torchsim_nl, cutoff: float | None = None, + retain_graph: bool = False, ) -> None: - """Initialize a soft sphere model for multi-component systems. - - Creates a model for systems with multiple particle types, each with potentially - different interaction parameters. + """Initialize the multi-species soft sphere model. Args: - n_species (int): Number of particle types. - sigma_matrix (torch.Tensor | None): Matrix of distance parameters for - each pair of types. Shape [n_types, n_types]. If None, uses default - value 1.0 for all pairs. Defaults to None. - epsilon_matrix (torch.Tensor | None): Matrix of energy scale parameters - for each pair of types. Shape [n_types, n_types]. If None, uses - default value 1.0 for all pairs. Defaults to None. - alpha_matrix (torch.Tensor | None): Matrix of exponents for each pair. - Shape [n_types, n_types]. If None, uses default value 2.0 for all - pairs. Defaults to None. - device (torch.device | None): Device for computations. If None, uses CPU. - Defaults to None. - dtype (torch.dtype): Data type for calculations. Defaults to torch.float32. - pbc (torch.Tensor | bool): Boolean tensor of shape (3,) indicating periodic - boundary conditions in each axis. If None, all axes are assumed to be - periodic. Defaults to True. - compute_forces (bool): Whether to compute forces. Defaults to True. - compute_stress (bool): Whether to compute stress tensor. Defaults to False. - per_atom_energies (bool): Whether to compute per-atom energy decomposition. - Defaults to False. - per_atom_stresses (bool): Whether to compute per-atom stress decomposition. - Defaults to False. - use_neighbor_list (bool): Whether to use a neighbor list for optimization. - Defaults to True. - cutoff (float | None): Cutoff distance for interactions. If None, uses - the maximum value from sigma_matrix. Defaults to None. - - Examples: - ```py - # Binary polymer mixture with different interactions - # Polymer A (type 0): larger, softer particles - # Polymer B (type 1): smaller, harder particles - - # Create species assignment (100 particles total) - species = torch.cat( - [ - torch.zeros(50, dtype=torch.long), # 50 particles of type 0 - torch.ones(50, dtype=torch.long), # 50 particles of type 1 - ] - ) - - # Interaction matrices - sigma = torch.tensor( - [ - [1.2, 1.0], # A-A and A-B interactions - [1.0, 0.8], # B-A and B-B interactions - ] - ) - - epsilon = torch.tensor( - [ - [1.0, 1.5], # A-A and A-B interactions - [1.5, 2.0], # B-A and B-B interactions - ] - ) - - # Create model with mixing rules - model = SoftSphereMultiModel( - species=species, - sigma_matrix=sigma, - epsilon_matrix=epsilon, - compute_forces=True, - ) - ``` - - Notes: - The interaction matrices must be symmetric for physical consistency - (e.g., interaction of type 0 with type 1 should be the same as type 1 - with type 0). + atomic_numbers: Atomic numbers of atoms in the system. May contain + duplicates; only the sorted unique values are used to define + species and determine matrix dimensions. + sigma_matrix: Symmetric matrix of interaction diameters. + Shape [n_species, n_species]. Defaults to 1.0 for all pairs. + epsilon_matrix: Symmetric matrix of energy scales. + Shape [n_species, n_species]. Defaults to 1.0 for all pairs. + alpha_matrix: Symmetric matrix of repulsion exponents. + Shape [n_species, n_species]. Defaults to 2.0 for all pairs. + device: Device for computations. Defaults to CPU. + dtype: Floating-point dtype. Defaults to torch.float64. + pbc: Periodic boundary conditions (kept for backward compat). Defaults + to True. + compute_forces: Whether to compute atomic forces. Defaults to True. + compute_stress: Whether to compute the stress tensor. Defaults to False. + per_atom_energies: Whether to return per-atom energies. Defaults to False. + per_atom_stresses: Whether to return per-atom stresses. Defaults to False. + use_neighbor_list: Accepted for backward compatibility (a neighbor list + is always used internally). Defaults to True. + neighbor_list_fn: Neighbor-list constructor. Defaults to torchsim_nl. + cutoff: Interaction cutoff. Defaults to max of sigma_matrix. + retain_graph: Keep computation graph for differentiable simulation. """ - super().__init__() - self._device = device or torch.device("cpu") - self._dtype = dtype self.pbc = torch.tensor([pbc] * 3) if isinstance(pbc, bool) else pbc - self._compute_forces = compute_forces - self._compute_stress = compute_stress - self.per_atom_energies = per_atom_energies - self.per_atom_stresses = per_atom_stresses self.use_neighbor_list = use_neighbor_list + unique_z = torch.unique(atomic_numbers).sort().values.long() + n_species = len(unique_z) self.n_species = n_species - # Initialize parameter matrices with defaults if not provided - default_sigma = DEFAULT_SIGMA.to(device=self.device, dtype=self.dtype) - default_epsilon = DEFAULT_EPSILON.to(device=self.device, dtype=self.dtype) - default_alpha = DEFAULT_ALPHA.to(device=self.device, dtype=self.dtype) + _device = device or torch.device("cpu") + default_sigma = DEFAULT_SIGMA.to(device=_device, dtype=dtype) + default_epsilon = DEFAULT_EPSILON.to(device=_device, dtype=dtype) + default_alpha = DEFAULT_ALPHA.to(device=_device, dtype=dtype) - # Validate matrix shapes match number of species if sigma_matrix is not None and sigma_matrix.shape != (n_species, n_species): raise ValueError(f"sigma_matrix must have shape ({n_species}, {n_species})") if epsilon_matrix is not None and epsilon_matrix.shape != ( @@ -621,246 +329,49 @@ def __init__( if alpha_matrix is not None and alpha_matrix.shape != (n_species, n_species): raise ValueError(f"alpha_matrix must have shape ({n_species}, {n_species})") - # Create parameter matrices, using defaults if not provided self.sigma_matrix = ( sigma_matrix if sigma_matrix is not None else default_sigma - * torch.ones((n_species, n_species), dtype=dtype, device=device) + * torch.ones((n_species, n_species), dtype=dtype, device=_device) ) self.epsilon_matrix = ( epsilon_matrix if epsilon_matrix is not None else default_epsilon - * torch.ones((n_species, n_species), dtype=dtype, device=device) + * torch.ones((n_species, n_species), dtype=dtype, device=_device) ) self.alpha_matrix = ( alpha_matrix if alpha_matrix is not None else default_alpha - * torch.ones((n_species, n_species), dtype=dtype, device=device) + * torch.ones((n_species, n_species), dtype=dtype, device=_device) ) - # Ensure parameter matrices are symmetric (required for energy conservation) for matrix_name in ("sigma_matrix", "epsilon_matrix", "alpha_matrix"): matrix = getattr(self, matrix_name) if not torch.allclose(matrix, matrix.T): raise ValueError(f"{matrix_name} is not symmetric") - # Set interaction cutoff distance - self.cutoff = torch.tensor( - cutoff or float(self.sigma_matrix.max()), dtype=dtype, device=device - ) - - def unbatched_forward( - self, - state: ts.SimState, - ) -> dict[str, torch.Tensor]: - """Compute energies and forces for a single unbatched system with multiple - species. - - Internal implementation that processes a single, non-batched simulation state. - This method handles all pair interactions between particles of different types - using the appropriate interaction parameters from the parameter matrices. + _cutoff = cutoff or float(self.sigma_matrix.detach().max()) - Args: - state (SimState): Single, non-batched simulation state containing atomic - positions, cell vectors, and other system information. Species indices - are read from ``state.atomic_numbers``. - - Returns: - dict[str, torch.Tensor]: Computed properties: - - "energy": Total potential energy (scalar) - - "forces": Atomic forces with shape [n_atoms, 3] - (if compute_forces=True) - - "stress": Stress tensor with shape [3, 3] - (if compute_stress=True) - - "energies": Per-atom energies with shape [n_atoms] - (if per_atom_energies=True) - - "stresses": Per-atom stresses with shape [n_atoms, 3, 3] - (if per_atom_stresses=True) - - Notes: - This method supports both neighbor list optimization and full pairwise - calculations based on the use_neighbor_list parameter. For each pair of - particles, it looks up the appropriate parameters based on the species - of the two particles. - """ - species_idx = state.atomic_numbers.to(device=self.device, dtype=torch.long) - - positions = state.positions - cell = state.row_vector_cell - cell = cell.squeeze() - - # Compute neighbor list or full distance matrix - if self.use_neighbor_list: - # Get neighbor list for efficient computation - # Ensure system_idx exists (create if None for single system) - system_idx = torch.zeros( - positions.shape[0], dtype=torch.long, device=self.device - ) - mapping, _, shifts_idx = torchsim_nl( - positions=positions, - cell=cell, - pbc=self.pbc, - cutoff=self.cutoff, - system_idx=system_idx, - ) - # Pass shifts_idx directly - get_pair_displacements will convert them - dr_vec, distances = transforms.get_pair_displacements( - positions=positions, - cell=cell, - pbc=self.pbc, - pairs=(mapping[0], mapping[1]), - shifts=shifts_idx, - ) - - else: - # Direct N^2 computation of all pairs - dr_vec, distances = transforms.get_pair_displacements( - positions=positions, - cell=cell, - pbc=self.pbc, - ) - # Remove self-interactions and apply cutoff - mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device) - distances = distances.masked_fill(mask, float("inf")) - mask = distances < self.cutoff - - # Get valid pairs and their displacements - i, j = torch.where(mask) - mapping = torch.stack([j, i]) - dr_vec = dr_vec[mask] - distances = distances[mask] - - # Look up species-specific parameters for each interacting pair - pair_species_1 = species_idx[mapping[0]] # Species of first atom in pair - pair_species_2 = species_idx[mapping[1]] # Species of second atom in pair - - # Get interaction parameters from parameter matrices - pair_sigmas = self.sigma_matrix[pair_species_1, pair_species_2] - pair_epsilons = self.epsilon_matrix[pair_species_1, pair_species_2] - pair_alphas = self.alpha_matrix[pair_species_1, pair_species_2] - - # Calculate pair energies using species-specific parameters - pair_energies = soft_sphere_pair( - distances, sigma=pair_sigmas, epsilon=pair_epsilons, alpha=pair_alphas + pair_fn = MultiSoftSpherePairFn( + atomic_numbers=unique_z.to(device=_device), + sigma_matrix=self.sigma_matrix, + epsilon_matrix=self.epsilon_matrix, + alpha_matrix=self.alpha_matrix, ) - # Initialize results with total energy (divide by 2 to avoid double counting) - results = {"energy": 0.5 * pair_energies.sum()} - - if self.per_atom_energies: - # Compute per-atom energy contributions - atom_energies = torch.zeros( - positions.shape[0], dtype=self.dtype, device=self.device - ) - # Each atom gets half of the pair energy - atom_energies.index_add_(0, mapping[0], 0.5 * pair_energies) - atom_energies.index_add_(0, mapping[1], 0.5 * pair_energies) - results["energies"] = atom_energies - - if self.compute_forces or self.compute_stress: - # Calculate pair forces - pair_forces = soft_sphere_pair_force( - distances, sigma=pair_sigmas, epsilon=pair_epsilons, alpha=pair_alphas - ) - - # Project scalar forces onto displacement vectors - force_vectors = (pair_forces / distances)[:, None] * dr_vec - - if self.compute_forces: - # Compute atomic forces by accumulating pair contributions - forces = torch.zeros_like(positions) - forces.index_add_(0, mapping[0], -force_vectors) - forces.index_add_(0, mapping[1], force_vectors) - results["forces"] = forces - - if self.compute_stress and cell is not None: - # Compute stress tensor using virial formula - stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors) - volume = torch.abs(torch.linalg.det(cell)) - - results["stress"] = -stress_per_pair.sum(dim=0) / volume - - if self.per_atom_stresses: - # Compute per-atom stress contributions - atom_stresses = torch.zeros( - (positions.shape[0], 3, 3), dtype=self.dtype, device=self.device - ) - atom_stresses.index_add_(0, mapping[0], -0.5 * stress_per_pair) - atom_stresses.index_add_(0, mapping[1], -0.5 * stress_per_pair) - results["stresses"] = atom_stresses / volume - - return results - - def forward(self, state: ts.SimState, **_kwargs) -> dict[str, torch.Tensor]: - """Compute soft sphere potential properties for multi-component systems. - - Main entry point for multi-species soft sphere calculations that handles - batched states by dispatching each batch to the unbatched implementation - and combining results. - - Args: - state (SimState): Input state containing atomic positions, cell vectors, - and other system information. - **_kwargs: Unused; accepted for interface compatibility. - - Returns: - dict[str, torch.Tensor]: Computed properties: - - "energy": Potential energy with shape [n_systems] - - "forces": Atomic forces with shape [n_atoms, 3] - (if compute_forces=True) - - "stress": Stress tensor with shape [n_systems, 3, 3] - (if compute_stress=True) - - May include additional outputs based on configuration - - Raises: - ValueError: If batch cannot be inferred for multi-cell systems or if - species information is missing. - - Examples: - ```py - # Create model for binary mixture - model = SoftSphereMultiModel( - species=particle_types, - sigma_matrix=distance_matrix, - epsilon_matrix=strength_matrix, - compute_forces=True, - ) - - # Calculate properties - results = model(simulation_state) - energy = results["energy"] - forces = results["forces"] - ``` - - Notes: - This method requires species information either provided during initialization - or included in the state object's metadata. - """ - if state.pbc != self.pbc: - raise ValueError("PBC mismatch between model and state") - - # Handle system indices if not provided - if state.system_idx is None and state.cell.shape[0] > 1: - raise ValueError( - "system_idx can only be inferred if there is only one system" - ) - - outputs = [ - self.unbatched_forward(state[sys_idx]) for sys_idx in range(state.n_systems) - ] - properties = outputs[0] - - # Combine results - results: dict[str, torch.Tensor] = {} - for key in ("stress", "energy", "forces", "energies", "stresses"): - if key in properties: - results[key] = torch.stack([out[key] for out in outputs]) - - for key in ("forces", "energies", "stresses"): - if key in properties: - results[key] = torch.cat([out[key] for out in outputs], dim=0) - - return results + super().__init__( + pair_fn=pair_fn, + cutoff=_cutoff, + device=device, + dtype=dtype, + compute_forces=compute_forces, + compute_stress=compute_stress, + per_atom_energies=per_atom_energies, + per_atom_stresses=per_atom_stresses, + neighbor_list_fn=neighbor_list_fn, + reduce_to_half_list=True, + retain_graph=retain_graph, + ) diff --git a/torch_sim/workflows/a2c.py b/torch_sim/workflows/a2c.py index 1ee49088..81d22078 100644 --- a/torch_sim/workflows/a2c.py +++ b/torch_sim/workflows/a2c.py @@ -282,7 +282,6 @@ def random_packed_structure( device=device, dtype=dtype, compute_forces=True, - use_neighbor_list=True, ) # Dummy atomic numbers @@ -402,15 +401,12 @@ def random_packed_structure_multi( # Convert fractional to cartesian coordinates positions_cart = torch.matmul(positions, cell) - # Initialize multi-species soft sphere potential calculator - n_species = len(element_counts) model = SoftSphereMultiModel( - n_species=n_species, + atomic_numbers=species_idx, sigma_matrix=diameter_matrix, device=device, dtype=dtype, compute_forces=True, - use_neighbor_list=True, ) state_dict = ts.SimState(