diff --git a/.gitignore b/.gitignore
index 8fe23290..b0bdd821 100644
--- a/.gitignore
+++ b/.gitignore
@@ -46,3 +46,6 @@ uv.lock
# agents
.agents/
+
+# treemap html file
+docs/_static/torch-sim-pkg-treemap.html
diff --git a/docs/_static/torch-sim-pkg-treemap.svg b/docs/_static/torch-sim-pkg-treemap.svg
index 81070aee..c7344fa9 100644
--- a/docs/_static/torch-sim-pkg-treemap.svg
+++ b/docs/_static/torch-sim-pkg-treemap.svg
@@ -1 +1 @@
-
+
diff --git a/examples/tutorials/diff_sim.py b/examples/tutorials/diff_sim.py
index 1092c275..cf96c6e6 100644
--- a/examples/tutorials/diff_sim.py
+++ b/examples/tutorials/diff_sim.py
@@ -9,17 +9,19 @@
#
# %%
+import typing
import torch
import matplotlib.pyplot as plt
+from torch_sim.state import SimState
from torch_sim.models.soft_sphere import (
+ SoftSphereMultiModel,
soft_sphere_pair,
- DEFAULT_SIGMA,
- DEFAULT_EPSILON,
- DEFAULT_ALPHA,
)
-from torch_sim import transforms
-from collections.abc import Callable
-from dataclasses import dataclass
+from torch_sim.neighbors import torch_nl_n2
+from torch_sim.optimizers.gradient_descent import (
+ gradient_descent_init,
+ gradient_descent_step,
+)
from torch._functorch import config
config.donated_buffer = False
@@ -99,244 +101,16 @@ def draw_system(
plt.xlim([0, 1.5])
plt.ylim([-0.2, 0.8])
-# model = SoftSphereMultiModel(sigma_matrix=torch.tensor([1.0]))
dr = torch.linspace(0, 3.0, 80)
-plt.plot(dr, soft_sphere_pair(dr, sigma=1), "b-", linewidth=3)
-plt.fill_between(dr, soft_sphere_pair(dr), alpha=0.4)
+_z = torch.zeros_like(dr, dtype=torch.long)
+plt.plot(dr, soft_sphere_pair(dr, _z, _z, sigma=1), "b-", linewidth=3)
+plt.fill_between(dr, soft_sphere_pair(dr, _z, _z), alpha=0.4)
plt.xlabel(r"$r$", fontsize=20)
plt.ylabel(r"$U(r)$", fontsize=20)
plt.show()
-# %% [markdown]
-"""
-## Define the simple TorchSim model for the soft sphere potential.
-"""
-
-
-# %%
-@dataclass
-class BaseState:
- """Simple simulation state"""
-
- positions: torch.Tensor
- cell: torch.Tensor
- pbc: torch.Tensor
- species: torch.Tensor
-
-
-class SoftSphereMultiModel(torch.nn.Module):
- """Soft sphere potential"""
-
- def __init__(
- self,
- species: torch.Tensor | None = None,
- sigma_matrix: torch.Tensor | None = None,
- epsilon_matrix: torch.Tensor | None = None,
- alpha_matrix: torch.Tensor | None = None,
- device: torch.device | None = None,
- dtype: torch.dtype = torch.float32,
- *, # Force keyword-only arguments
- pbc: torch.Tensor | bool = True,
- cutoff: float | None = None,
- ) -> None:
- """Initialize a soft sphere model for multi-component systems."""
- super().__init__()
- self.device = device or torch.device("cpu")
- self.dtype = dtype
- self.pbc = (
- pbc
- if isinstance(pbc, torch.Tensor)
- else torch.tensor([pbc] * 3, dtype=torch.bool)
- )
-
- # Store species list and determine number of unique species
- self.species = species
- n_species = len(torch.unique(species))
-
- # Initialize parameter matrices with defaults if not provided
- default_sigma = DEFAULT_SIGMA.to(device=self.device, dtype=self.dtype)
- default_epsilon = DEFAULT_EPSILON.to(device=self.device, dtype=self.dtype)
- default_alpha = DEFAULT_ALPHA.to(device=self.device, dtype=self.dtype)
-
- # Validate matrix shapes match number of species
- if sigma_matrix is not None and sigma_matrix.shape != (n_species, n_species):
- raise ValueError(f"sigma_matrix must have shape ({n_species}, {n_species})")
- if epsilon_matrix is not None and epsilon_matrix.shape != (
- n_species,
- n_species,
- ):
- raise ValueError(f"epsilon_matrix must have shape ({n_species}, {n_species})")
- if alpha_matrix is not None and alpha_matrix.shape != (n_species, n_species):
- raise ValueError(f"alpha_matrix must have shape ({n_species}, {n_species})")
-
- # Create parameter matrices, using defaults if not provided
- self.sigma_matrix = (
- sigma_matrix
- if sigma_matrix is not None
- else default_sigma
- * torch.ones((n_species, n_species), dtype=dtype, device=device)
- )
- self.epsilon_matrix = (
- epsilon_matrix
- if epsilon_matrix is not None
- else default_epsilon
- * torch.ones((n_species, n_species), dtype=dtype, device=device)
- )
- self.alpha_matrix = (
- alpha_matrix
- if alpha_matrix is not None
- else default_alpha
- * torch.ones((n_species, n_species), dtype=dtype, device=device)
- )
-
- # Ensure parameter matrices are symmetric (required for energy conservation)
- for matrix_name in ("sigma_matrix", "epsilon_matrix", "alpha_matrix"):
- matrix = getattr(self, matrix_name)
- if not torch.allclose(matrix, matrix.T):
- raise ValueError(f"{matrix_name} is not symmetric")
-
- # Set interaction cutoff distance
- self.cutoff = torch.tensor(
- cutoff or float(self.sigma_matrix.max()), dtype=dtype, device=device
- )
-
- def forward(
- self,
- custom_state: BaseState,
- species: torch.Tensor | None = None,
- ) -> dict[str, torch.Tensor]:
- """Compute energies and forces for a single unbatched system with multiple
- species."""
- # Convert inputs to proper device/dtype and handle species
- positions = custom_state.positions.requires_grad_(True)
- cell = custom_state.cell
- species = custom_state.species
-
- if species is not None:
- species = species.to(device=self.device, dtype=torch.long)
- else:
- species = self.species
-
- species_idx = species
-
- # Direct N^2 computation of all pairs (minimum image convention)
- dr_vec, distances = transforms.get_pair_displacements(
- positions=positions,
- cell=cell,
- pbc=self.pbc,
- )
- # Remove self-interactions and apply cutoff
- mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device)
- distances = distances.masked_fill(mask, float("inf"))
- mask = distances < self.cutoff
-
- # Get valid pairs and their displacements
- i, j = torch.where(mask)
- mapping = torch.stack([j, i])
- dr_vec = dr_vec[mask]
- distances = distances[mask]
-
- # Look up species-specific parameters for each interacting pair
- pair_species_1 = species_idx[mapping[0]] # Species of first atom in pair
- pair_species_2 = species_idx[mapping[1]] # Species of second atom in pair
-
- # Get interaction parameters from parameter matrices
- pair_sigmas = self.sigma_matrix[pair_species_1, pair_species_2]
- pair_epsilons = self.epsilon_matrix[pair_species_1, pair_species_2]
- pair_alphas = self.alpha_matrix[pair_species_1, pair_species_2]
-
- # Calculate pair energies using species-specific parameters
- pair_energies = soft_sphere_pair(
- distances, sigma=pair_sigmas, epsilon=pair_epsilons, alpha=pair_alphas
- )
-
- # Initialize results with total energy (divide by 2 to avoid double counting)
- potential_energy = pair_energies.sum() / 2
-
- grad_outputs: list[torch.Tensor] = [
- torch.ones_like(
- potential_energy,
- )
- ]
- grad = torch.autograd.grad(
- outputs=[
- potential_energy,
- ],
- inputs=[positions],
- grad_outputs=grad_outputs,
- create_graph=False,
- retain_graph=True,
- )
-
- force_grad = grad[0]
- if force_grad is not None:
- forces = torch.neg(force_grad)
-
- return {"energy": potential_energy, "forces": forces}
-
-
-# %% [markdown]
-"""
-## Gradient Descent
-
-We will use a simple gradient descent to optimize the positions of the particles.
-"""
-
-
-# %%
-@dataclass
-class GDState(BaseState):
- """Simple simulation state"""
-
- forces: torch.Tensor
- energy: torch.Tensor
-
-
-def gradient_descent(
- model: torch.nn.Module, *, lr: torch.Tensor | float = 0.01
-) -> tuple[Callable[[BaseState], GDState], Callable[[GDState], GDState]]:
- """Initialize a gradient descent optimization."""
-
- def gd_init(
- state: BaseState,
- ) -> GDState:
- """Initialize the gradient descent optimization state."""
-
- # Get initial forces and energy from model
- model_output = model(state)
- energy = model_output["energy"]
- forces = model_output["forces"]
-
- return GDState(
- positions=state.positions,
- forces=forces,
- energy=energy,
- cell=state.cell,
- pbc=state.pbc,
- species=state.species,
- )
-
- def gd_step(state: GDState, lr: torch.Tensor | float = lr) -> GDState:
- """Perform one gradient descent optimization step to update the
- atomic positions. The cell is not optimized."""
-
- # Update positions using forces and per-atom learning rates
- state.positions = state.positions + lr * state.forces
-
- # Get updated forces and energy from model
- model_output = model(state)
-
- # Update state with new forces and energy
- state.forces = model_output["forces"]
- state.energy = model_output["energy"]
-
- return state
-
- return gd_init, gd_step
-
-
# %% [markdown]
"""
## Setup the simulation environment.
@@ -358,15 +132,14 @@ def box_size_at_packing_fraction(
def species_sigma(diameter: torch.Tensor) -> torch.Tensor:
- d_AA = diameter
- d_BB = 1
+ d_BB = torch.ones_like(diameter)
d_AB = 0.5 * (diameter + 1)
- return torch.tensor([[d_AA, d_AB], [d_AB, d_BB]])
+ return torch.stack([diameter, d_AB, d_AB, d_BB]).reshape(2, 2)
N = 128
N_2 = N // 2
-species = torch.tensor([0] * (N_2) + [1] * (N_2), dtype=torch.int32)
+species = torch.tensor([0] * (N_2) + [1] * (N_2), dtype=torch.long)
simulation_steps = 1000
packing_fraction = 0.98
markersize = 260
@@ -376,30 +149,30 @@ def species_sigma(diameter: torch.Tensor) -> torch.Tensor:
def simulation(
diameter: torch.Tensor, seed: int = 42
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- # Create the simulation environment.
box_size = box_size_at_packing_fraction(diameter, packing_fraction)
- cell = torch.eye(3) * box_size
- # Create the energy function.
+ cell = (torch.eye(3) * box_size).unsqueeze(0)
sigma = species_sigma(diameter)
- model = SoftSphereMultiModel(sigma_matrix=sigma, species=species)
- model = torch.compile(model)
- # Randomly initialize the system.
- # Fix seed for reproducible random positions
+ model = SoftSphereMultiModel(
+ atomic_numbers=species,
+ sigma_matrix=sigma,
+ dtype=torch.float32,
+ neighbor_list_fn=torch_nl_n2,
+ retain_graph=True,
+ )
+ # Use aot_eager backend as Inductor has issues with scatter operations (index_add/scatter_add)
+ model = typing.cast(SoftSphereMultiModel, torch.compile(model, backend="aot_eager"))
torch.manual_seed(seed)
R = torch.rand(N, 3) * box_size
-
- # Minimize to the nearest minimum.
- init_fn, apply_fn = gradient_descent(model, lr=0.1) # ty: ignore[invalid-argument-type]
-
- custom_state = BaseState(
+ state = SimState(
positions=R,
+ masses=torch.ones(N),
cell=cell,
- species=species,
- pbc=torch.tensor([True] * 3, dtype=torch.bool),
+ pbc=True,
+ atomic_numbers=species,
)
- state = init_fn(custom_state)
+ state = gradient_descent_init(state, model)
for _ in range(simulation_steps):
- state = apply_fn(state)
+ state = gradient_descent_step(state, model, pos_lr=0.1)
return box_size, model(state)["energy"], state.positions
@@ -482,41 +255,28 @@ def short_simulation(
) -> tuple[torch.Tensor, torch.Tensor]:
diameter = diameter.requires_grad_(True)
box_size = box_size_at_packing_fraction(diameter, packing_fraction)
- cell = torch.eye(3) * box_size
- # Create the energy function.
+ cell = (torch.eye(3) * box_size).unsqueeze(0)
sigma = species_sigma(diameter)
- model = SoftSphereMultiModel(sigma_matrix=sigma, species=species)
-
- # Minimize to the nearest minimum.
- init_fn, apply_fn = gradient_descent(model, lr=0.1)
-
- custom_state = BaseState(
+ model = SoftSphereMultiModel(
+ atomic_numbers=species,
+ sigma_matrix=sigma,
+ dtype=torch.float32,
+ neighbor_list_fn=torch_nl_n2,
+ retain_graph=True,
+ )
+ state = SimState(
positions=R,
+ masses=torch.ones(N),
cell=cell,
- species=species,
- pbc=torch.tensor([True, True, True], dtype=torch.bool),
- )
- state = init_fn(custom_state)
- for i in range(short_simulation_steps):
- state = apply_fn(state)
-
- grad_outputs: list[torch.Tensor] = [
- torch.ones_like(
- diameter,
- )
- ]
- grad = torch.autograd.grad(
- outputs=[
- model(state)["energy"],
- ],
- inputs=[diameter],
- grad_outputs=grad_outputs,
- create_graph=True,
- retain_graph=False,
+ pbc=True,
+ atomic_numbers=species,
)
-
- dU_dd = grad[0]
- return model(state)["energy"], dU_dd
+ state = gradient_descent_init(state, model)
+ for _ in range(short_simulation_steps):
+ state = gradient_descent_step(state, model, pos_lr=0.1)
+ energy = model(state)["energy"]
+ (dU_dd,) = torch.autograd.grad(energy, diameter, create_graph=True)
+ return energy, dU_dd
# %%
diff --git a/pyproject.toml b/pyproject.toml
index 38e08515..4322706b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -111,6 +111,7 @@ ignore = [
"PLR2004", # Magic value used in comparison, consider replacing {value} with a constant variable
"PLW2901", # Outer for loop variable overwritten by inner assignment target
"PTH", # flake8-use-pathlib
+ "RUF002", # Greek letters are discouraged
"S301", # pickle and modules that wrap it can be unsafe, possible security issue
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
"SIM105", # Use contextlib.suppress instead of try-except-pass
diff --git a/tests/models/test_lennard_jones.py b/tests/models/test_lennard_jones.py
index f70817e4..a3e5ed5a 100644
--- a/tests/models/test_lennard_jones.py
+++ b/tests/models/test_lennard_jones.py
@@ -1,408 +1,137 @@
-"""Cheap integration tests ensuring different parts of TorchSim work together."""
+"""Tests for the Lennard-Jones pair functions and wrapped model."""
-from collections.abc import Callable
-
-import pytest
import torch
-from ase.build import bulk
import torch_sim as ts
-from tests.conftest import DEVICE
-from torch_sim import neighbors
-from torch_sim.models.interface import validate_model_outputs
from torch_sim.models.lennard_jones import (
LennardJonesModel,
- UnbatchedLennardJonesModel,
lennard_jones_pair,
lennard_jones_pair_force,
)
+def _dummy_z(n: int) -> torch.Tensor:
+ return torch.ones(n, dtype=torch.long)
+
+
def test_lennard_jones_pair_minimum() -> None:
- """Test that the potential has its minimum at r=sigma."""
- dr = torch.linspace(0.8, 1.2, 100)
- dr = dr.reshape(-1, 1)
- energy = lennard_jones_pair(dr, sigma=1.0, epsilon=1.0)
- min_idx = torch.argmin(energy)
+ """Minimum of LJ is at r = 2^(1/6) * sigma."""
+ dr = torch.linspace(0.9, 1.5, 500)
+ z = _dummy_z(len(dr))
+ energies = lennard_jones_pair(dr, z, z, sigma=1.0, epsilon=1.0)
+ min_r = dr[energies.argmin()]
+ assert abs(min_r.item() - 2 ** (1 / 6)) < 0.01
- torch.testing.assert_close(
- dr[min_idx], torch.tensor([2 ** (1 / 6)]), rtol=1e-2, atol=1e-2
- )
+
+def test_lennard_jones_pair_energy_at_minimum() -> None:
+ """Energy at minimum equals -epsilon."""
+ r_min = torch.tensor([2 ** (1 / 6)])
+ z = _dummy_z(1)
+ e = lennard_jones_pair(r_min, z, z, sigma=1.0, epsilon=2.0)
+ torch.testing.assert_close(e, torch.tensor([-2.0]), rtol=1e-5, atol=1e-5)
-def test_lennard_jones_pair_scaling() -> None:
- """Test that the potential scales correctly with epsilon."""
- dr = torch.ones(5, 5) * 1.5
- e1 = lennard_jones_pair(dr, sigma=1.0, epsilon=1.0)
- e2 = lennard_jones_pair(dr, sigma=1.0, epsilon=2.0)
- torch.testing.assert_close(e2, 2 * e1)
+def test_lennard_jones_pair_epsilon_scaling() -> None:
+ """Energy scales linearly with epsilon."""
+ dr = torch.tensor([1.5])
+ z = _dummy_z(1)
+ e1 = lennard_jones_pair(dr, z, z, sigma=1.0, epsilon=1.0)
+ e2 = lennard_jones_pair(dr, z, z, sigma=1.0, epsilon=3.0)
+ torch.testing.assert_close(e2, 3.0 * e1)
def test_lennard_jones_pair_repulsive_core() -> None:
- """Test that the potential is strongly repulsive at short distances."""
- dr_close = torch.tensor([[0.5]]) # Less than sigma
- dr_far = torch.tensor([[2.0]]) # Greater than sigma
- e_close = lennard_jones_pair(dr_close)
- e_far = lennard_jones_pair(dr_far)
+ """The potential is strongly repulsive at short distances."""
+ z = _dummy_z(1)
+ e_close = lennard_jones_pair(torch.tensor([0.5]), z, z)
+ e_far = lennard_jones_pair(torch.tensor([2.0]), z, z)
assert e_close > e_far
assert e_close > 0 # Repulsive
assert e_far < 0 # Attractive
-def test_lennard_jones_pair_tensor_params() -> None:
- """Test that the function works with tensor parameters."""
- dr = torch.ones(3, 3) * 1.5
- sigma = torch.ones(3, 3)
- epsilon = torch.ones(3, 3) * 2.0
- energy = lennard_jones_pair(dr, sigma=sigma, epsilon=epsilon)
- assert energy.shape == (3, 3)
-
-
def test_lennard_jones_pair_zero_distance() -> None:
- """Test that the function handles zero distances gracefully."""
- dr = torch.zeros(2, 2)
- energy = lennard_jones_pair(dr)
+ """The function handles zero distances gracefully."""
+ dr = torch.zeros(2)
+ z = _dummy_z(2)
+ energy = lennard_jones_pair(dr, z, z)
assert not torch.isnan(energy).any()
assert not torch.isinf(energy).any()
-def test_lennard_jones_pair_batch() -> None:
- """Test that the function works with batched inputs."""
- batch_size = 10
- n_particles = 5
- dr = torch.rand(batch_size, n_particles, n_particles) + 0.5
- energy = lennard_jones_pair(dr)
- assert energy.shape == (batch_size, n_particles, n_particles)
-
-
def test_lennard_jones_pair_force_scaling() -> None:
- """Test that the force scales correctly with epsilon."""
- dr = torch.ones(5, 5) * 1.5
+ """Force scales linearly with epsilon."""
+ dr = torch.tensor([1.5])
f1 = lennard_jones_pair_force(dr, sigma=1.0, epsilon=1.0)
f2 = lennard_jones_pair_force(dr, sigma=1.0, epsilon=2.0)
- assert torch.allclose(f2, 2 * f1)
+ torch.testing.assert_close(f2, 2.0 * f1)
def test_lennard_jones_pair_force_repulsive_core() -> None:
- """Test that the force is strongly repulsive at short distances."""
- dr_close = torch.tensor([[0.5]]) # Less than sigma
- dr_far = torch.tensor([[2.0]]) # Greater than sigma
- f_close = lennard_jones_pair_force(dr_close)
- f_far = lennard_jones_pair_force(dr_far)
+ """Force is repulsive at short distances and attractive at long distances."""
+ f_close = lennard_jones_pair_force(torch.tensor([0.5]))
+ f_far = lennard_jones_pair_force(torch.tensor([2.0]))
assert f_close > 0 # Repulsive
assert f_far < 0 # Attractive
- assert abs(f_close) > abs(f_far) # Stronger at short range
-
-
-def test_lennard_jones_pair_force_tensor_params() -> None:
- """Test that the function works with tensor parameters."""
- dr = torch.ones(3, 3) * 1.5
- sigma = torch.ones(3, 3)
- epsilon = torch.ones(3, 3) * 2.0
- force = lennard_jones_pair_force(dr, sigma=sigma, epsilon=epsilon)
- assert force.shape == (3, 3)
+ assert abs(f_close) > abs(f_far)
def test_lennard_jones_pair_force_zero_distance() -> None:
- """Test that the function handles zero distances gracefully."""
- dr = torch.zeros(2, 2)
+ """The force function handles zero distances gracefully."""
+ dr = torch.zeros(2)
force = lennard_jones_pair_force(dr)
assert not torch.isnan(force).any()
assert not torch.isinf(force).any()
-def test_lennard_jones_pair_force_batch() -> None:
- """Test that the function works with batched inputs."""
- batch_size = 10
- n_particles = 5
- dr = torch.rand(batch_size, n_particles, n_particles) + 0.5
- force = lennard_jones_pair_force(dr)
- assert force.shape == (batch_size, n_particles, n_particles)
-
-
def test_lennard_jones_force_energy_consistency() -> None:
- """Test that the force is consistent with the energy gradient."""
+ """Force is consistent with the energy gradient."""
dr = torch.linspace(0.8, 2.0, 100, requires_grad=True)
- dr = dr.reshape(-1, 1)
+ z = _dummy_z(len(dr))
- # Calculate force directly
force_direct = lennard_jones_pair_force(dr)
- # Calculate force from energy gradient
- energy = lennard_jones_pair(dr)
+ energy = lennard_jones_pair(dr, z, z)
force_from_grad = -torch.autograd.grad(energy.sum(), dr, create_graph=True)[0]
- # Compare forces (allowing for some numerical differences)
- assert torch.allclose(force_direct, force_from_grad, rtol=1e-4, atol=1e-4)
-
-
-# NOTE: This is a large system to robustly compare neighbor-list implementations.
-@pytest.fixture
-def ar_supercell_sim_state_large() -> ts.SimState:
- """Create a face-centered cubic (FCC) Argon structure."""
- # Create FCC Ar using ASE, with 4x4x4 supercell
- ar_atoms = bulk("Ar", "fcc", a=5.26, cubic=True).repeat([4, 4, 4])
- return ts.io.atoms_to_state(ar_atoms, DEVICE, torch.float64)
-
-
-@pytest.fixture
-def models(
- ar_supercell_sim_state_large: ts.SimState,
-) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
- """Create default and explicit N2 neighbor-list outputs with Argon parameters."""
- model_kwargs: dict[str, float | bool | torch.dtype] = {
- "sigma": 3.405, # Å, typical for Ar
- "epsilon": 0.0104, # eV, typical for Ar
- "dtype": torch.float64,
- "compute_forces": True,
- "compute_stress": True,
- "per_atom_energies": True,
- "per_atom_stresses": True,
- }
- cutoff = 2.5 * 3.405 # Standard LJ cutoff * sigma
- model_default = LennardJonesModel(cutoff=cutoff, **model_kwargs)
- model_n2 = LennardJonesModel(
- cutoff=cutoff,
- neighbor_list_fn=neighbors.torch_nl_n2,
- **model_kwargs,
- )
-
- return (
- model_default(ar_supercell_sim_state_large),
- model_n2(ar_supercell_sim_state_large),
- )
-
-
-@pytest.fixture
-def standard_vs_batched_models(
- ar_supercell_sim_state_large: ts.SimState,
-) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
- """Create unbatched and default Lennard-Jones outputs for parity checks."""
- model_kwargs: dict[str, float | bool | torch.dtype | torch.device] = {
- "sigma": 3.405,
- "epsilon": 0.0104,
- "dtype": torch.float64,
- "device": DEVICE,
- "compute_forces": True,
- "compute_stress": True,
- "per_atom_energies": True,
- "per_atom_stresses": True,
- }
- cutoff = 2.5 * 3.405
- standard = UnbatchedLennardJonesModel(cutoff=cutoff, **model_kwargs)
- batched = LennardJonesModel(cutoff=cutoff, **model_kwargs)
- return standard(ar_supercell_sim_state_large), batched(ar_supercell_sim_state_large)
-
-
-def test_energy_match(
- models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]],
-) -> None:
- """Test that total energy matches across neighbor-list implementations."""
- results_default, results_n2 = models
- assert torch.allclose(results_default["energy"], results_n2["energy"], rtol=1e-10)
-
-
-def test_per_atom_energy_match(
- models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]],
-) -> None:
- """Test that per-atom energy matches across neighbor-list implementations."""
- results_default, results_n2 = models
- assert torch.allclose(results_default["energies"], results_n2["energies"], rtol=1e-10)
-
-
-def test_forces_match(
- models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]],
-) -> None:
- """Test that forces match across neighbor-list implementations."""
- results_default, results_n2 = models
- assert torch.allclose(results_default["forces"], results_n2["forces"], rtol=1e-10)
-
-
-def test_stress_match(
- models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]],
-) -> None:
- """Test that stress tensors match across neighbor-list implementations."""
- results_default, results_n2 = models
- assert torch.allclose(results_default["stress"], results_n2["stress"], rtol=1e-10)
-
-
-def test_per_atom_stress_match(
- models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]],
-) -> None:
- """Test that per-atom stress tensors match across neighbor-list implementations."""
- results_default, results_n2 = models
- assert torch.allclose(results_default["stresses"], results_n2["stresses"], rtol=1e-10)
-
-
-@pytest.mark.parametrize(
- "key",
- ["energy", "energies", "forces", "stress", "stresses"],
-)
-def test_batched_lj_matches_standard(
- standard_vs_batched_models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]],
- key: str,
-) -> None:
- """Test that the default batched model matches the unbatched reference."""
- standard_out, batched_out = standard_vs_batched_models
- torch.testing.assert_close(
- batched_out[key], standard_out[key], rtol=1e-10, atol=1e-10
- )
-
-
-def test_batched_lj_multi_system_matches_standard(
- ar_double_sim_state: ts.SimState,
-) -> None:
- """Test default model multi-system parity with unbatched reference."""
- model_kwargs: dict[str, float | bool | torch.dtype | torch.device] = {
- "sigma": 3.405,
- "epsilon": 0.0104,
- "dtype": torch.float64,
- "device": DEVICE,
- "compute_forces": True,
- "compute_stress": True,
- }
- cutoff = 2.5 * 3.405
- standard = UnbatchedLennardJonesModel(cutoff=cutoff, **model_kwargs)
- batched = LennardJonesModel(cutoff=cutoff, **model_kwargs)
-
- standard_out = standard(ar_double_sim_state)
- batched_out = batched(ar_double_sim_state)
-
- assert batched_out["energy"].shape == (ar_double_sim_state.n_systems,)
- for key in ("energy", "forces", "stress"):
- torch.testing.assert_close(
- batched_out[key], standard_out[key], rtol=1e-10, atol=1e-10
- )
-
-
-def test_force_conservation(
- models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]],
-) -> None:
- """Test that forces sum to zero."""
- results_default, _ = models
- assert torch.allclose(
- results_default["forces"].sum(dim=0),
- torch.zeros(3, dtype=torch.float64),
- atol=1e-10,
- )
-
-
-@pytest.mark.parametrize("model_cls", [UnbatchedLennardJonesModel, LennardJonesModel])
-@pytest.mark.parametrize(
- "neighbor_list_fn",
- [neighbors.torch_nl_linked_cell, neighbors.torch_nl_n2],
-)
-def test_custom_neighbor_list_fn_matches_default(
- model_cls: type[UnbatchedLennardJonesModel],
- neighbor_list_fn: Callable[..., tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
- ar_supercell_sim_state_large: ts.SimState,
-) -> None:
- """Test that custom neighbor-list implementations match default behavior."""
- model_kwargs: dict[str, float | bool | torch.dtype | torch.device] = {
- "sigma": 3.405,
- "epsilon": 0.0104,
- "dtype": torch.float64,
- "device": DEVICE,
- "compute_forces": True,
- "compute_stress": True,
- }
- cutoff = 2.5 * 3.405
- default_model = model_cls(
- cutoff=cutoff,
- neighbor_list_fn=neighbors.torchsim_nl,
- **model_kwargs,
- )
- custom_model = model_cls(
- cutoff=cutoff,
- neighbor_list_fn=neighbor_list_fn,
- **model_kwargs,
- )
-
- default_out = default_model(ar_supercell_sim_state_large)
- custom_out = custom_model(ar_supercell_sim_state_large)
-
- for key in ("energy", "forces", "stress"):
- torch.testing.assert_close(
- custom_out[key], default_out[key], rtol=1e-10, atol=1e-10
- )
-
+ torch.testing.assert_close(force_direct, force_from_grad, rtol=1e-4, atol=1e-4)
-def test_stress_tensor_symmetry(
- models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]],
-) -> None:
- """Test that stress tensor is symmetric."""
- results_default, _ = models
- # select trailing two dimensions
- stress_tensor = results_default["stress"][0]
- assert torch.allclose(stress_tensor, stress_tensor.T, atol=1e-10)
-
-
-def test_validate_model_outputs(lj_model: LennardJonesModel) -> None:
- """Test that the model outputs are valid."""
- validate_model_outputs(lj_model, DEVICE, torch.float64)
-
-
-def test_unwrapped_positions_consistency() -> None:
- """Test that wrapped and unwrapped positions give identical results.
-
- This tests that models correctly handle positions outside the unit cell
- by wrapping them before neighbor list computation.
- """
- # Create a periodic system
- ar_atoms = bulk("Ar", "fcc", a=5.26, cubic=True).repeat([2, 2, 2])
- cell = torch.tensor(ar_atoms.get_cell().array, dtype=torch.float64, device=DEVICE)
- # Create wrapped state (positions inside unit cell)
- state_wrapped = ts.io.atoms_to_state(ar_atoms, DEVICE, torch.float64)
-
- # Create unwrapped state by shifting some atoms outside the cell
- positions_unwrapped = state_wrapped.positions.clone()
- # Shift first half of atoms by +1 cell vector in x direction
- n_atoms = positions_unwrapped.shape[0]
- positions_unwrapped[: n_atoms // 2] += cell[0]
- # Shift some atoms by -1 cell vector in y direction
- positions_unwrapped[n_atoms // 4 : n_atoms // 2] -= cell[1]
-
- state_unwrapped = ts.SimState.from_state(state_wrapped, positions=positions_unwrapped)
-
- # Create model
+def test_lennard_jones_model_evaluation(si_double_sim_state: ts.SimState) -> None:
+ """LennardJonesModel (wrapped PairPotentialModel) evaluates correctly."""
model = LennardJonesModel(
sigma=3.405,
epsilon=0.0104,
cutoff=2.5 * 3.405,
dtype=torch.float64,
- device=DEVICE,
compute_forces=True,
compute_stress=True,
)
+ results = model(si_double_sim_state)
+ assert "energy" in results
+ assert "forces" in results
+ assert "stress" in results
+ assert results["energy"].shape == (si_double_sim_state.n_systems,)
+ assert results["forces"].shape == (si_double_sim_state.n_atoms, 3)
+ assert results["stress"].shape == (si_double_sim_state.n_systems, 3, 3)
- # Compute results
- results_wrapped = model(state_wrapped)
- results_unwrapped = model(state_unwrapped)
-
- # Verify energy matches
- torch.testing.assert_close(
- results_wrapped["energy"],
- results_unwrapped["energy"],
- rtol=1e-10,
- atol=1e-10,
- msg="Energies should match for wrapped and unwrapped positions",
- )
- # Verify forces match
- torch.testing.assert_close(
- results_wrapped["forces"],
- results_unwrapped["forces"],
- rtol=1e-10,
- atol=1e-10,
- msg="Forces should match for wrapped and unwrapped positions",
- )
-
- # Verify stress matches
- torch.testing.assert_close(
- results_wrapped["stress"],
- results_unwrapped["stress"],
- rtol=1e-10,
- atol=1e-10,
- msg="Stress should match for wrapped and unwrapped positions",
+def test_lennard_jones_model_force_conservation(
+ si_double_sim_state: ts.SimState,
+) -> None:
+ """LennardJonesModel forces sum to zero (Newton's third law)."""
+ model = LennardJonesModel(
+ sigma=3.405,
+ epsilon=0.0104,
+ cutoff=2.5 * 3.405,
+ dtype=torch.float64,
+ compute_forces=True,
)
+ results = model(si_double_sim_state)
+ for sys_idx in range(si_double_sim_state.n_systems):
+ mask = si_double_sim_state.system_idx == sys_idx
+ assert torch.allclose(
+ results["forces"][mask].sum(dim=0),
+ torch.zeros(3, dtype=torch.float64),
+ atol=1e-10,
+ )
diff --git a/tests/models/test_morse.py b/tests/models/test_morse.py
index f65cb10b..5fe8aa5c 100644
--- a/tests/models/test_morse.py
+++ b/tests/models/test_morse.py
@@ -1,144 +1,115 @@
-"""Tests for Morse potential calculator using copper parameters."""
+"""Tests for the Morse pair functions and wrapped model."""
-import pytest
import torch
import torch_sim as ts
from torch_sim.models.morse import MorseModel, morse_pair, morse_pair_force
-def test_morse_pair_minimum() -> None:
- """Test that the potential has its minimum at r=sigma."""
- dr = torch.linspace(0.8, 1.2, 100)
- dr = dr.reshape(-1, 1)
- energy = morse_pair(dr)
- min_idx = torch.argmin(energy)
- torch.testing.assert_close(dr[min_idx], torch.tensor([1.0]), rtol=0.01, atol=1e-5)
+def _dummy_z(n: int) -> torch.Tensor:
+ return torch.ones(n, dtype=torch.long)
+
+
+def test_morse_pair_minimum_at_sigma() -> None:
+ """Morse minimum is at r = sigma."""
+ dr = torch.linspace(0.5, 2.0, 500)
+ z = _dummy_z(len(dr))
+ energies = morse_pair(dr, z, z, sigma=1.0, epsilon=5.0, alpha=5.0)
+ min_r = dr[energies.argmin()]
+ assert abs(min_r.item() - 1.0) < 0.01
+
+
+def test_morse_pair_energy_at_minimum() -> None:
+ """Morse energy at minimum equals -epsilon."""
+ dr = torch.tensor([1.0])
+ z = _dummy_z(1)
+ e = morse_pair(dr, z, z, sigma=1.0, epsilon=5.0, alpha=5.0)
+ torch.testing.assert_close(e, torch.tensor([-5.0]), rtol=1e-5, atol=1e-5)
def test_morse_pair_scaling() -> None:
- """Test that the potential scales correctly with epsilon."""
- dr = torch.ones(5, 5) * 1.5
- e1 = morse_pair(dr, epsilon=1.0)
- e2 = morse_pair(dr, epsilon=2.0)
- torch.testing.assert_close(e2, 2 * e1, rtol=1e-5, atol=1e-5)
-
-
-def test_morse_pair_asymptotic() -> None:
- """Test that the potential approaches -epsilon at large distances."""
- dr = torch.tensor([[1.0]]) # Large distance
- epsilon = 5.0
- energy = morse_pair(dr, epsilon=epsilon)
- torch.testing.assert_close(
- energy, -epsilon * torch.ones_like(energy), rtol=1e-2, atol=1e-5
- )
+ """Energy scales linearly with epsilon."""
+ dr = torch.tensor([1.5])
+ z = _dummy_z(1)
+ e1 = morse_pair(dr, z, z, epsilon=1.0)
+ e2 = morse_pair(dr, z, z, epsilon=2.0)
+ torch.testing.assert_close(e2, 2.0 * e1, rtol=1e-5, atol=1e-5)
def test_morse_pair_force_scaling() -> None:
- """Test that the force scales correctly with epsilon."""
- dr = torch.ones(5, 5) * 1.5
- f1 = morse_pair_force(dr, epsilon=1.0)
- f2 = morse_pair_force(dr, epsilon=2.0)
- assert torch.allclose(f2, 2 * f1)
+ """Force scales linearly with epsilon."""
+ dr = torch.tensor([1.5])
+ z = _dummy_z(1)
+ f1 = morse_pair_force(dr, z, z, epsilon=1.0)
+ f2 = morse_pair_force(dr, z, z, epsilon=2.0)
+ torch.testing.assert_close(f2, 2.0 * f1)
def test_morse_force_energy_consistency() -> None:
- """Test that the force is consistent with the energy gradient."""
+ """Force is consistent with the energy gradient."""
dr = torch.linspace(0.8, 2.0, 100, requires_grad=True)
- dr = dr.reshape(-1, 1)
+ z = _dummy_z(len(dr))
- # Calculate force directly
- force_direct = morse_pair_force(dr)
+ force_direct = morse_pair_force(dr, z, z)
- # Calculate force from energy gradient
- energy = morse_pair(dr)
+ energy = morse_pair(dr, z, z)
force_from_grad = -torch.autograd.grad(energy.sum(), dr, create_graph=True)[0]
- # Compare forces
- assert torch.allclose(force_direct, force_from_grad, rtol=1e-4, atol=1e-4)
+ torch.testing.assert_close(force_direct, force_from_grad, rtol=1e-4, atol=1e-4)
def test_morse_alpha_effect() -> None:
- """Test that larger alpha values make the potential well narrower."""
+ """Larger alpha values make the potential well narrower."""
dr = torch.linspace(0.8, 1.2, 100)
- dr = dr.reshape(-1, 1)
+ z = _dummy_z(len(dr))
- energy1 = morse_pair(dr, alpha=5.0)
- energy2 = morse_pair(dr, alpha=10.0)
+ energy1 = morse_pair(dr, z, z, alpha=5.0)
+ energy2 = morse_pair(dr, z, z, alpha=10.0)
- # Calculate width at half minimum
def get_well_width(energy: torch.Tensor) -> torch.Tensor:
min_e = torch.min(energy)
half_e = min_e / 2
mask = energy < half_e
return dr[mask].max() - dr[mask].min()
- width1 = get_well_width(energy1)
- width2 = get_well_width(energy2)
- assert width2 < width1 # Higher alpha should give narrower well
-
-
-@pytest.fixture
-def models(
- cu_supercell_sim_state: ts.SimState,
-) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
- """Create both neighbor list and direct calculators with Copper parameters."""
- # Parameters for Copper (Cu) using Morse potential
- # Values from: https://doi.org/10.1016/j.commatsci.2004.12.069
- model_kwargs: dict[str, float | bool | torch.dtype] = {
- "sigma": 2.55, # Å, equilibrium distance
- "epsilon": 0.436, # eV, dissociation energy
- "alpha": 1.359, # Å^-1, controls potential well width
- "dtype": torch.float64,
- "compute_forces": True,
- "compute_stress": True,
- }
- cutoff = 2.5 * 2.55 # Similar scaling as LJ cutoff
- model_nl = MorseModel(use_neighbor_list=True, cutoff=cutoff, **model_kwargs)
- model_direct = MorseModel(use_neighbor_list=False, cutoff=cutoff, **model_kwargs)
-
- return model_nl(cu_supercell_sim_state), model_direct(cu_supercell_sim_state)
-
-
-def test_energy_match(
- models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]],
-) -> None:
- """Test that total energy matches between neighbor list and direct calculations."""
- results_nl, results_direct = models
- assert torch.allclose(results_nl["energy"], results_direct["energy"], rtol=1e-10)
-
-
-def test_forces_match(
- models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]],
-) -> None:
- """Test that forces match between neighbor list and direct calculations."""
- results_nl, results_direct = models
- assert torch.allclose(results_nl["forces"], results_direct["forces"], rtol=1e-10)
-
-
-def test_stress_match(
- models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]],
-) -> None:
- """Test that stress tensors match between neighbor list and direct calculations."""
- results_nl, results_direct = models
- assert torch.allclose(results_nl["stress"], results_direct["stress"], rtol=1e-10)
-
-
-def test_force_conservation(
- models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]],
-) -> None:
- """Test that forces sum to zero (Newton's third law)."""
- results_nl, _ = models
- assert torch.allclose(
- results_nl["forces"].sum(dim=0), torch.zeros(3, dtype=torch.float64), atol=1e-10
- )
+ assert get_well_width(energy2) < get_well_width(energy1)
-def test_stress_tensor_symmetry(
- models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]],
-) -> None:
- """Test that stress tensor is symmetric."""
- results_nl, _ = models
- assert torch.allclose(
- results_nl["stress"].squeeze(), results_nl["stress"].squeeze().T, atol=1e-10
+def test_morse_model_evaluation(si_double_sim_state: ts.SimState) -> None:
+ """MorseModel (wrapped PairPotentialModel) evaluates correctly."""
+ model = MorseModel(
+ sigma=2.55,
+ epsilon=0.436,
+ alpha=1.359,
+ cutoff=6.0,
+ dtype=torch.float64,
+ compute_forces=True,
+ compute_stress=True,
+ )
+ results = model(si_double_sim_state)
+ assert "energy" in results
+ assert "forces" in results
+ assert "stress" in results
+ assert results["energy"].shape == (si_double_sim_state.n_systems,)
+ assert results["forces"].shape == (si_double_sim_state.n_atoms, 3)
+ assert results["stress"].shape == (si_double_sim_state.n_systems, 3, 3)
+
+
+def test_morse_model_force_conservation(si_double_sim_state: ts.SimState) -> None:
+ """MorseModel forces sum to zero (Newton's third law)."""
+ model = MorseModel(
+ sigma=2.55,
+ epsilon=0.436,
+ alpha=1.359,
+ cutoff=6.0,
+ dtype=torch.float64,
+ compute_forces=True,
)
+ results = model(si_double_sim_state)
+ for sys_idx in range(si_double_sim_state.n_systems):
+ mask = si_double_sim_state.system_idx == sys_idx
+ assert torch.allclose(
+ results["forces"][mask].sum(dim=0),
+ torch.zeros(3, dtype=torch.float64),
+ atol=1e-10,
+ )
diff --git a/tests/models/test_pair_potential.py b/tests/models/test_pair_potential.py
index b0b62cc1..355d8171 100644
--- a/tests/models/test_pair_potential.py
+++ b/tests/models/test_pair_potential.py
@@ -1,40 +1,75 @@
-"""Tests for the general pair potential model and standard pair functions."""
+"""Tests for the general pair potential model and pair forces model."""
import functools
import pytest
import torch
+from ase.build import bulk
import torch_sim as ts
from tests.conftest import DEVICE, DTYPE
from tests.models.conftest import make_validate_model_outputs_test
-from torch_sim.models.interface import ModelInterface
-from torch_sim.models.lennard_jones import LennardJonesModel
-from torch_sim.models.morse import MorseModel
+from torch_sim import io
+from torch_sim.models.lennard_jones import LennardJonesModel, lennard_jones_pair
+from torch_sim.models.morse import morse_pair
from torch_sim.models.pair_potential import (
- MultiSoftSpherePairFn,
PairForcesModel,
PairPotentialModel,
full_to_half_list,
- lj_pair,
- morse_pair,
- particle_life_pair_force,
- soft_sphere_pair,
)
-from torch_sim.models.soft_sphere import SoftSphereModel
+from torch_sim.models.particle_life import particle_life_pair_force
+from torch_sim.models.soft_sphere import soft_sphere_pair
+from torch_sim.neighbors import torch_nl_n2
+
+
+# BMHTF (Born-Meyer-Huggins-Tosi-Fumi) potential for NaCl
+# Na-Cl interaction parameters
+BMHTF_A = 20.3548
+BMHTF_B = 3.1546
+BMHTF_C = 674.4793
+BMHTF_D = 837.0770
+BMHTF_SIGMA = 2.755
+BMHTF_CUTOFF = 10.0
+
+
+def bmhtf_pair(
+ dr: torch.Tensor,
+ zi: torch.Tensor, # noqa: ARG001
+ zj: torch.Tensor, # noqa: ARG001
+ A: float,
+ B: float,
+ C: float,
+ D: float,
+ sigma: float,
+) -> torch.Tensor:
+ """Born-Meyer-Huggins-Tosi-Fumi (BMHTF) potential for ionic crystals."""
+ exp_term = A * torch.exp(B * (sigma - dr))
+ r6_term = C / dr.pow(6)
+ r8_term = D / dr.pow(8)
+ energy = exp_term - r6_term - r8_term
+ return torch.where(dr > 0, energy, torch.zeros_like(energy))
-# Argon LJ parameters
-LJ_SIGMA = 3.405
-LJ_EPSILON = 0.0104
-LJ_CUTOFF = 2.5 * LJ_SIGMA
+@pytest.fixture
+def nacl_sim_state() -> ts.SimState:
+ """NaCl structure for BMHTF potential tests."""
+ nacl_atoms = bulk("NaCl", "rocksalt", a=5.64)
+ return io.atoms_to_state(nacl_atoms, device=DEVICE, dtype=DTYPE)
@pytest.fixture
-def lj_model_pp() -> PairPotentialModel:
+def bmhtf_model_pp() -> PairPotentialModel:
+ """BMHTF model using PairPotentialModel to test general case."""
return PairPotentialModel(
- pair_fn=functools.partial(lj_pair, sigma=LJ_SIGMA, epsilon=LJ_EPSILON),
- cutoff=LJ_CUTOFF,
+ pair_fn=functools.partial(
+ bmhtf_pair,
+ A=BMHTF_A,
+ B=BMHTF_B,
+ C=BMHTF_C,
+ D=BMHTF_D,
+ sigma=BMHTF_SIGMA,
+ ),
+ cutoff=BMHTF_CUTOFF,
dtype=DTYPE,
compute_forces=True,
compute_stress=True,
@@ -56,7 +91,7 @@ def particle_life_model() -> PairForcesModel:
# Interface validation via factory
test_pair_potential_model_outputs = make_validate_model_outputs_test(
- model_fixture_name="lj_model_pp", device=DEVICE, dtype=DTYPE
+ model_fixture_name="bmhtf_model_pp", device=DEVICE, dtype=DTYPE
)
test_pair_forces_model_outputs = make_validate_model_outputs_test(
@@ -64,267 +99,6 @@ def particle_life_model() -> PairForcesModel:
)
-def _dummy_z(n: int) -> torch.Tensor:
- return torch.ones(n, dtype=torch.long)
-
-
-def test_lj_pair_minimum() -> None:
- """Minimum of LJ is at r = 2^(1/6) * sigma."""
- dr = torch.linspace(0.9, 1.5, 500)
- z = _dummy_z(len(dr))
- energies = lj_pair(dr, z, z, sigma=1.0, epsilon=1.0)
- min_r = dr[energies.argmin()]
- assert abs(min_r.item() - 2 ** (1 / 6)) < 0.01
-
-
-def test_lj_pair_energy_at_minimum() -> None:
- """Energy at minimum equals -epsilon."""
- r_min = torch.tensor([2 ** (1 / 6)])
- z = _dummy_z(1)
- e = lj_pair(r_min, z, z, sigma=1.0, epsilon=2.0)
- torch.testing.assert_close(e, torch.tensor([-2.0]), rtol=1e-5, atol=1e-5)
-
-
-def test_lj_pair_epsilon_scaling() -> None:
- dr = torch.tensor([1.5])
- z = _dummy_z(1)
- e1 = lj_pair(dr, z, z, sigma=1.0, epsilon=1.0)
- e2 = lj_pair(dr, z, z, sigma=1.0, epsilon=3.0)
- torch.testing.assert_close(e2, 3.0 * e1)
-
-
-def test_morse_pair_minimum_at_sigma() -> None:
- """Morse minimum is at r = sigma."""
- dr = torch.linspace(0.5, 2.0, 500)
- z = _dummy_z(len(dr))
- energies = morse_pair(dr, z, z, sigma=1.0, epsilon=5.0, alpha=5.0)
- min_r = dr[energies.argmin()]
- assert abs(min_r.item() - 1.0) < 0.01
-
-
-def test_morse_pair_energy_at_minimum() -> None:
- """Morse energy at minimum equals -epsilon."""
- dr = torch.tensor([1.0])
- z = _dummy_z(1)
- e = morse_pair(dr, z, z, sigma=1.0, epsilon=5.0, alpha=5.0)
- torch.testing.assert_close(e, torch.tensor([-5.0]), rtol=1e-5, atol=1e-5)
-
-
-def test_soft_sphere_zero_beyond_sigma() -> None:
- """Soft-sphere energy is zero for r >= sigma."""
- dr = torch.tensor([1.0, 1.5, 2.0])
- z = _dummy_z(3)
- e = soft_sphere_pair(dr, z, z, sigma=1.0)
- assert e[1] == 0.0
- assert e[2] == 0.0
-
-
-def test_soft_sphere_repulsive_only() -> None:
- """Soft-sphere energies are non-negative for r < sigma."""
- dr = torch.linspace(0.1, 0.99, 50)
- z = _dummy_z(len(dr))
- e = soft_sphere_pair(dr, z, z, sigma=1.0, epsilon=1.0, alpha=2.0)
- assert (e >= 0).all()
-
-
-def test_particle_life_force_inner() -> None:
- """For dr < beta the force is negative (repulsive)."""
- dr = torch.tensor([0.1, 0.2])
- z = _dummy_z(2)
- f = particle_life_pair_force(dr, z, z, A=1.0, beta=0.3, sigma=1.0)
- assert (f < 0).all()
-
-
-def test_particle_life_force_zero_beyond_sigma() -> None:
- dr = torch.tensor([1.0, 1.5])
- z = _dummy_z(2)
- f = particle_life_pair_force(dr, z, z, A=1.0, beta=0.3, sigma=1.0)
- assert f[0] == 0.0
- assert f[1] == 0.0
-
-
-def _make_mss(
- sigma: float = 1.0, epsilon: float = 1.0, alpha: float = 2.0
-) -> MultiSoftSpherePairFn:
- """Two-species MultiSoftSpherePairFn with uniform parameters."""
- n = 2
- return MultiSoftSpherePairFn(
- atomic_numbers=torch.tensor([18, 36]),
- sigma_matrix=torch.full((n, n), sigma),
- epsilon_matrix=torch.full((n, n), epsilon),
- alpha_matrix=torch.full((n, n), alpha),
- )
-
-
-def test_multi_soft_sphere_zero_beyond_sigma() -> None:
- """Energy is zero for r >= sigma."""
- fn = _make_mss(sigma=1.0)
- dr = torch.tensor([1.0, 1.5])
- zi = zj = torch.tensor([18, 36])
- e = fn(dr, zi, zj)
- assert (e == 0.0).all()
-
-
-def test_multi_soft_sphere_repulsive_only() -> None:
- """Energy is non-negative for r < sigma."""
- fn = _make_mss(sigma=2.0, epsilon=1.0, alpha=2.0)
- dr = torch.linspace(0.1, 1.99, 20)
- zi = zj = torch.full((20,), 18, dtype=torch.long)
- assert (fn(dr, zi, zj) >= 0).all()
-
-
-def test_multi_soft_sphere_species_lookup() -> None:
- """Different species pairs use the correct off-diagonal parameters."""
- sigma_matrix = torch.tensor([[1.0, 2.0], [2.0, 3.0]])
- epsilon_matrix = torch.ones(2, 2)
- alpha_matrix = torch.full((2, 2), 2.0)
- fn = MultiSoftSpherePairFn(
- atomic_numbers=torch.tensor([18, 36]),
- sigma_matrix=sigma_matrix,
- epsilon_matrix=epsilon_matrix,
- alpha_matrix=alpha_matrix,
- )
- dr = torch.tensor([0.5])
- zi_same = torch.tensor([18])
- zj_same = torch.tensor([18])
- zi_cross = torch.tensor([18])
- zj_cross = torch.tensor([36])
- e_same = fn(dr, zi_same, zj_same) # sigma=1.0, r=0.5 < sigma → non-zero
- e_cross = fn(dr, zi_cross, zj_cross) # sigma=2.0, r=0.5 < sigma → non-zero
- # cross pair has larger sigma so (1 - r/sigma) is larger → higher energy
- assert e_cross > e_same
-
-
-def test_multi_soft_sphere_alpha_matrix_default() -> None:
- """Omitting alpha_matrix defaults to 2.0 for all pairs."""
- fn_default = MultiSoftSpherePairFn(
- atomic_numbers=torch.tensor([18, 36]),
- sigma_matrix=torch.full((2, 2), 1.0),
- epsilon_matrix=torch.full((2, 2), 1.0),
- )
- fn_explicit = _make_mss(sigma=1.0, epsilon=1.0, alpha=2.0)
- dr = torch.tensor([0.5])
- zi = zj = torch.tensor([18])
- torch.testing.assert_close(fn_default(dr, zi, zj), fn_explicit(dr, zi, zj))
-
-
-def test_multi_soft_sphere_bad_matrix_shape_raises() -> None:
- with pytest.raises(ValueError, match="sigma_matrix"):
- MultiSoftSpherePairFn(
- atomic_numbers=torch.tensor([18, 36]),
- sigma_matrix=torch.ones(3, 3), # wrong shape
- epsilon_matrix=torch.ones(2, 2),
- )
-
-
-def _build_potential_model_pair(name: str) -> tuple[PairPotentialModel, ModelInterface]:
- """Return (PairPotentialModel, reference_model) for a named potential."""
- if name == "lj-half":
- pp = PairPotentialModel(
- pair_fn=functools.partial(lj_pair, sigma=LJ_SIGMA, epsilon=LJ_EPSILON),
- cutoff=LJ_CUTOFF,
- dtype=DTYPE,
- compute_forces=True,
- compute_stress=True,
- per_atom_energies=True,
- per_atom_stresses=True,
- reduce_to_half_list=True,
- )
- ref = LennardJonesModel(
- sigma=LJ_SIGMA,
- epsilon=LJ_EPSILON,
- cutoff=LJ_CUTOFF,
- dtype=DTYPE,
- compute_forces=True,
- compute_stress=True,
- per_atom_energies=True,
- per_atom_stresses=True,
- )
- return pp, ref
- if name == "lj-full":
- pp = PairPotentialModel(
- pair_fn=functools.partial(lj_pair, sigma=LJ_SIGMA, epsilon=LJ_EPSILON),
- cutoff=LJ_CUTOFF,
- dtype=DTYPE,
- compute_forces=True,
- compute_stress=True,
- per_atom_energies=True,
- per_atom_stresses=True,
- reduce_to_half_list=False,
- )
- ref = LennardJonesModel(
- sigma=LJ_SIGMA,
- epsilon=LJ_EPSILON,
- cutoff=LJ_CUTOFF,
- dtype=DTYPE,
- compute_forces=True,
- compute_stress=True,
- per_atom_energies=True,
- per_atom_stresses=True,
- )
- return pp, ref
- if name == "morse":
- sigma, epsilon, alpha, cutoff = 4.0, 5.0, 5.0, 5.0
- pp = PairPotentialModel(
- pair_fn=functools.partial(
- morse_pair, sigma=sigma, epsilon=epsilon, alpha=alpha
- ),
- cutoff=cutoff,
- dtype=DTYPE,
- compute_forces=True,
- compute_stress=True,
- )
- ref = MorseModel(
- sigma=sigma,
- epsilon=epsilon,
- alpha=alpha,
- cutoff=cutoff,
- dtype=DTYPE,
- compute_forces=True,
- compute_stress=True,
- )
- return pp, ref
- if name == "soft_sphere":
- sigma, epsilon, alpha = 5, 0.0104, 2.0
- pp = PairPotentialModel(
- pair_fn=functools.partial(
- soft_sphere_pair, sigma=sigma, epsilon=epsilon, alpha=alpha
- ),
- cutoff=sigma,
- dtype=DTYPE,
- compute_forces=True,
- compute_stress=True,
- )
- ref = SoftSphereModel(
- sigma=sigma,
- epsilon=epsilon,
- alpha=alpha,
- cutoff=sigma,
- dtype=DTYPE,
- compute_forces=True,
- compute_stress=True,
- )
- return pp, ref
-
- raise ValueError(f"Unknown potential: {name}")
-
-
-@pytest.mark.parametrize("potential", ["lj-half", "lj-full", "morse", "soft_sphere"])
-def test_potential_matches_reference(
- mixed_double_sim_state: ts.SimState,
- potential: str,
-) -> None:
- """PairPotentialModel matches the dedicated reference model."""
- model_pp, model_ref = _build_potential_model_pair(potential)
- out_pp = model_pp(mixed_double_sim_state)
- out_ref = model_ref(mixed_double_sim_state)
-
- assert (out_pp["energy"] != 0).all()
-
- for key in out_pp:
- torch.testing.assert_close(out_pp[key], out_ref[key], rtol=1e-4, atol=1e-5)
-
-
def test_full_to_half_list_removes_duplicates() -> None:
"""i < j mask halves a symmetric full neighbor list."""
# 3-atom full list: (0,1),(1,0),(0,2),(2,0),(1,2),(2,1)
@@ -353,13 +127,17 @@ def test_full_to_half_list_preserves_system_and_shifts() -> None:
@pytest.mark.parametrize("key", ["energy", "forces", "stress", "stresses"])
def test_half_list_matches_full(si_double_sim_state: ts.SimState, key: str) -> None:
"""reduce_to_half_list=True gives the same result as the default full list."""
- fn = functools.partial(lj_pair, sigma=LJ_SIGMA, epsilon=LJ_EPSILON)
+ # Argon LJ parameters
+ sigma = 3.405
+ epsilon = 0.0104
+ cutoff = 2.5 * sigma
+ fn = functools.partial(lennard_jones_pair, sigma=sigma, epsilon=epsilon)
needs_forces = key in ("forces", "stress", "stresses")
needs_stress = key in ("stress", "stresses")
common = dict(
pair_fn=fn,
- cutoff=LJ_CUTOFF,
- dtype=DTYPE,
+ cutoff=cutoff,
+ dtype=si_double_sim_state.dtype,
compute_forces=needs_forces,
compute_stress=needs_stress,
per_atom_stresses=(key == "stresses"),
@@ -371,14 +149,25 @@ def test_half_list_matches_full(si_double_sim_state: ts.SimState, key: str) -> N
torch.testing.assert_close(out_half[key], out_full[key], rtol=1e-10, atol=1e-14)
-@pytest.mark.parametrize("potential", ["lj", "morse", "soft_sphere"])
+@pytest.mark.parametrize("potential", ["bmhtf", "morse", "soft_sphere"])
def test_autograd_force_fn_matches_potential_model(
- si_double_sim_state: ts.SimState, potential: str
+ nacl_sim_state: ts.SimState,
+ si_double_sim_state: ts.SimState,
+ potential: str,
) -> None:
"""PairForcesModel with -dV/dr force fn matches PairPotentialModel forces/stress."""
- if potential == "lj":
- pair_fn = functools.partial(lj_pair, sigma=LJ_SIGMA, epsilon=LJ_EPSILON)
- cutoff = LJ_CUTOFF
+ # Use NaCl for BMHTF, si_double for others
+ sim_state = nacl_sim_state if potential == "bmhtf" else si_double_sim_state
+ if potential == "bmhtf":
+ pair_fn = functools.partial(
+ bmhtf_pair,
+ A=BMHTF_A,
+ B=BMHTF_B,
+ C=BMHTF_C,
+ D=BMHTF_D,
+ sigma=BMHTF_SIGMA,
+ )
+ cutoff = BMHTF_CUTOFF
elif potential == "morse":
pair_fn = functools.partial(morse_pair, sigma=4.0, epsilon=5.0, alpha=5.0)
cutoff = 5.0
@@ -395,7 +184,7 @@ def force_fn(dr: torch.Tensor, zi: torch.Tensor, zj: torch.Tensor) -> torch.Tens
model_pp = PairPotentialModel(
pair_fn=pair_fn,
cutoff=cutoff,
- dtype=DTYPE,
+ dtype=sim_state.dtype,
compute_forces=True,
compute_stress=True,
per_atom_stresses=True,
@@ -403,12 +192,12 @@ def force_fn(dr: torch.Tensor, zi: torch.Tensor, zj: torch.Tensor) -> torch.Tens
model_pf = PairForcesModel(
force_fn=force_fn,
cutoff=cutoff,
- dtype=DTYPE,
+ dtype=sim_state.dtype,
compute_stress=True,
per_atom_stresses=True,
)
- out_pp = model_pp(si_double_sim_state)
- out_pf = model_pf(si_double_sim_state)
+ out_pp = model_pp(sim_state)
+ out_pf = model_pf(sim_state)
assert (out_pp["forces"] != 0.0).all()
@@ -435,3 +224,130 @@ def test_forces_model_half_list_matches_full(
out_full = model_full(si_double_sim_state)
out_half = model_half(si_double_sim_state)
torch.testing.assert_close(out_half[key], out_full[key], rtol=1e-10, atol=1e-13)
+
+
+def test_force_conservation(
+ bmhtf_model_pp: PairPotentialModel, nacl_sim_state: ts.SimState
+) -> None:
+ """Forces sum to zero (Newton's third law)."""
+ out = bmhtf_model_pp(nacl_sim_state)
+ for sys_idx in range(nacl_sim_state.n_systems):
+ mask = nacl_sim_state.system_idx == sys_idx
+ assert torch.allclose(
+ out["forces"][mask].sum(dim=0),
+ torch.zeros(3, dtype=nacl_sim_state.dtype),
+ atol=1e-10,
+ )
+
+
+def test_stress_tensor_symmetry(
+ bmhtf_model_pp: PairPotentialModel, nacl_sim_state: ts.SimState
+) -> None:
+ """Stress tensor is symmetric."""
+ out = bmhtf_model_pp(nacl_sim_state)
+ for i in range(nacl_sim_state.n_systems):
+ stress = out["stress"][i]
+ assert torch.allclose(stress, stress.T, atol=1e-10)
+
+
+def test_multi_system(ar_double_sim_state: ts.SimState) -> None:
+ """Multi-system batched evaluation matches single-system evaluation."""
+ model = LennardJonesModel(
+ sigma=3.405,
+ epsilon=0.0104,
+ cutoff=2.5 * 3.405,
+ dtype=torch.float64,
+ device=DEVICE,
+ compute_forces=True,
+ compute_stress=True,
+ )
+ out = model(ar_double_sim_state)
+
+ assert out["energy"].shape == (ar_double_sim_state.n_systems,)
+ # Both systems are identical, so energies should match
+ torch.testing.assert_close(out["energy"][0], out["energy"][1], rtol=1e-10, atol=1e-10)
+
+
+def test_unwrapped_positions_consistency() -> None:
+ """Wrapped and unwrapped positions give identical results."""
+ ar_atoms = bulk("Ar", "fcc", a=5.26, cubic=True).repeat([2, 2, 2])
+ cell = torch.tensor(ar_atoms.get_cell().array, dtype=torch.float64, device=DEVICE)
+
+ state_wrapped = ts.io.atoms_to_state(ar_atoms, DEVICE, torch.float64)
+
+ positions_unwrapped = state_wrapped.positions.clone()
+ n_atoms = positions_unwrapped.shape[0]
+ positions_unwrapped[: n_atoms // 2] += cell[0]
+ positions_unwrapped[n_atoms // 4 : n_atoms // 2] -= cell[1]
+
+ state_unwrapped = ts.SimState.from_state(state_wrapped, positions=positions_unwrapped)
+
+ model = LennardJonesModel(
+ sigma=3.405,
+ epsilon=0.0104,
+ cutoff=2.5 * 3.405,
+ dtype=torch.float64,
+ device=DEVICE,
+ compute_forces=True,
+ compute_stress=True,
+ )
+
+ results_wrapped = model(state_wrapped)
+ results_unwrapped = model(state_unwrapped)
+
+ for key in ("energy", "forces", "stress"):
+ torch.testing.assert_close(
+ results_wrapped[key], results_unwrapped[key], rtol=1e-10, atol=1e-10
+ )
+
+
+def test_retain_graph_allows_param_grad(nacl_sim_state: ts.SimState) -> None:
+ """With retain_graph=True, energy graph survives force computation so we can
+ differentiate energy w.r.t. model parameters (e.g. A, B, C, D)."""
+ A = torch.tensor(BMHTF_A, dtype=nacl_sim_state.dtype, requires_grad=True)
+ pair_fn = functools.partial(
+ bmhtf_pair,
+ A=A,
+ B=BMHTF_B,
+ C=BMHTF_C,
+ D=BMHTF_D,
+ sigma=BMHTF_SIGMA,
+ )
+ model = PairPotentialModel(
+ pair_fn=pair_fn,
+ cutoff=BMHTF_CUTOFF,
+ dtype=nacl_sim_state.dtype,
+ compute_forces=True,
+ neighbor_list_fn=torch_nl_n2,
+ retain_graph=True,
+ )
+ out = model(nacl_sim_state)
+ assert out["forces"] is not None
+ (grad,) = torch.autograd.grad(out["energy"].sum(), A)
+ assert grad.shape == A.shape
+ assert grad.abs() > 0
+
+
+def test_no_retain_graph_frees_graph(nacl_sim_state: ts.SimState) -> None:
+ """Without retain_graph, differentiating energy w.r.t. parameters after force
+ computation raises because the graph has been freed."""
+ A = torch.tensor(BMHTF_A, dtype=nacl_sim_state.dtype, requires_grad=True)
+ pair_fn = functools.partial(
+ bmhtf_pair,
+ A=A,
+ B=BMHTF_B,
+ C=BMHTF_C,
+ D=BMHTF_D,
+ sigma=BMHTF_SIGMA,
+ )
+ model = PairPotentialModel(
+ pair_fn=pair_fn,
+ cutoff=BMHTF_CUTOFF,
+ dtype=nacl_sim_state.dtype,
+ compute_forces=True,
+ neighbor_list_fn=torch_nl_n2,
+ retain_graph=False,
+ )
+ out = model(nacl_sim_state)
+ with pytest.raises(RuntimeError, match="does not require grad"):
+ torch.autograd.grad(out["energy"].sum(), A)
diff --git a/tests/models/test_particle_life.py b/tests/models/test_particle_life.py
new file mode 100644
index 00000000..d93b1f3b
--- /dev/null
+++ b/tests/models/test_particle_life.py
@@ -0,0 +1,80 @@
+"""Tests for the particle life force function and wrapped model."""
+
+import torch
+
+import torch_sim as ts
+from torch_sim.models.particle_life import ParticleLifeModel, particle_life_pair_force
+
+
+def _dummy_z(n: int) -> torch.Tensor:
+ return torch.ones(n, dtype=torch.long)
+
+
+def test_inner_region_repulsive() -> None:
+ """For dr < beta the force is negative (repulsive)."""
+ dr = torch.tensor([0.1, 0.2])
+ z = _dummy_z(2)
+ f = particle_life_pair_force(dr, z, z, A=1.0, beta=0.3, sigma=1.0)
+ assert (f < 0).all()
+
+
+def test_zero_beyond_sigma() -> None:
+ """Force is zero at and beyond sigma."""
+ dr = torch.tensor([1.0, 1.5])
+ z = _dummy_z(2)
+ f = particle_life_pair_force(dr, z, z, A=1.0, beta=0.3, sigma=1.0)
+ assert (f == 0.0).all()
+
+
+def test_amplitude_scaling() -> None:
+ """Outer-region force scales with A."""
+ dr = torch.tensor([0.6]) # between beta and sigma
+ z = _dummy_z(1)
+ f1 = particle_life_pair_force(dr, z, z, A=1.0, beta=0.3, sigma=1.0)
+ f2 = particle_life_pair_force(dr, z, z, A=3.0, beta=0.3, sigma=1.0)
+ torch.testing.assert_close(f2, 3.0 * f1)
+
+
+def test_particle_life_model_evaluation(si_double_sim_state: ts.SimState) -> None:
+ """ParticleLifeModel (wrapped PairForcesModel) evaluates correctly."""
+ model = ParticleLifeModel(
+ A=1.0,
+ beta=0.3,
+ sigma=5.26,
+ cutoff=5.26,
+ dtype=si_double_sim_state.dtype,
+ compute_stress=True,
+ )
+ results = model(si_double_sim_state)
+ assert "energy" in results
+ assert "forces" in results
+ assert "stress" in results
+ assert results["energy"].shape == (si_double_sim_state.n_systems,)
+ assert results["forces"].shape == (si_double_sim_state.n_atoms, 3)
+ assert results["stress"].shape == (si_double_sim_state.n_systems, 3, 3)
+ # Energy should be zeros for PairForcesModel
+ assert torch.allclose(
+ results["energy"],
+ torch.zeros(si_double_sim_state.n_systems, dtype=si_double_sim_state.dtype),
+ )
+
+
+def test_particle_life_model_force_conservation(
+ si_double_sim_state: ts.SimState,
+) -> None:
+ """ParticleLifeModel forces sum to zero (Newton's third law)."""
+ model = ParticleLifeModel(
+ A=1.0,
+ beta=0.3,
+ sigma=5.26,
+ cutoff=5.26,
+ dtype=torch.float64,
+ )
+ results = model(si_double_sim_state)
+ for sys_idx in range(si_double_sim_state.n_systems):
+ mask = si_double_sim_state.system_idx == sys_idx
+ assert torch.allclose(
+ results["forces"][mask].sum(dim=0),
+ torch.zeros(3, dtype=torch.float64),
+ atol=1e-10,
+ )
diff --git a/tests/models/test_soft_sphere.py b/tests/models/test_soft_sphere.py
index 1b1cae6f..daa32102 100644
--- a/tests/models/test_soft_sphere.py
+++ b/tests/models/test_soft_sphere.py
@@ -1,146 +1,36 @@
-"""Tests for soft sphere models ensuring different parts of TorchSim work together."""
+"""Tests for the soft sphere pair functions, wrapped model, and multi-species models."""
import pytest
import torch
import torch_sim as ts
-import torch_sim.models.soft_sphere as ss
-from tests.conftest import DEVICE
-from torch_sim.models.interface import validate_model_outputs
-
-
-def _make_soft_sphere_model(
- *, use_neighbor_list: bool, with_per_atom: bool = False
-) -> ss.SoftSphereModel:
- """Create a SoftSphereModel with common test defaults."""
- model_kwargs: dict[str, float | bool | torch.dtype] = {
- "sigma": 3.405, # Å, typical for Ar
- "epsilon": 0.0104, # eV, typical for Ar
- "alpha": 2.0,
- "dtype": torch.float64,
- "compute_forces": True,
- "compute_stress": True,
- }
- if with_per_atom:
- model_kwargs["per_atom_energies"] = True
- model_kwargs["per_atom_stresses"] = True
- return ss.SoftSphereModel(use_neighbor_list=use_neighbor_list, **model_kwargs)
-
-
-@pytest.fixture
-def models(
- fe_supercell_sim_state: ts.SimState,
-) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
- """Create both neighbor list and direct calculators."""
- model_nl = _make_soft_sphere_model(use_neighbor_list=True)
- model_direct = _make_soft_sphere_model(use_neighbor_list=False)
-
- return model_nl(fe_supercell_sim_state), model_direct(fe_supercell_sim_state)
-
-
-@pytest.fixture
-def models_with_per_atom(
- fe_supercell_sim_state: ts.SimState,
-) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
- """Create calculators with per-atom properties enabled."""
- model_nl = _make_soft_sphere_model(use_neighbor_list=True, with_per_atom=True)
- model_direct = _make_soft_sphere_model(use_neighbor_list=False, with_per_atom=True)
-
- return model_nl(fe_supercell_sim_state), model_direct(fe_supercell_sim_state)
-
-
-@pytest.fixture
-def small_system() -> tuple[torch.Tensor, torch.Tensor]:
- """Create a small simple cubic system for testing."""
- positions = torch.tensor(
- [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]],
- dtype=torch.float64,
- )
- cell = torch.eye(3, dtype=torch.float64) * 2.0
- return positions, cell
-
-
-@pytest.fixture
-def small_sim_state(small_system: tuple[torch.Tensor, torch.Tensor]) -> ts.SimState:
- """Create a small SimState for testing."""
- positions, cell = small_system
- return ts.SimState(
- positions=positions,
- cell=cell,
- pbc=True,
- masses=torch.ones(positions.shape[0], dtype=torch.float64),
- atomic_numbers=torch.ones(positions.shape[0], dtype=torch.long),
- )
-
-
-@pytest.fixture
-def small_batched_sim_state(small_sim_state: ts.SimState) -> ts.SimState:
- """Create a batched state from the small system."""
- return ts.concatenate_states(
- [small_sim_state, small_sim_state], device=small_sim_state.device
- )
-
-
-@pytest.mark.parametrize("output_key", ["energy", "forces", "stress"])
-def test_outputs_match(
- models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]],
- output_key: str,
-) -> None:
- """Test that outputs match between neighbor list and direct calculations."""
- results_nl, results_direct = models
- assert torch.allclose(results_nl[output_key], results_direct[output_key], rtol=1e-10)
+from torch_sim.models.soft_sphere import (
+ MultiSoftSpherePairFn,
+ SoftSphereModel,
+ SoftSphereMultiModel,
+ soft_sphere_pair,
+)
-def test_force_conservation(
- models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]],
-) -> None:
- """Test that forces sum to zero."""
- results_nl, _ = models
- assert torch.allclose(
- results_nl["forces"].sum(dim=0), torch.zeros(3, dtype=torch.float64), atol=1e-10
- )
+def _dummy_z(n: int) -> torch.Tensor:
+ return torch.ones(n, dtype=torch.long)
-def test_stress_tensor_symmetry(
- models: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]],
-) -> None:
- """Test that stress tensor is symmetric."""
- results_nl, _ = models
- assert torch.allclose(results_nl["stress"], results_nl["stress"].T, atol=1e-10)
+def test_soft_sphere_zero_beyond_sigma() -> None:
+ """Soft-sphere energy is zero for r >= sigma."""
+ dr = torch.tensor([1.0, 1.5, 2.0])
+ z = _dummy_z(3)
+ e = soft_sphere_pair(dr, z, z, sigma=1.0)
+ assert e[1] == 0.0
+ assert e[2] == 0.0
-def test_validate_model_outputs() -> None:
- """Test that the model outputs are valid."""
- model_nl = _make_soft_sphere_model(use_neighbor_list=True)
- model_direct = _make_soft_sphere_model(use_neighbor_list=False)
- for out in (model_nl, model_direct):
- validate_model_outputs(out, DEVICE, torch.float64)
-
-
-@pytest.mark.parametrize(
- ("per_atom_key", "total_key"), [("energies", "energy"), ("stresses", "stress")]
-)
-def test_per_atom_properties(
- models_with_per_atom: tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]],
- per_atom_key: str,
- total_key: str,
-) -> None:
- """Test that per-atom properties are calculated correctly."""
- results_nl, results_direct = models_with_per_atom
-
- # Check per-atom properties are calculated and match
- assert torch.allclose(
- results_nl[per_atom_key], results_direct[per_atom_key], rtol=1e-10
- )
-
- # Check sum of per-atom properties matches total property
- if per_atom_key == "energies":
- assert torch.allclose(
- results_nl[per_atom_key].sum(), results_nl[total_key], rtol=1e-10
- )
- else: # stresses
- total_from_atoms = results_nl[per_atom_key].sum(dim=0)
- assert torch.allclose(total_from_atoms, results_nl[total_key], rtol=1e-10)
+def test_soft_sphere_repulsive_only() -> None:
+ """Soft-sphere energies are non-negative for r < sigma."""
+ dr = torch.linspace(0.1, 0.99, 50)
+ z = _dummy_z(len(dr))
+ e = soft_sphere_pair(dr, z, z, sigma=1.0, epsilon=1.0, alpha=2.0)
+ assert (e >= 0).all()
@pytest.mark.parametrize(
@@ -155,108 +45,102 @@ def test_soft_sphere_pair_single(
distance: float, sigma: float, epsilon: float, alpha: float, expected: float
) -> None:
"""Test the soft sphere pair calculation for single values."""
- energy = ss.soft_sphere_pair(
- torch.tensor(distance),
- torch.tensor(sigma),
- torch.tensor(epsilon),
- torch.tensor(alpha),
+ dr = torch.tensor([distance])
+ z = _dummy_z(1)
+ energy = soft_sphere_pair(dr, z, z, sigma=sigma, epsilon=epsilon, alpha=alpha)
+ torch.testing.assert_close(energy, torch.tensor([expected]))
+
+
+def _make_mss(
+ sigma: float = 1.0, epsilon: float = 1.0, alpha: float = 2.0
+) -> MultiSoftSpherePairFn:
+ """Two-species MultiSoftSpherePairFn with uniform parameters."""
+ n = 2
+ return MultiSoftSpherePairFn(
+ atomic_numbers=torch.tensor([18, 36]),
+ sigma_matrix=torch.full((n, n), sigma),
+ epsilon_matrix=torch.full((n, n), epsilon),
+ alpha_matrix=torch.full((n, n), alpha),
)
- assert torch.allclose(energy, torch.tensor(expected))
-
-
-def test_model_initialization_defaults() -> None:
- """Test initialization with default parameters."""
- model = ss.SoftSphereModel()
-
- # Check default parameters are used
- assert torch.allclose(model.sigma, ss.DEFAULT_SIGMA)
- assert torch.allclose(model.epsilon, ss.DEFAULT_EPSILON)
- assert torch.allclose(model.alpha, ss.DEFAULT_ALPHA)
- assert torch.allclose(model.cutoff, ss.DEFAULT_SIGMA) # Default cutoff is sigma
-
-
-@pytest.mark.parametrize(
- ("param_name", "param_value", "expected_dtype"),
- [
- ("sigma", 2.0, torch.float64),
- ("epsilon", 3.0, torch.float64),
- ("alpha", 4.0, torch.float64),
- ("cutoff", 5.0, torch.float64),
- ],
-)
-def test_model_initialization_custom_params(
- param_name: str, param_value: float, expected_dtype: torch.dtype
-) -> None:
- """Test initialization with custom parameters."""
- model = ss.SoftSphereModel(**{param_name: param_value, "dtype": expected_dtype})
-
- param_tensor = getattr(model, param_name)
- assert torch.allclose(param_tensor, torch.tensor(param_value, dtype=expected_dtype))
- assert param_tensor.dtype == expected_dtype
-@pytest.mark.parametrize(
- ("flag_name", "flag_value"),
- [
- ("compute_forces", False),
- ("compute_stress", True),
- ("per_atom_energies", True),
- ("per_atom_stresses", True),
- ("use_neighbor_list", False),
- ],
-)
-def test_model_initialization_custom_flags(*, flag_name: str, flag_value: bool) -> None:
- """Test initialization with custom flags."""
- model = ss.SoftSphereModel(**{flag_name: flag_value})
-
- # For compute_forces and compute_stress, we need to check the private attributes
- if flag_name == "compute_forces":
- flag_name = "_compute_forces"
- elif flag_name == "compute_stress":
- flag_name = "_compute_stress"
+def test_multi_soft_sphere_zero_beyond_sigma() -> None:
+ """Energy is zero for r >= sigma."""
+ fn = _make_mss(sigma=1.0)
+ dr = torch.tensor([1.0, 1.5])
+ zi = zj = torch.tensor([18, 36])
+ e = fn(dr, zi, zj)
+ assert (e == 0.0).all()
- assert getattr(model, flag_name) is flag_value
+def test_multi_soft_sphere_repulsive_only() -> None:
+ """Energy is non-negative for r < sigma."""
+ fn = _make_mss(sigma=2.0, epsilon=1.0, alpha=2.0)
+ dr = torch.linspace(0.1, 1.99, 20)
+ zi = zj = torch.full((20,), 18, dtype=torch.long)
+ assert (fn(dr, zi, zj) >= 0).all()
-@pytest.mark.parametrize("dtype", [torch.float32, torch.float64])
-def test_model_dtype(dtype: torch.dtype) -> None:
- """Test model with different dtypes."""
- model = ss.SoftSphereModel(dtype=dtype)
- assert model.sigma.dtype == dtype
- assert model.epsilon.dtype == dtype
- assert model.alpha.dtype == dtype
- assert model.cutoff.dtype == dtype
+def test_multi_soft_sphere_species_lookup() -> None:
+ """Different species pairs use the correct off-diagonal parameters."""
+ sigma_matrix = torch.tensor([[1.0, 2.0], [2.0, 3.0]])
+ epsilon_matrix = torch.ones(2, 2)
+ alpha_matrix = torch.full((2, 2), 2.0)
+ fn = MultiSoftSpherePairFn(
+ atomic_numbers=torch.tensor([18, 36]),
+ sigma_matrix=sigma_matrix,
+ epsilon_matrix=epsilon_matrix,
+ alpha_matrix=alpha_matrix,
+ )
+ dr = torch.tensor([0.5])
+ zi_same = torch.tensor([18])
+ zj_same = torch.tensor([18])
+ zi_cross = torch.tensor([18])
+ zj_cross = torch.tensor([36])
+ e_same = fn(dr, zi_same, zj_same) # sigma=1.0, r=0.5 < sigma → non-zero
+ e_cross = fn(dr, zi_cross, zj_cross) # sigma=2.0, r=0.5 < sigma → non-zero
+ # cross pair has larger sigma so (1 - r/sigma) is larger → higher energy
+ assert e_cross > e_same
+
+
+def test_multi_soft_sphere_alpha_matrix_default() -> None:
+ """Omitting alpha_matrix defaults to 2.0 for all pairs."""
+ fn_default = MultiSoftSpherePairFn(
+ atomic_numbers=torch.tensor([18, 36]),
+ sigma_matrix=torch.full((2, 2), 1.0),
+ epsilon_matrix=torch.full((2, 2), 1.0),
+ )
+ fn_explicit = _make_mss(sigma=1.0, epsilon=1.0, alpha=2.0)
+ dr = torch.tensor([0.5])
+ zi = zj = torch.tensor([18])
+ torch.testing.assert_close(fn_default(dr, zi, zj), fn_explicit(dr, zi, zj))
+
+
+def test_multi_soft_sphere_bad_matrix_shape_raises() -> None:
+ with pytest.raises(ValueError, match="sigma_matrix"):
+ MultiSoftSpherePairFn(
+ atomic_numbers=torch.tensor([18, 36]),
+ sigma_matrix=torch.ones(3, 3), # wrong shape
+ epsilon_matrix=torch.ones(2, 2),
+ )
def test_multispecies_initialization_defaults() -> None:
- """Test initialization of multi-species model with defaults."""
- dtype = torch.float32
- model = ss.SoftSphereMultiModel(n_species=2, dtype=dtype)
-
- # Check matrices are created with defaults
+ """Multi-species model initializes with default parameters."""
+ model = SoftSphereMultiModel(atomic_numbers=torch.tensor([0, 1]), dtype=torch.float32)
assert model.sigma_matrix.shape == (2, 2)
assert model.epsilon_matrix.shape == (2, 2)
assert model.alpha_matrix.shape == (2, 2)
- # Check default values
- ones = torch.ones(2, 2, dtype=dtype)
- assert torch.allclose(model.sigma_matrix, ss.DEFAULT_SIGMA * ones)
- assert torch.allclose(model.epsilon_matrix, ss.DEFAULT_EPSILON * ones)
- assert torch.allclose(model.alpha_matrix, ss.DEFAULT_ALPHA * ones)
-
- # Check cutoff is max sigma
- assert model.cutoff.item() == ss.DEFAULT_SIGMA.item()
-
def test_multispecies_initialization_custom() -> None:
- """Test initialization of multi-species model with custom parameters."""
+ """Multi-species model stores custom parameter matrices."""
sigma_matrix = torch.tensor([[1.0, 1.5], [1.5, 2.0]], dtype=torch.float64)
epsilon_matrix = torch.tensor([[1.0, 0.5], [0.5, 1.5]], dtype=torch.float64)
alpha_matrix = torch.tensor([[2.0, 3.0], [3.0, 4.0]], dtype=torch.float64)
- model = ss.SoftSphereMultiModel(
- n_species=2,
+ model = SoftSphereMultiModel(
+ atomic_numbers=torch.tensor([0, 1]),
sigma_matrix=sigma_matrix,
epsilon_matrix=epsilon_matrix,
alpha_matrix=alpha_matrix,
@@ -264,25 +148,20 @@ def test_multispecies_initialization_custom() -> None:
dtype=torch.float64,
)
- # Check matrices are stored correctly
assert torch.allclose(model.sigma_matrix, sigma_matrix)
assert torch.allclose(model.epsilon_matrix, epsilon_matrix)
assert torch.allclose(model.alpha_matrix, alpha_matrix)
-
- # Check cutoff is set explicitly
assert model.cutoff.item() == 3.0
def test_multispecies_matrix_validation() -> None:
- """Test validation of parameter matrices."""
- # Create incorrect-sized matrices (2x2 instead of 3x3)
+ """Incorrectly sized matrices raise ValueError."""
sigma_matrix = torch.tensor([[1.0, 1.5], [1.5, 2.0]])
epsilon_matrix = torch.tensor([[1.0, 0.5], [0.5, 1.5]])
- # Should raise ValueError due to matrix size mismatch
with pytest.raises(ValueError, match="sigma_matrix must have shape"):
- ss.SoftSphereMultiModel(
- n_species=3,
+ SoftSphereMultiModel(
+ atomic_numbers=torch.tensor([0, 1, 2]),
sigma_matrix=sigma_matrix,
epsilon_matrix=epsilon_matrix,
)
@@ -297,52 +176,98 @@ def test_multispecies_matrix_validation() -> None:
],
)
def test_matrix_symmetry_validation(matrix_name: str, matrix: torch.Tensor) -> None:
- """Test that parameter matrices are validated for symmetry."""
- # Create symmetric matrices for the other parameters
+ """Parameter matrices are validated for symmetry."""
symmetric_matrix = torch.tensor([[1.0, 1.5], [1.5, 2.0]])
-
params = {
- "n_species": 2,
+ "atomic_numbers": torch.tensor([0, 1]),
"sigma_matrix": symmetric_matrix,
"epsilon_matrix": symmetric_matrix,
"alpha_matrix": symmetric_matrix,
}
-
- # Replace one matrix with the non-symmetric version
params[matrix_name] = matrix
- # Should raise ValueError due to asymmetric matrix
with pytest.raises(ValueError, match="is not symmetric"):
- ss.SoftSphereMultiModel(**params)
+ SoftSphereMultiModel(**params)
def test_multispecies_cutoff_default() -> None:
- """Test that the default cutoff is the maximum sigma value."""
+ """Default cutoff is the maximum sigma value."""
sigma_matrix = torch.tensor([[1.0, 1.5, 2.0], [1.5, 2.0, 2.5], [2.0, 2.5, 3.0]])
+ model = SoftSphereMultiModel(
+ atomic_numbers=torch.tensor([0, 1, 2]), sigma_matrix=sigma_matrix
+ )
+ assert model.cutoff.item() == 3.0
- model = ss.SoftSphereMultiModel(n_species=3, sigma_matrix=sigma_matrix)
- # Cutoff should default to max value in sigma_matrix
- assert model.cutoff.item() == 3.0
+def test_multispecies_evaluation() -> None:
+ """Multi-species model evaluates without error on a small system."""
+ sigma_matrix = torch.tensor([[1.0, 0.8], [0.8, 0.6]], dtype=torch.float64)
+ epsilon_matrix = torch.tensor([[1.0, 0.5], [0.5, 2.0]], dtype=torch.float64)
+ model = SoftSphereMultiModel(
+ atomic_numbers=torch.tensor([0, 1]),
+ sigma_matrix=sigma_matrix,
+ epsilon_matrix=epsilon_matrix,
+ dtype=torch.float64,
+ compute_forces=True,
+ compute_stress=True,
+ )
+
+ positions = torch.tensor(
+ [[0.0, 0.0, 0.0], [0.5, 0.0, 0.0], [0.0, 0.5, 0.0], [0.5, 0.5, 0.0]],
+ dtype=torch.float64,
+ )
+ cell = torch.eye(3, dtype=torch.float64) * 2.0
+ state = ts.SimState(
+ positions=positions,
+ cell=cell,
+ pbc=True,
+ masses=torch.ones(4, dtype=torch.float64),
+ atomic_numbers=torch.tensor([0, 0, 1, 1], dtype=torch.long),
+ )
+ results = model(state)
+ assert "energy" in results
+ assert "forces" in results
+ assert "stress" in results
+
+
+def test_soft_sphere_model_evaluation(si_double_sim_state: ts.SimState) -> None:
+ """SoftSphereModel (wrapped PairPotentialModel) evaluates correctly."""
+ model = SoftSphereModel(
+ sigma=5.0,
+ epsilon=0.0104,
+ alpha=2.0,
+ cutoff=5.0,
+ dtype=torch.float64,
+ compute_forces=True,
+ compute_stress=True,
+ )
+ results = model(si_double_sim_state)
+ assert "energy" in results
+ assert "forces" in results
+ assert "stress" in results
+ assert results["energy"].shape == (si_double_sim_state.n_systems,)
+ assert results["forces"].shape == (si_double_sim_state.n_atoms, 3)
+ assert results["stress"].shape == (si_double_sim_state.n_systems, 3, 3)
-@pytest.mark.parametrize(
- ("flag_name", "flag_value"),
- [
- ("pbc", torch.tensor([True, True, True])),
- ("pbc", torch.tensor([False, False, False])),
- ("compute_forces", False),
- ("compute_stress", True),
- ("per_atom_energies", True),
- ("per_atom_stresses", False),
- ("use_neighbor_list", True),
- ("use_neighbor_list", False),
- ],
-)
-def test_multispecies_model_flags(*, flag_name: str, flag_value: bool) -> None:
- """Test flags of the SoftSphereMultiModel."""
- model = ss.SoftSphereMultiModel(n_species=2, **{flag_name: flag_value})
- # For SoftSphereMultiModel, we don't need to convert attribute names
- # as it uses public attribute names for all flags
- assert getattr(model, flag_name) is flag_value
+def test_soft_sphere_model_force_conservation(
+ si_double_sim_state: ts.SimState,
+) -> None:
+ """SoftSphereModel forces sum to zero (Newton's third law)."""
+ model = SoftSphereModel(
+ sigma=5.0,
+ epsilon=0.0104,
+ alpha=2.0,
+ cutoff=5.0,
+ dtype=torch.float64,
+ compute_forces=True,
+ )
+ results = model(si_double_sim_state)
+ for sys_idx in range(si_double_sim_state.n_systems):
+ mask = si_double_sim_state.system_idx == sys_idx
+ assert torch.allclose(
+ results["forces"][mask].sum(dim=0),
+ torch.zeros(3, dtype=torch.float64),
+ atol=1e-10,
+ )
diff --git a/tests/test_fix_symmetry.py b/tests/test_fix_symmetry.py
index a770168b..ea5076d7 100644
--- a/tests/test_fix_symmetry.py
+++ b/tests/test_fix_symmetry.py
@@ -11,9 +11,10 @@
from ase.stress import full_3x3_to_voigt_6_stress, voigt_6_to_full_3x3_stress
import torch_sim as ts
+from tests.conftest import DEVICE, DTYPE
from torch_sim.constraints import FixCom, FixSymmetry
from torch_sim.models.interface import ModelInterface
-from torch_sim.models.lennard_jones import UnbatchedLennardJonesModel
+from torch_sim.models.lennard_jones import LennardJonesModel
from torch_sim.symmetrize import get_symmetry_datasets
@@ -22,15 +23,11 @@
SPACEGROUPS = {"fcc": 225, "hcp": 194, "diamond": 227, "bcc": 229, "p6bar": 174}
MAX_STEPS = 30
-DTYPE = torch.float64
SYMPREC = 0.01
-CPU = torch.device("cpu")
+REPEATS = 2
-# === Structure helpers ===
-
-
-def make_structure(name: str) -> Atoms:
+def make_structure(name: str, repeats: int = REPEATS) -> Atoms:
"""Create a test structure by name (fcc/hcp/diamond/bcc/p6bar + _rotated suffix)."""
base = name.replace("_rotated", "")
builders = {
@@ -46,6 +43,8 @@ def make_structure(name: str) -> Atoms:
),
}
atoms = builders[base]()
+ # make a supercell to exaggerate the impact of symmetry breaking noise
+ atoms = atoms.repeat([repeats, repeats, repeats])
if "_rotated" in name:
rotation_product = np.eye(3)
for axis_idx in range(3):
@@ -63,13 +62,10 @@ def make_structure(name: str) -> Atoms:
return atoms
-# === Fixtures ===
-
-
@pytest.fixture
-def model() -> UnbatchedLennardJonesModel:
+def model() -> LennardJonesModel:
"""LJ model for testing."""
- return UnbatchedLennardJonesModel(
+ return LennardJonesModel(
sigma=1.0,
epsilon=0.05,
cutoff=6.0,
@@ -79,42 +75,58 @@ def model() -> UnbatchedLennardJonesModel:
class NoisyModelWrapper(ModelInterface):
- """Wrapper that adds noise to forces and stress."""
+ """Wrapper that adds Weibull-distributed noise to forces and stress.
+
+ Uses Weibull noise (heavy-tailed) rather than Gaussian so that occasional
+ large perturbations can break symmetry in negative-control tests. This
+ also better mimics real ML potential errors, which have heavy tails.
+ """
- model: UnbatchedLennardJonesModel
- rng: np.random.Generator
+ model: LennardJonesModel
noise_scale: float
+ concentration: float
def __init__(
self,
- model: UnbatchedLennardJonesModel,
- noise_scale: float = 1e-4,
+ model: LennardJonesModel,
+ noise_scale: float = 1e-1,
+ concentration: float = 1.0,
) -> None:
super().__init__()
self.model = model
- self.rng = np.random.default_rng(seed=1)
self.noise_scale = noise_scale
+ self.concentration = concentration
self._device = model.device
self._dtype = model.dtype
self._compute_stress = model.compute_stress
self._compute_forces = model.compute_forces
def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]:
- """Forward pass with added noise."""
+ """Forward pass with added Weibull noise."""
results = self.model(state, **kwargs)
for key in ("forces", "stress"):
if key in results:
- noise = torch.tensor(
- self.rng.normal(size=results[key].shape),
- dtype=results[key].dtype,
- device=results[key].device,
+ shape = results[key].shape
+ # Random direction on the unit sphere (per element for stress)
+ direction = torch.randn(shape, generator=state.rng)
+ direction = direction / torch.norm(direction, dim=-1, keepdim=True).clamp(
+ min=1e-12
+ )
+ # Weibull magnitude via inverse CDF: scale * (-ln(U))^(1/k)
+ u = torch.rand(shape[0], generator=state.rng)
+ magnitudes = self.noise_scale * (-torch.log(u)).pow(
+ 1.0 / self.concentration
)
- results[key] = results[key] + self.noise_scale * noise
+ if key == "forces":
+ noise = magnitudes.unsqueeze(-1) * direction
+ else:
+ noise = magnitudes.view(-1, 1, 1) * direction
+ results[key] = results[key] + noise
return results
@pytest.fixture
-def noisy_lj_model(model: UnbatchedLennardJonesModel) -> NoisyModelWrapper:
+def noisy_lj_model(model: LennardJonesModel) -> NoisyModelWrapper:
"""LJ model with noise added to forces/stress."""
return NoisyModelWrapper(model)
@@ -123,7 +135,7 @@ def noisy_lj_model(model: UnbatchedLennardJonesModel) -> NoisyModelWrapper:
def p6bar_both_constraints() -> tuple[ts.SimState, FixSymmetry, Atoms, ASEFixSymmetry]:
"""P-6 structure with both TorchSim and ASE constraints (shared setup)."""
atoms = make_structure("p6bar")
- state = ts.io.atoms_to_state(atoms, CPU, DTYPE)
+ state = ts.io.atoms_to_state(atoms, DEVICE, DTYPE)
ts_constraint = FixSymmetry.from_state(state, symprec=SYMPREC)
ase_atoms = atoms.copy()
ase_refine_symmetry(ase_atoms, symprec=SYMPREC)
@@ -131,9 +143,6 @@ def p6bar_both_constraints() -> tuple[ts.SimState, FixSymmetry, Atoms, ASEFixSym
return state, ts_constraint, ase_atoms, ase_constraint
-# === Optimization helper ===
-
-
def run_optimization_check_symmetry(
state: ts.SimState,
model: ModelInterface,
@@ -168,9 +177,6 @@ def run_optimization_check_symmetry(
}
-# === Tests: Creation & Basics ===
-
-
class TestFixSymmetryCreation:
"""Tests for FixSymmetry creation and basic behavior."""
@@ -178,14 +184,15 @@ def test_from_state_batched(self) -> None:
"""Batched state with FCC + diamond gets correct ops, atom counts, and DOF."""
state = ts.io.atoms_to_state(
[make_structure("fcc"), make_structure("diamond")],
- CPU,
+ DEVICE,
DTYPE,
)
constraint = FixSymmetry.from_state(state, symprec=SYMPREC)
assert len(constraint.rotations) == 2
- assert constraint.rotations[0].shape[0] == 48 # cubic
- assert constraint.symm_maps[0].shape == (48, 1) # Cu: 1 atom
- assert constraint.symm_maps[1].shape == (48, 2) # Si: 2 atoms
+ assert constraint.rotations[0].shape[0] == 48 * REPEATS**3 # cubic
+ n_ops = 48 * REPEATS**3
+ assert constraint.symm_maps[0].shape == (n_ops, REPEATS**3)
+ assert constraint.symm_maps[1].shape == (n_ops, 2 * REPEATS**3)
assert torch.all(constraint.get_removed_dof(state) == 0)
def test_p1_identity_is_noop(self) -> None:
@@ -196,7 +203,7 @@ def test_p1_identity_is_noop(self) -> None:
cell=[[3.0, 0.1, 0.2], [0.15, 3.5, 0.1], [0.2, 0.15, 4.0]],
pbc=True,
)
- state = ts.io.atoms_to_state(atoms, CPU, DTYPE)
+ state = ts.io.atoms_to_state(atoms, DEVICE, DTYPE)
constraint = FixSymmetry.from_state(state, symprec=SYMPREC)
assert constraint.rotations[0].shape[0] == 1
@@ -217,7 +224,7 @@ def test_from_state_refine_symmetry(self, *, refine: bool) -> None:
atoms = make_structure("fcc")
rng = np.random.default_rng(42)
atoms.positions += rng.standard_normal(atoms.positions.shape) * 0.001
- state = ts.io.atoms_to_state(atoms, CPU, DTYPE)
+ state = ts.io.atoms_to_state(atoms, DEVICE, DTYPE)
orig_pos = state.positions.clone()
_ = FixSymmetry.from_state(state, symprec=SYMPREC, refine_symmetry_state=refine)
if not refine:
@@ -235,7 +242,7 @@ def test_refine_symmetry_produces_correct_spacegroup(
expected = SPACEGROUPS[structure_name]
rng = np.random.default_rng(42)
atoms.positions += rng.standard_normal(atoms.positions.shape) * 0.001
- state = ts.io.atoms_to_state(atoms, CPU, DTYPE)
+ state = ts.io.atoms_to_state(atoms, DEVICE, DTYPE)
refined_cell, refined_pos = refine_symmetry(
state.row_vector_cell[0],
@@ -250,23 +257,26 @@ def test_refine_symmetry_produces_correct_spacegroup(
assert datasets[0].number == expected
def test_cubic_forces_vanish(self) -> None:
- """Asymmetric force on single cubic atom symmetrizes to zero."""
+ """Asymmetric force on cubic atoms symmetrizes to zero."""
state = ts.io.atoms_to_state(
[make_structure("fcc"), make_structure("diamond")],
- CPU,
+ DEVICE,
DTYPE,
)
constraint = FixSymmetry.from_state(state, symprec=SYMPREC)
- forces = torch.tensor(
- [[1.0, 0.0, 0.0], [0.5, 0.5, 0.0], [0.0, 0.5, 0.5]],
- dtype=DTYPE,
- )
+ n_atoms = state.positions.shape[0]
+ # Random asymmetric forces on all atoms
+ forces = torch.randn(n_atoms, 3, device=DEVICE, dtype=DTYPE, generator=state.rng)
constraint.adjust_forces(state, forces)
- assert torch.allclose(forces[0], torch.zeros(3, dtype=DTYPE), atol=1e-10)
+ # All FCC atoms (first 27 in 3x3x3 supercell) should have zero forces
+ n_fcc = REPEATS**3
+ assert torch.allclose(
+ forces[:n_fcc], torch.zeros(n_fcc, 3, dtype=DTYPE), atol=1e-10
+ )
def test_large_deformation_clamped(self) -> None:
"""Per-step deformation > 0.25 is clamped rather than rejected."""
- state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE)
+ state = ts.io.atoms_to_state(make_structure("fcc"), DEVICE, DTYPE)
constraint = FixSymmetry.from_state(state, symprec=SYMPREC)
orig_cell = state.cell.clone()
new_cell = state.cell.clone() * 1.5 # 50% strain, well over 0.25
@@ -282,7 +292,7 @@ def test_large_deformation_clamped(self) -> None:
def test_nan_deformation_raises(self) -> None:
"""NaN in proposed cell raises RuntimeError instead of propagating."""
- state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE)
+ state = ts.io.atoms_to_state(make_structure("fcc"), DEVICE, DTYPE)
constraint = FixSymmetry.from_state(state, symprec=SYMPREC)
new_cell = state.cell.clone()
new_cell[0, 0, 0] = float("nan")
@@ -304,7 +314,7 @@ def test_init_mismatched_lengths_raises(self) -> None:
def test_adjust_skipped_when_disabled(self, method: str) -> None:
"""adjust_positions=False / adjust_cell=False leaves data unchanged."""
flag = method.replace("adjust_", "") # "positions" or "cell"
- state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE)
+ state = ts.io.atoms_to_state(make_structure("fcc"), DEVICE, DTYPE)
constraint = FixSymmetry.from_state(
state,
symprec=SYMPREC,
@@ -319,9 +329,6 @@ def test_adjust_skipped_when_disabled(self, method: str) -> None:
assert torch.equal(data, expected)
-# === Tests: Comparison with ASE ===
-
-
class TestFixSymmetryComparisonWithASE:
"""Compare TorchSim FixSymmetry with ASE's implementation on P-6 structure."""
@@ -384,15 +391,12 @@ def test_position_symmetrization_matches_ase(
assert np.allclose(new_pos_ts.numpy(), new_pos_ase, atol=1e-10)
-# === Tests: Merge, Select, Reindex ===
-
-
class TestFixSymmetryMergeSelectReindex:
"""Tests for reindex/merge API, select, and concatenation."""
def test_reindex_preserves_symmetry_data(self) -> None:
"""reindex shifts system_idx but preserves rotations and symm_maps."""
- state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE)
+ state = ts.io.atoms_to_state(make_structure("fcc"), DEVICE, DTYPE)
orig = FixSymmetry.from_state(state, symprec=SYMPREC)
shifted = orig.reindex(atom_offset=100, system_offset=5)
assert shifted.system_idx.item() == 5
@@ -401,8 +405,8 @@ def test_reindex_preserves_symmetry_data(self) -> None:
def test_merge_two_constraints(self) -> None:
"""Merge two single-system constraints via reindex + merge."""
- s1 = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE)
- s2 = ts.io.atoms_to_state(make_structure("diamond"), CPU, DTYPE)
+ s1 = ts.io.atoms_to_state(make_structure("fcc"), DEVICE, DTYPE)
+ s2 = ts.io.atoms_to_state(make_structure("diamond"), DEVICE, DTYPE)
c1 = FixSymmetry.from_state(s1)
c2 = FixSymmetry.from_state(s2).reindex(atom_offset=0, system_offset=1)
merged = FixSymmetry.merge([c1, c2])
@@ -417,9 +421,9 @@ def test_merge_multi_system_no_duplicate_indices(self) -> None:
make_structure("hcp"),
]
atoms_b = [make_structure("bcc"), make_structure("fcc")]
- c_a = FixSymmetry.from_state(ts.io.atoms_to_state(atoms_a, CPU, DTYPE))
+ c_a = FixSymmetry.from_state(ts.io.atoms_to_state(atoms_a, DEVICE, DTYPE))
c_b = FixSymmetry.from_state(
- ts.io.atoms_to_state(atoms_b, CPU, DTYPE),
+ ts.io.atoms_to_state(atoms_b, DEVICE, DTYPE),
).reindex(atom_offset=0, system_offset=3)
merged = FixSymmetry.merge([c_a, c_b])
assert merged.system_idx.tolist() == [0, 1, 2, 3, 4]
@@ -428,12 +432,12 @@ def test_system_constraint_merge_multi_system_via_concatenate(self) -> None:
"""Regression: merging multi-system FixCom via concatenate_states."""
s1 = ts.io.atoms_to_state(
[make_structure("fcc"), make_structure("diamond")],
- CPU,
+ DEVICE,
DTYPE,
)
s2 = ts.io.atoms_to_state(
[make_structure("bcc"), make_structure("hcp")],
- CPU,
+ DEVICE,
DTYPE,
)
s1.constraints = [FixCom(system_idx=torch.tensor([0, 1]))]
@@ -445,8 +449,8 @@ def test_system_constraint_merge_multi_system_via_concatenate(self) -> None:
def test_concatenate_states_with_fix_symmetry(self) -> None:
"""FixSymmetry survives concatenate_states and still symmetrizes correctly."""
- s1 = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE)
- s2 = ts.io.atoms_to_state(make_structure("diamond"), CPU, DTYPE)
+ s1 = ts.io.atoms_to_state(make_structure("fcc"), DEVICE, DTYPE)
+ s2 = ts.io.atoms_to_state(make_structure("diamond"), DEVICE, DTYPE)
s1.constraints = [FixSymmetry.from_state(s1, symprec=SYMPREC)]
s2.constraints = [FixSymmetry.from_state(s2, symprec=SYMPREC)]
combined = ts.concatenate_states([s1, s2])
@@ -454,64 +458,70 @@ def test_concatenate_states_with_fix_symmetry(self) -> None:
assert isinstance(constraint, FixSymmetry)
assert constraint.system_idx.tolist() == [0, 1]
assert len(constraint.rotations) == 2
- # Forces on single FCC atom should still vanish
- forces = torch.tensor(
- [[1.0, 0.0, 0.0], [0.5, 0.5, 0.0], [0.0, 0.5, 0.5]],
- dtype=DTYPE,
+ # Forces on FCC atoms should still vanish after symmetrization
+ n_atoms = combined.positions.shape[0]
+ n_fcc = REPEATS**3
+ forces = torch.randn(
+ n_atoms, 3, device=DEVICE, dtype=DTYPE, generator=combined.rng
)
constraint.adjust_forces(combined, forces)
- assert torch.allclose(forces[0], torch.zeros(3, dtype=DTYPE), atol=1e-10)
+ assert torch.allclose(
+ forces[:n_fcc], torch.zeros(n_fcc, 3, dtype=DTYPE), atol=1e-10
+ )
def test_select_sub_constraint(self) -> None:
"""Select second system from batched constraint."""
state = ts.io.atoms_to_state(
[make_structure("fcc"), make_structure("diamond")],
- CPU,
+ DEVICE,
DTYPE,
)
constraint = FixSymmetry.from_state(state, symprec=SYMPREC)
selected = constraint.select_sub_constraint(torch.tensor([1, 2]), sys_idx=1)
assert selected is not None
- assert selected.symm_maps[0].shape[1] == 2
+ # diamond has 2 atoms per unit cell
+ assert selected.symm_maps[0].shape[1] == 2 * REPEATS**3
assert selected.system_idx.item() == 0
def test_select_constraint_by_mask(self) -> None:
"""Select first system via system_mask."""
state = ts.io.atoms_to_state(
[make_structure("fcc"), make_structure("diamond")],
- CPU,
+ DEVICE,
DTYPE,
)
constraint = FixSymmetry.from_state(state, symprec=SYMPREC)
+ n_atoms = state.positions.shape[0]
+ atom_mask = torch.zeros(n_atoms, dtype=torch.bool)
+ atom_mask[0] = True # keep at least one atom from first system
selected = constraint.select_constraint(
- atom_mask=torch.tensor([True, False, False]),
+ atom_mask=atom_mask,
system_mask=torch.tensor([True, False]),
)
assert selected is not None
assert len(selected.rotations) == 1
- assert selected.rotations[0].shape[0] == 48
+ n_ops = 48 * REPEATS**3 # cubic x supercell translations
+ assert selected.rotations[0].shape[0] == n_ops
def test_select_returns_none_for_nonexistent(self) -> None:
"""select_sub_constraint and select_constraint return None when no match."""
state = ts.io.atoms_to_state(
[make_structure("fcc"), make_structure("diamond")],
- CPU,
+ DEVICE,
DTYPE,
)
constraint = FixSymmetry.from_state(state, symprec=SYMPREC)
+ n_atoms = state.positions.shape[0]
assert constraint.select_sub_constraint(torch.tensor([0]), sys_idx=99) is None
assert (
constraint.select_constraint(
- atom_mask=torch.zeros(3, dtype=torch.bool),
+ atom_mask=torch.zeros(n_atoms, dtype=torch.bool),
system_mask=torch.zeros(2, dtype=torch.bool),
)
is None
)
-# === Tests: build_symmetry_map chunked path ===
-
-
def test_build_symmetry_map_chunked_matches_vectorized() -> None:
"""Per-op loop gives same result as vectorized path."""
import torch_sim.symmetrize as sym_mod
@@ -521,11 +531,11 @@ def test_build_symmetry_map_chunked_matches_vectorized() -> None:
build_symmetry_map,
)
- state = ts.io.atoms_to_state(make_structure("p6bar"), CPU, DTYPE)
+ state = ts.io.atoms_to_state(make_structure("p6bar"), DEVICE, DTYPE)
cell = state.row_vector_cell[0]
frac = state.positions @ torch.linalg.inv(cell)
dataset = _moyo_dataset(cell, frac, state.atomic_numbers)
- rotations, translations = _extract_symmetry_ops(dataset, DTYPE, CPU)
+ rotations, translations = _extract_symmetry_ops(dataset, DTYPE, DEVICE)
old_threshold = sym_mod._SYMM_MAP_CHUNK_THRESHOLD # noqa: SLF001
try:
@@ -538,9 +548,6 @@ def test_build_symmetry_map_chunked_matches_vectorized() -> None:
assert torch.equal(vectorized, chunked)
-# === Tests: Optimization ===
-
-
class TestFixSymmetryWithOptimization:
"""Test FixSymmetry with actual optimization routines."""
@@ -560,7 +567,7 @@ def test_distorted_preserves_symmetry(
"""Compressed structure relaxes while preserving symmetry."""
atoms = make_structure(structure_name)
expected = SPACEGROUPS[structure_name]
- state = ts.io.atoms_to_state(atoms, CPU, DTYPE)
+ state = ts.io.atoms_to_state(atoms, DEVICE, DTYPE)
constraint = FixSymmetry.from_state(
state,
symprec=SYMPREC,
@@ -581,11 +588,11 @@ def test_distorted_preserves_symmetry(
@pytest.mark.parametrize("cell_filter", [ts.CellFilter.unit, ts.CellFilter.frechet])
def test_cell_filter_preserves_symmetry(
self,
- model: UnbatchedLennardJonesModel,
+ model: LennardJonesModel,
cell_filter: ts.CellFilter,
) -> None:
"""Cell filters with FixSymmetry preserve symmetry."""
- state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE)
+ state = ts.io.atoms_to_state(make_structure("fcc"), DEVICE, DTYPE)
constraint = FixSymmetry.from_state(state, symprec=SYMPREC)
state.constraints = [constraint]
initial = get_symmetry_datasets(state, symprec=SYMPREC)
@@ -607,7 +614,7 @@ def test_lbfgs_preserves_symmetry(
cell_filter: ts.CellFilter,
) -> None:
"""Regression: LBFGS must use set_constrained_cell for FixSymmetry support."""
- state = ts.io.atoms_to_state(make_structure("bcc"), CPU, DTYPE)
+ state = ts.io.atoms_to_state(make_structure("bcc"), DEVICE, DTYPE)
constraint = FixSymmetry.from_state(state, symprec=SYMPREC)
state.constraints = [constraint]
state.cell = state.cell * 0.95
@@ -630,10 +637,12 @@ def test_noisy_model_loses_symmetry_without_constraint(
self,
noisy_lj_model: NoisyModelWrapper,
) -> None:
- """Negative control: without FixSymmetry, noise breaks rotated BCC symmetry."""
- state = ts.io.atoms_to_state(make_structure("bcc_rotated"), CPU, DTYPE)
- noisier_model = NoisyModelWrapper(noisy_lj_model.model, noise_scale=5e-4)
- result = run_optimization_check_symmetry(state, noisier_model, constraint=None)
+ """Negative control: without FixSymmetry, Weibull noise breaks BCC symmetry."""
+ # Need supercell to reliably break symmetry. Previously test pinned to magic seed.
+ state = ts.io.atoms_to_state(
+ make_structure("bcc_rotated", repeats=max(REPEATS, 2)), DEVICE, DTYPE
+ )
+ result = run_optimization_check_symmetry(state, noisy_lj_model, constraint=None)
assert result["initial_spacegroups"][0] == 229
assert result["final_spacegroups"][0] != 229
@@ -646,7 +655,7 @@ def test_noisy_model_preserves_symmetry_with_constraint(
) -> None:
"""With FixSymmetry, noisy forces still preserve symmetry."""
name = "bcc_rotated" if rotated else "bcc"
- state = ts.io.atoms_to_state(make_structure(name), CPU, DTYPE)
+ state = ts.io.atoms_to_state(make_structure(name), DEVICE, DTYPE)
constraint = FixSymmetry.from_state(state, symprec=SYMPREC)
result = run_optimization_check_symmetry(
state,
@@ -665,7 +674,7 @@ def test_cumulative_strain_clamp_direct(self) -> None:
1. The cell doesn't drift beyond the strain envelope
2. Symmetry is preserved after many small steps
"""
- state = ts.io.atoms_to_state(make_structure("fcc"), CPU, DTYPE)
+ state = ts.io.atoms_to_state(make_structure("fcc", repeats=1), DEVICE, DTYPE)
constraint = FixSymmetry.from_state(state, symprec=SYMPREC)
constraint.max_cumulative_strain = 0.15
assert constraint.reference_cells is not None
diff --git a/tests/test_integrators.py b/tests/test_integrators.py
index b3d8c336..dd9b49d0 100644
--- a/tests/test_integrators.py
+++ b/tests/test_integrators.py
@@ -339,7 +339,7 @@ def test_nvt_nose_hoover(ar_double_sim_state: ts.SimState, lj_model: LennardJone
temperatures_list = [t.tolist() for t in temperatures_tensor.T]
assert torch.allclose(
temperatures_tensor[-1],
- torch.tensor([299.9910, 299.6800], dtype=dtype),
+ torch.tensor([300.0096, 299.7024], dtype=dtype),
)
energies_tensor = torch.stack(energies)
@@ -728,7 +728,7 @@ def test_npt_nose_hoover(ar_double_sim_state: ts.SimState, lj_model: LennardJone
temperatures_list = [t.tolist() for t in temperatures_tensor.T]
assert torch.allclose(
temperatures_tensor[-1],
- torch.tensor([297.8602, 297.5306], dtype=dtype),
+ torch.tensor([298.2752, 297.9444], dtype=dtype),
)
energies_tensor = torch.stack(energies)
@@ -1041,11 +1041,6 @@ def test_compute_cell_force_atoms_per_system():
assert abs(force_ratio - 8.0) / 8.0 < 0.1
-# ---------------------------------------------------------------------------
-# Reproducibility tests
-# ---------------------------------------------------------------------------
-
-
def test_nvt_langevin_reproducibility(
ar_double_sim_state: ts.SimState, lj_model: LennardJonesModel
):
diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py
index cae16fb3..e51ead76 100644
--- a/tests/test_optimizers.py
+++ b/tests/test_optimizers.py
@@ -1096,25 +1096,26 @@ def test_optimizer_batch_consistency(
lj_model: ModelInterface,
) -> None:
"""Test batched optimizer is consistent with individual optimizations."""
- generator = torch.Generator(device=ar_supercell_sim_state.device)
-
# Create two distinct initial states by cloning and perturbing
state1_orig = ar_supercell_sim_state.clone()
+ state1_orig.rng = 0
# Apply identical perturbations to state1_orig
# for state_item in [state1_orig, state2_orig]: # Old loop structure
- generator.manual_seed(43) # Reset seed for positions
state1_orig.positions += (
torch.randn(
- state1_orig.positions.shape, device=state1_orig.device, generator=generator
+ state1_orig.positions.shape,
+ device=state1_orig.device,
+ generator=state1_orig.rng,
)
* 0.1
)
if filter_func:
- generator.manual_seed(44) # Reset seed for cell
state1_orig.cell += (
torch.randn(
- state1_orig.cell.shape, device=state1_orig.device, generator=generator
+ state1_orig.cell.shape,
+ device=state1_orig.device,
+ generator=state1_orig.rng,
)
* 0.01
)
@@ -1177,7 +1178,7 @@ def energy_converged(e_current: torch.Tensor, e_prev: torch.Tensor) -> bool:
# Converge when all batch energies have converged
while not torch.allclose(e_current_batch, e_prev_batch, atol=1e-6):
e_prev_batch = e_current_batch.clone()
- batch_opt_state = step_fn_batch(model=lj_model, state=batch_opt_state)
+ batch_opt_state = step_fn_batch(model=lj_model, state=batch_opt_state, dt_max=0.3)
e_current_batch = batch_opt_state.energy.clone()
steps_batch += 1
if steps_batch > 1000:
diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py
index 0660d757..91323880 100644
--- a/torch_sim/constraints.py
+++ b/torch_sim/constraints.py
@@ -859,8 +859,6 @@ def from_state(
reference_cells=reference_cells,
)
- # === Symmetrization hooks ===
-
def adjust_forces(self, state: SimState, forces: torch.Tensor) -> None:
"""Symmetrize forces according to crystal symmetry."""
self._symmetrize_rank1(state, forces)
@@ -959,8 +957,6 @@ def _symmetrize_rank1(self, state: SimState, vectors: torch.Tensor) -> None:
self.symm_maps[ci],
)
- # === Constraint interface ===
-
def get_removed_dof(self, state: SimState) -> torch.Tensor:
"""Returns zero - constrains direction, not DOF count."""
return torch.zeros(state.n_systems, dtype=torch.long, device=state.device)
diff --git a/torch_sim/elastic.py b/torch_sim/elastic.py
index 94b95b79..944cdb81 100644
--- a/torch_sim/elastic.py
+++ b/torch_sim/elastic.py
@@ -1,4 +1,4 @@
-# ruff: noqa: RUF002, RUF003, PLC2401
+# ruff: noqa: RUF003, PLC2401
"""Calculation of elastic properties of crystals.
Primary Sources and References for Crystal Elasticity.
diff --git a/torch_sim/math.py b/torch_sim/math.py
index 5f7604f6..64a2edd7 100644
--- a/torch_sim/math.py
+++ b/torch_sim/math.py
@@ -1,6 +1,6 @@
"""Mathematical operations and utilities. Adapted from https://github.com/abhijeetgangan/torch_matfunc."""
-# ruff: noqa: FBT001, FBT002, RUF002
+# ruff: noqa: FBT001, FBT002
from typing import Final
diff --git a/torch_sim/models/lennard_jones.py b/torch_sim/models/lennard_jones.py
index 518355cc..c4b67c24 100644
--- a/torch_sim/models/lennard_jones.py
+++ b/torch_sim/models/lennard_jones.py
@@ -1,158 +1,92 @@
-"""Classical pairwise interatomic potential model.
+"""Lennard-Jones 12-6 potential model.
-This module implements the Lennard-Jones potential for molecular dynamics simulations.
-It provides efficient calculation of energies, forces, and stresses based on the
-classic 12-6 potential function. The implementation supports both full pairwise
-calculations and neighbor list-based optimizations.
+Thin wrapper around :class:`~torch_sim.models.pair_potential.PairPotentialModel` with
+the :func:`lennard_jones_pair` energy function baked in.
Example::
- # Create a Lennard-Jones model with default parameters
- model = LennardJonesModel(device=torch.device("cuda"))
-
- # Create a model with custom parameters
- model = LennardJonesModel(
- sigma=3.405, # Angstroms
- epsilon=0.01032, # eV
- cutoff=10.0, # Angstroms
- compute_stress=True,
- )
-
- # Calculate properties for a simulation state
- output = model(sim_state)
- energy = output["energy"]
- forces = output["forces"]
+ model = LennardJonesModel(sigma=3.405, epsilon=0.0104, cutoff=8.5)
+ results = model(sim_state)
"""
-from collections.abc import Callable
+from __future__ import annotations
+
+import functools
+from collections.abc import Callable # noqa: TC003
import torch
-import torch_sim as ts
-from torch_sim import transforms
-from torch_sim.models.interface import ModelInterface
+from torch_sim.models.pair_potential import PairPotentialModel
from torch_sim.neighbors import torchsim_nl
-DEFAULT_SIGMA = 1.0
-DEFAULT_EPSILON = 1.0
-
-
def lennard_jones_pair(
dr: torch.Tensor,
- sigma: float | torch.Tensor = DEFAULT_SIGMA,
- epsilon: float | torch.Tensor = DEFAULT_EPSILON,
+ zi: torch.Tensor, # noqa: ARG001
+ zj: torch.Tensor, # noqa: ARG001
+ sigma: torch.Tensor | float = 1.0,
+ epsilon: torch.Tensor | float = 1.0,
) -> torch.Tensor:
- """Calculate pairwise Lennard-Jones interaction energies between particles.
-
- Implements the standard 12-6 Lennard-Jones potential that combines short-range
- repulsion with longer-range attraction. The potential has a minimum at r=sigma.
+ """Lennard-Jones 12-6 pair energy.
- The functional form is:
- V(r) = 4*epsilon*[(sigma/r)^12 - (sigma/r)^6]
+ V(r) = 4ε[(σ/r)¹² - (σ/r)⁶]
Args:
- dr: Pairwise distances between particles. Shape: [n, m].
- sigma: Distance at which potential reaches its minimum. Either a scalar float
- or tensor of shape [n, m] for particle-specific interaction distances.
- epsilon: Depth of the potential well (energy scale). Either a scalar float
- or tensor of shape [n, m] for pair-specific interaction strengths.
+ dr: Pairwise distances, shape [n_pairs].
+ zi: Atomic numbers of first atoms (unused, for interface compatibility).
+ zj: Atomic numbers of second atoms (unused, for interface compatibility).
+ sigma: Length scale. Defaults to 1.0.
+ epsilon: Energy scale. Defaults to 1.0.
Returns:
- torch.Tensor: Pairwise Lennard-Jones interaction energies between particles.
- Shape: [n, m]. Each element [i,j] represents the interaction energy between
- particles i and j.
+ Pair energies, shape [n_pairs].
"""
- # Calculate inverse dr and its powers
- idr = sigma / dr
- idr2 = idr * idr
- idr6 = idr2 * idr2 * idr2
+ idr6 = (sigma / dr).pow(6)
idr12 = idr6 * idr6
-
- # Calculate potential energy
energy = 4.0 * epsilon * (idr12 - idr6)
-
- # Handle potential numerical instabilities and infinities
return torch.where(dr > 0, energy, torch.zeros_like(energy))
- # return torch.nan_to_num(energy, nan=0.0, posinf=0.0, neginf=0.0)
def lennard_jones_pair_force(
dr: torch.Tensor,
- sigma: float | torch.Tensor = DEFAULT_SIGMA,
- epsilon: float | torch.Tensor = DEFAULT_EPSILON,
+ sigma: torch.Tensor | float = 1.0,
+ epsilon: torch.Tensor | float = 1.0,
) -> torch.Tensor:
- """Calculate pairwise Lennard-Jones forces between particles.
-
- Implements the force derived from the 12-6 Lennard-Jones potential. The force
- is repulsive at short range and attractive at long range, with a zero-crossing
- at r=sigma.
-
- The functional form is:
- F(r) = 24*epsilon/r * [(2*sigma^12/r^12) - (sigma^6/r^6)]
+ """Lennard-Jones 12-6 pair force (negative gradient of energy).
- This is the negative gradient of the Lennard-Jones potential energy.
+ F(r) = 24ε/r [2(σ/r)¹² - (σ/r)⁶]
Args:
- dr: Pairwise distances between particles. Shape: [n, m].
- sigma: Distance at which force changes from repulsive to attractive.
- Either a scalar float or tensor of shape [n, m] for particle-specific
- interaction distances.
- epsilon: Energy scale of the interaction. Either a scalar float or tensor
- of shape [n, m] for pair-specific interaction strengths.
+ dr: Pairwise distances, shape [n_pairs].
+ sigma: Length scale. Defaults to 1.0.
+ epsilon: Energy scale. Defaults to 1.0.
Returns:
- torch.Tensor: Pairwise Lennard-Jones forces between particles. Shape: [n, m].
- Each element [i,j] represents the force magnitude between particles i and j.
- Positive values indicate repulsion, negative values indicate attraction.
+ Pair force magnitudes (positive = repulsive), shape [n_pairs].
"""
- # Calculate inverse dr and its powers
idr = sigma / dr
- idr2 = idr * idr
- idr6 = idr2 * idr2 * idr2
+ idr6 = idr.pow(6)
idr12 = idr6 * idr6
-
- # Calculate force (negative gradient of potential)
- # F = -24*epsilon/r * ((sigma/r)^6 - 2*(sigma/r)^12)
force = 24.0 * epsilon / dr * (2.0 * idr12 - idr6)
-
- # Handle potential numerical instabilities and infinities
return torch.where(dr > 0, force, torch.zeros_like(force))
-class UnbatchedLennardJonesModel(ModelInterface):
- """Unbatched Lennard-Jones model.
-
- Implements the Lennard-Jones 12-6 potential for molecular dynamics simulations.
- This implementation loops over systems in batched inputs and is intended for
- testing or baseline comparisons with the default batched model.
+class LennardJonesModel(PairPotentialModel):
+ """Lennard-Jones 12-6 pair potential model.
- Attributes:
- sigma (torch.Tensor): Length parameter controlling particle size/repulsion
- distance.
- epsilon (torch.Tensor): Energy parameter controlling interaction strength.
- cutoff (torch.Tensor): Distance cutoff for truncating potential calculation.
- device (torch.device): Device where calculations are performed.
- dtype (torch.dtype): Data type used for calculations.
- compute_forces (bool): Whether to compute atomic forces.
- compute_stress (bool): Whether to compute stress tensor.
- per_atom_energies (bool): Whether to compute per-atom energy decomposition.
- per_atom_stresses (bool): Whether to compute per-atom stress decomposition.
- neighbor_list_fn (Callable): Function used to construct neighbor lists.
+ Convenience subclass that fixes the pair function to :func:`lj_pair` so the
+ caller only needs to supply ``sigma`` and ``epsilon``.
Example::
- # Basic usage with default parameters
- lj_model = UnbatchedLennardJonesModel(device=torch.device("cuda"))
- results = lj_model(sim_state)
-
- # Custom parameterization for Argon
- ar_model = UnbatchedLennardJonesModel(
- sigma=3.405, # Å
- epsilon=0.0104, # eV
- cutoff=8.5, # Å
+ model = LennardJonesModel(
+ sigma=3.405,
+ epsilon=0.0104,
+ cutoff=2.5 * 3.405,
+ compute_forces=True,
compute_stress=True,
)
+ results = model(sim_state)
"""
def __init__(
@@ -160,360 +94,48 @@ def __init__(
sigma: float = 1.0,
epsilon: float = 1.0,
device: torch.device | None = None,
- dtype: torch.dtype = torch.float32,
- *, # Force keyword-only arguments
+ dtype: torch.dtype = torch.float64,
+ *,
compute_forces: bool = True,
compute_stress: bool = False,
per_atom_energies: bool = False,
per_atom_stresses: bool = False,
neighbor_list_fn: Callable = torchsim_nl,
cutoff: float | None = None,
+ retain_graph: bool = False,
) -> None:
- """Initialize the Lennard-Jones potential calculator.
-
- Creates a model with specified interaction parameters and computational flags.
- The model can be configured to compute different properties (forces, stresses)
- and use different optimization strategies.
-
- Args:
- sigma (float): Length parameter of the Lennard-Jones potential in distance
- units. Controls the size of particles. Defaults to 1.0.
- epsilon (float): Energy parameter of the Lennard-Jones potential in energy
- units. Controls the strength of the interaction. Defaults to 1.0.
- device (torch.device | None): Device to run computations on. If None, uses
- CPU. Defaults to None.
- dtype (torch.dtype): Data type for calculations. Defaults to torch.float32.
- compute_forces (bool): Whether to compute forces. Defaults to True.
- compute_stress (bool): Whether to compute stress tensor. Defaults to False.
- per_atom_energies (bool): Whether to compute per-atom energy decomposition.
- Defaults to False.
- per_atom_stresses (bool): Whether to compute per-atom stress decomposition.
- Defaults to False.
- neighbor_list_fn (Callable): Batched neighbor-list function to use when
- constructing interactions. Defaults to torchsim_nl.
- cutoff (float | None): Cutoff distance for interactions in distance units.
- If None, uses 2.5*sigma. Defaults to None.
-
- Example::
-
- # Model with custom parameters
- model = UnbatchedLennardJonesModel(
- sigma=3.405,
- epsilon=0.01032,
- device=torch.device("cuda"),
- dtype=torch.float64,
- compute_stress=True,
- per_atom_energies=True,
- cutoff=10.0,
- )
- """
- super().__init__()
- self._device = device or torch.device("cpu")
- self._dtype = dtype
- self._compute_forces = compute_forces
- self._compute_stress = compute_stress
- self.per_atom_energies = per_atom_energies
- self.per_atom_stresses = per_atom_stresses
- self.neighbor_list_fn = neighbor_list_fn
-
- # Convert parameters to tensors
- self.sigma = torch.tensor(sigma, dtype=dtype, device=self.device)
- self.cutoff = torch.tensor(cutoff or 2.5 * sigma, dtype=dtype, device=self.device)
- self.epsilon = torch.tensor(epsilon, dtype=dtype, device=self.device)
-
- def unbatched_forward(
- self,
- state: ts.SimState,
- ) -> dict[str, torch.Tensor]:
- """Compute Lennard-Jones properties for a single unbatched system.
-
- Internal implementation that processes a single, non-batched simulation state.
- This method handles the core computations of pair interactions, neighbor lists,
- and property calculations.
-
- Args:
- state (SimState): Single, non-batched simulation state containing atomic
- positions, cell vectors, and other system information.
-
- Returns:
- dict[str, torch.Tensor]: Computed properties:
- - "energy": Total potential energy (scalar)
- - "forces": Atomic forces with shape [n_atoms, 3] (if
- compute_forces=True)
- - "stress": Stress tensor with shape [3, 3] (if compute_stress=True)
- - "energies": Per-atom energies with shape [n_atoms] (if
- per_atom_energies=True)
- - "stresses": Per-atom stresses with shape [n_atoms, 3, 3] (if
- per_atom_stresses=True)
-
- Notes:
- Neighbor lists are always used to construct interacting pairs.
- """
- positions = state.positions
- cell = state.row_vector_cell
- cell = cell.squeeze()
-
- # Ensure system_idx exists (create if None for single system)
- system_idx = (
- state.system_idx
- if state.system_idx is not None
- else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device)
- )
-
- # Wrap positions into the unit cell
- wrapped_positions = (
- ts.transforms.pbc_wrap_batched(positions, state.cell, system_idx, state.pbc)
- if state.pbc.any()
- else positions
- )
-
- mapping, _, shifts_idx = self.neighbor_list_fn(
- positions=wrapped_positions,
- cell=cell,
- pbc=state.pbc,
- cutoff=self.cutoff,
- system_idx=system_idx,
- )
- # Pass shifts_idx directly - get_pair_displacements will convert them
- dr_vec, distances = transforms.get_pair_displacements(
- positions=wrapped_positions,
- cell=cell,
- pbc=state.pbc,
- pairs=(mapping[0], mapping[1]),
- shifts=shifts_idx,
- )
-
- # Calculate pair energies and apply cutoff
- pair_energies = lennard_jones_pair(
- distances, sigma=self.sigma, epsilon=self.epsilon
- )
- # Zero out energies beyond cutoff
- mask = distances < self.cutoff
- pair_energies = torch.where(mask, pair_energies, torch.zeros_like(pair_energies))
-
- # Initialize results with total energy (sum/2 to avoid double counting)
- results = {"energy": 0.5 * pair_energies.sum()}
-
- if self.per_atom_energies:
- atom_energies = torch.zeros(
- positions.shape[0], dtype=self.dtype, device=self.device
- )
- # Each atom gets half of the pair energy
- atom_energies.index_add_(0, mapping[0], 0.5 * pair_energies)
- atom_energies.index_add_(0, mapping[1], 0.5 * pair_energies)
- results["energies"] = atom_energies
-
- if self.compute_forces or self.compute_stress:
- # Calculate forces and apply cutoff
- pair_forces = lennard_jones_pair_force(
- distances, sigma=self.sigma, epsilon=self.epsilon
- )
- pair_forces = torch.where(mask, pair_forces, torch.zeros_like(pair_forces))
-
- # Project forces along displacement vectors
- force_vectors = (pair_forces / distances)[:, None] * dr_vec
-
- if self.compute_forces:
- # Initialize forces tensor
- forces = torch.zeros_like(positions)
- # Add force contributions (f_ij on i, -f_ij on j)
- forces.index_add_(0, mapping[0], -force_vectors)
- forces.index_add_(0, mapping[1], force_vectors)
- results["forces"] = forces
-
- if self.compute_stress and cell is not None:
- # Compute stress tensor
- stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors)
- volume = torch.abs(torch.linalg.det(cell))
-
- results["stress"] = -stress_per_pair.sum(dim=0) / volume
-
- if self.per_atom_stresses:
- atom_stresses = torch.zeros(
- (state.positions.shape[0], 3, 3),
- dtype=self.dtype,
- device=self.device,
- )
- atom_stresses.index_add_(0, mapping[0], -0.5 * stress_per_pair)
- atom_stresses.index_add_(0, mapping[1], -0.5 * stress_per_pair)
- results["stresses"] = atom_stresses / volume
-
- return results
-
- def forward(self, state: ts.SimState, **_kwargs: object) -> dict[str, torch.Tensor]:
- """Compute Lennard-Jones energies, forces, and stresses for a system.
-
- Main entry point for Lennard-Jones calculations that handles batched states by
- dispatching each system to the unbatched implementation and combining results.
+ """Initialize the Lennard-Jones model.
Args:
- state (SimState): Input state containing atomic positions, cell vectors,
- and other system information.
- **_kwargs: Unused; accepted for interface compatibility.
-
- Returns:
- dict[str, torch.Tensor]: Computed properties:
- - "energy": Potential energy with shape [n_systems]
- - "forces": Atomic forces with shape [n_atoms, 3] (if
- compute_forces=True)
- - "stress": Stress tensor with shape [n_systems, 3, 3] (if
- compute_stress=True)
- - "energies": Per-atom energies with shape [n_atoms] (if
- per_atom_energies=True)
- - "stresses": Per-atom stresses with shape [n_atoms, 3, 3] (if
- per_atom_stresses=True)
-
- Raises:
- ValueError: If system cannot be inferred for multi-cell systems.
-
- Example::
-
- # Compute properties for a simulation state
- model = UnbatchedLennardJonesModel(compute_stress=True)
- results = model(sim_state)
-
- energy = results["energy"] # Shape: [n_systems]
- forces = results["forces"] # Shape: [n_atoms, 3]
- stress = results["stress"] # Shape: [n_systems, 3, 3]
- energies = results["energies"] # Shape: [n_atoms]
- stresses = results["stresses"] # Shape: [n_atoms, 3, 3]
+ sigma: Length scale parameter. Defaults to 1.0.
+ epsilon: Energy scale parameter. Defaults to 1.0.
+ device: Device for computations. Defaults to CPU.
+ dtype: Floating-point dtype. Defaults to torch.float32.
+ compute_forces: Whether to compute atomic forces. Defaults to True.
+ compute_stress: Whether to compute the stress tensor. Defaults to False.
+ per_atom_energies: Whether to return per-atom energies. Defaults to False.
+ per_atom_stresses: Whether to return per-atom stresses. Defaults to False.
+ neighbor_list_fn: Neighbor-list constructor. Defaults to torchsim_nl.
+ cutoff: Interaction cutoff. Defaults to 2.5 * sigma.
+ retain_graph: Keep computation graph for differentiable simulation.
"""
- sim_state = state
-
- if sim_state.system_idx is None and sim_state.cell.shape[0] > 1:
- raise ValueError("System can only be inferred for batch size 1.")
-
- outputs = [
- self.unbatched_forward(sim_state[idx]) for idx in range(sim_state.n_systems)
- ]
- properties = outputs[0]
-
- # we always return tensors
- # per atom properties are returned as (atoms, ...) tensors
- # global properties are returned as shape (..., n) tensors
- results: dict[str, torch.Tensor] = {}
- for key in ("stress", "energy"):
- if key in properties:
- results[key] = torch.stack([out[key] for out in outputs])
- for key in ("forces", "energies", "stresses"):
- if key in properties:
- results[key] = torch.cat([out[key] for out in outputs], dim=0)
-
- return results
-
-
-class LennardJonesModel(UnbatchedLennardJonesModel):
- """Default vectorized Lennard-Jones model for batched systems.
-
- This class computes Lennard-Jones energies, forces, and stresses for all systems in
- a batch in one pass, avoiding Python loops over systems in the model forward path.
- Use this class for production runs.
- """
-
- def forward( # noqa: PLR0915
- self, state: ts.SimState, **_kwargs: object
- ) -> dict[str, torch.Tensor]:
- """Compute Lennard-Jones properties with batched tensor operations."""
- sim_state = state
-
- if sim_state.system_idx is None and sim_state.cell.shape[0] > 1:
- raise ValueError("System can only be inferred for batch size 1.")
-
- positions = sim_state.positions
- row_cell = sim_state.row_vector_cell
- pbc = sim_state.pbc
-
- system_idx = (
- sim_state.system_idx
- if sim_state.system_idx is not None
- else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device)
- )
-
- wrapped_positions = (
- ts.transforms.pbc_wrap_batched(positions, sim_state.cell, system_idx, pbc)
- if pbc.any()
- else positions
- )
-
- if pbc.ndim == 1:
- pbc_batched = pbc.unsqueeze(0).expand(sim_state.n_systems, -1)
- else:
- pbc_batched = pbc
-
- mapping, system_mapping, shifts_idx = self.neighbor_list_fn(
- positions=wrapped_positions,
- cell=row_cell,
- pbc=pbc_batched,
- cutoff=self.cutoff,
- system_idx=system_idx,
+ self.sigma = sigma
+ self.epsilon = epsilon
+ pair_fn = functools.partial(lennard_jones_pair, sigma=sigma, epsilon=epsilon)
+ super().__init__(
+ pair_fn=pair_fn,
+ cutoff=cutoff if cutoff is not None else 2.5 * sigma,
+ device=device,
+ dtype=dtype,
+ compute_forces=compute_forces,
+ compute_stress=compute_stress,
+ per_atom_energies=per_atom_energies,
+ per_atom_stresses=per_atom_stresses,
+ neighbor_list_fn=neighbor_list_fn,
+ reduce_to_half_list=True,
+ retain_graph=retain_graph,
)
- cell_shifts = transforms.compute_cell_shifts(row_cell, shifts_idx, system_mapping)
- dr_vec = (
- wrapped_positions[mapping[1]] - wrapped_positions[mapping[0]] + cell_shifts
- )
- distances = dr_vec.norm(dim=1)
-
- cutoff_mask = distances < self.cutoff
- pair_energies = lennard_jones_pair(
- distances, sigma=self.sigma, epsilon=self.epsilon
- )
- pair_energies = torch.where(
- cutoff_mask, pair_energies, torch.zeros_like(pair_energies)
- )
-
- n_systems = sim_state.n_systems
- results: dict[str, torch.Tensor] = {}
- energy = torch.zeros(n_systems, dtype=self.dtype, device=self.device)
- energy.index_add_(0, system_mapping, 0.5 * pair_energies)
- results["energy"] = energy
-
- if self.per_atom_energies:
- atom_energies = torch.zeros(
- positions.shape[0], dtype=self.dtype, device=self.device
- )
- atom_energies.index_add_(0, mapping[0], 0.5 * pair_energies)
- atom_energies.index_add_(0, mapping[1], 0.5 * pair_energies)
- results["energies"] = atom_energies
-
- if self.compute_forces or self.compute_stress:
- pair_forces = lennard_jones_pair_force(
- distances, sigma=self.sigma, epsilon=self.epsilon
- )
- pair_forces = torch.where(
- cutoff_mask, pair_forces, torch.zeros_like(pair_forces)
- )
- safe_distances = torch.where(
- distances > 0, distances, torch.ones_like(distances)
- )
- force_vectors = (pair_forces / safe_distances)[:, None] * dr_vec
-
- if self.compute_forces:
- forces = torch.zeros_like(positions)
- forces.index_add_(0, mapping[0], -force_vectors)
- forces.index_add_(0, mapping[1], force_vectors)
- results["forces"] = forces
-
- if self.compute_stress:
- stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors)
- volumes = torch.abs(torch.linalg.det(row_cell))
- stress = torch.zeros(
- (n_systems, 3, 3),
- dtype=self.dtype,
- device=self.device,
- )
- stress.index_add_(0, system_mapping, -stress_per_pair)
- results["stress"] = stress / volumes[:, None, None]
-
- if self.per_atom_stresses:
- atom_stresses = torch.zeros(
- (positions.shape[0], 3, 3),
- dtype=self.dtype,
- device=self.device,
- )
- atom_stresses.index_add_(0, mapping[0], -0.5 * stress_per_pair)
- atom_stresses.index_add_(0, mapping[1], -0.5 * stress_per_pair)
- atom_volumes = volumes[system_idx]
- results["stresses"] = atom_stresses / atom_volumes[:, None, None]
- return results
+# Keep old name as alias for backward compatibility
+UnbatchedLennardJonesModel = LennardJonesModel
diff --git a/torch_sim/models/morse.py b/torch_sim/models/morse.py
index 2c881808..1e786c2a 100644
--- a/torch_sim/models/morse.py
+++ b/torch_sim/models/morse.py
@@ -1,395 +1,144 @@
-"""Anharmonic interatomic potential for molecular dynamics.
+"""Morse potential model.
-This module implements the Morse potential for molecular dynamics simulations.
-The Morse potential provides a more realistic description of anharmonic bond
-behavior than simple harmonic potentials, capturing bond breaking and formation.
-It includes both energy and force calculations with support for neighbor lists.
+Thin wrapper around :class:`~torch_sim.models.pair_potential.PairPotentialModel` with
+the :func:`morse_pair` energy function baked in.
Example::
- # Create a Morse model with default parameters
- model = MorseModel(device=torch.device("cuda"))
-
- # Calculate properties for a simulation state
- output = model(sim_state)
- energy = output["energy"]
- forces = output["forces"]
+ model = MorseModel(sigma=2.55, epsilon=0.436, alpha=1.359, cutoff=6.0)
+ results = model(sim_state)
+"""
-Notes:
- The Morse potential follows the form:
- V(r) = D_e * (1 - exp(-a(r-r_e)))^2
+from __future__ import annotations
- Where:
- - D_e (epsilon) is the well depth (dissociation energy)
- - r_e (sigma) is the equilibrium bond distance
- - a (alpha) controls the width of the potential well
-"""
+import functools
+from collections.abc import Callable # noqa: TC003
import torch
-import torch_sim as ts
-from torch_sim import transforms
-from torch_sim.models.interface import ModelInterface
+from torch_sim.models.pair_potential import PairPotentialModel
from torch_sim.neighbors import torchsim_nl
-DEFAULT_SIGMA = 1.0
-DEFAULT_EPSILON = 5.0
-DEFAULT_ALPHA = 5.0
-
-
def morse_pair(
dr: torch.Tensor,
- sigma: float | torch.Tensor = DEFAULT_SIGMA,
- epsilon: float | torch.Tensor = DEFAULT_EPSILON,
- alpha: float | torch.Tensor = DEFAULT_ALPHA,
+ zi: torch.Tensor, # noqa: ARG001
+ zj: torch.Tensor, # noqa: ARG001
+ sigma: torch.Tensor | float = 1.0,
+ epsilon: torch.Tensor | float = 5.0,
+ alpha: torch.Tensor | float = 5.0,
) -> torch.Tensor:
- """Calculate pairwise Morse potential energies between particles.
-
- Implements the Morse potential that combines short-range repulsion with
- longer-range attraction. The potential has a minimum at r=sigma and approaches
- -epsilon as r→∞.
+ """Morse pair energy.
- The functional form is:
- V(r) = epsilon * (1 - exp(-alpha*(r-sigma)))^2 - epsilon
+ V(r) = ε(1 - exp(-α(r - σ)))² - ε
Args:
- dr: Pairwise distances between particles. Shape: [n, m].
- sigma: Distance at which potential reaches its minimum. Either a scalar float
- or tensor of shape [n, m] for particle-specific equilibrium distances.
- epsilon: Depth of the potential well (energy scale). Either a scalar float
- or tensor of shape [n, m] for pair-specific interaction strengths.
- alpha: Controls the width of the potential well. Larger values give a narrower
- well. Either a scalar float or tensor of shape [n, m].
+ dr: Pairwise distances, shape [n_pairs].
+ zi: Atomic numbers of first atoms (unused).
+ zj: Atomic numbers of second atoms (unused).
+ sigma: Equilibrium bond distance. Defaults to 1.0.
+ epsilon: Well depth / dissociation energy. Defaults to 5.0.
+ alpha: Width parameter. Defaults to 5.0.
Returns:
- torch.Tensor: Pairwise Morse interaction energies between particles.
- Shape: [n, m]. Each element [i,j] represents the interaction energy between
- particles i and j.
+ Pair energies, shape [n_pairs].
"""
- # Calculate potential energy
energy = epsilon * (1.0 - torch.exp(-alpha * (dr - sigma))).pow(2) - epsilon
-
- # Handle potential numerical instabilities
return torch.where(dr > 0, energy, torch.zeros_like(energy))
def morse_pair_force(
dr: torch.Tensor,
- sigma: float | torch.Tensor = DEFAULT_SIGMA,
- epsilon: float | torch.Tensor = DEFAULT_EPSILON,
- alpha: float | torch.Tensor = DEFAULT_ALPHA,
+ zi: torch.Tensor, # noqa: ARG001
+ zj: torch.Tensor, # noqa: ARG001
+ sigma: torch.Tensor | float = 1.0,
+ epsilon: torch.Tensor | float = 5.0,
+ alpha: torch.Tensor | float = 5.0,
) -> torch.Tensor:
- """Calculate pairwise Morse forces between particles.
-
- Implements the force derived from the Morse potential. The force changes
- from repulsive to attractive at r=sigma.
-
- The functional form is:
- F(r) = 2*alpha*epsilon * exp(-alpha*(r-sigma)) * (1 - exp(-alpha*(r-sigma)))
+ """Morse pair force (negative gradient of energy).
- This is the negative gradient of the Morse potential energy.
+ F(r) = -2αε exp(-α(r-σ)) (1 - exp(-α(r-σ)))
Args:
- dr: Pairwise distances between particles. Shape: [n, m].
- sigma: Distance at which force changes from repulsive to attractive.
- Either a scalar float or tensor of shape [n, m].
- epsilon: Energy scale of the interaction. Either a scalar float or tensor
- of shape [n, m].
- alpha: Controls the force range and stiffness. Either a scalar float or
- tensor of shape [n, m].
+ dr: Pairwise distances.
+ zi: Atomic numbers of first atoms (unused).
+ zj: Atomic numbers of second atoms (unused).
+ sigma: Equilibrium distance. Defaults to 1.0.
+ epsilon: Well depth. Defaults to 5.0.
+ alpha: Width parameter. Defaults to 5.0.
Returns:
- torch.Tensor: Pairwise Morse forces between particles. Shape: [n, m].
- Positive values indicate repulsion, negative values indicate attraction.
+ Pair force magnitudes.
"""
exp_term = torch.exp(-alpha * (dr - sigma))
force = -2.0 * alpha * epsilon * exp_term * (1.0 - exp_term)
-
- # Handle potential numerical instabilities
return torch.where(dr > 0, force, torch.zeros_like(force))
-class MorseModel(ModelInterface):
- """Morse potential energy and force calculator.
-
- Implements the Morse potential for molecular dynamics simulations. This model
- is particularly useful for modeling covalent bonds as it can accurately describe
- bond stretching, breaking, and anharmonic behavior. Unlike the Lennard-Jones
- potential, Morse is often better for cases where accurate dissociation energy
- and bond dynamics are important.
+class MorseModel(PairPotentialModel):
+ """Morse pair potential model.
- Attributes:
- sigma (torch.Tensor): Equilibrium bond length (r_e) in distance units.
- epsilon (torch.Tensor): Dissociation energy (D_e) in energy units.
- alpha (torch.Tensor): Parameter controlling the width/steepness of the potential.
- cutoff (torch.Tensor): Distance cutoff for truncating potential calculation.
- device (torch.device): Device where calculations are performed.
- dtype (torch.dtype): Data type used for calculations.
- compute_forces (bool): Whether to compute atomic forces.
- compute_stress (bool): Whether to compute stress tensor.
- per_atom_energies (bool): Whether to compute per-atom energy decomposition.
- per_atom_stresses (bool): Whether to compute per-atom stress decomposition.
- use_neighbor_list (bool): Whether to use neighbor list optimization.
+ Convenience subclass that fixes the pair function to :func:`morse_pair` so the
+ caller only needs to supply ``sigma``, ``epsilon``, and ``alpha``.
- Examples:
- ```py
- # Basic usage with default parameters
- morse_model = MorseModel(device=torch.device("cuda"))
- results = morse_model(sim_state)
+ Example::
- # Model parameterized for O-H bonds in water, atomic units
- oh_model = MorseModel(
- sigma=0.96,
- epsilon=4.52,
- alpha=2.0,
+ model = MorseModel(
+ sigma=2.55,
+ epsilon=0.436,
+ alpha=1.359,
+ cutoff=6.0,
compute_forces=True,
- compute_stress=True,
)
- ```
+ results = model(sim_state)
"""
def __init__(
self,
- sigma: float | torch.Tensor = 1.0,
- epsilon: float | torch.Tensor = 5.0,
- alpha: float | torch.Tensor = 5.0,
+ sigma: float = 1.0,
+ epsilon: float = 5.0,
+ alpha: float = 5.0,
device: torch.device | None = None,
- dtype: torch.dtype = torch.float32,
- *, # Force keyword-only arguments
- compute_forces: bool = False,
+ dtype: torch.dtype = torch.float64,
+ *,
+ compute_forces: bool = True,
compute_stress: bool = False,
per_atom_energies: bool = False,
per_atom_stresses: bool = False,
- use_neighbor_list: bool = True,
- cutoff: float | torch.Tensor | None = None,
+ neighbor_list_fn: Callable = torchsim_nl,
+ cutoff: float | None = None,
+ retain_graph: bool = False,
) -> None:
- """Initialize the Morse potential calculator.
-
- Creates a model with specified interaction parameters and computational flags.
- The Morse potential is defined by three key parameters: sigma (equilibrium
- distance), epsilon (dissociation energy), and alpha (width control).
+ """Initialize the Morse potential model.
Args:
- sigma (float): Equilibrium bond distance (r_e) in distance units.
- Defaults to 1.0.
- epsilon (float): Dissociation energy (D_e) in energy units.
- Defaults to 5.0.
- alpha (float): Controls the width/steepness of the potential well.
- Larger values create a narrower well. Defaults to 5.0.
- device (torch.device | None): Device to run computations on. If None, uses
- CPU. Defaults to None.
- dtype (torch.dtype): Data type for calculations. Defaults to torch.float32.
- compute_forces (bool): Whether to compute forces. Defaults to False.
- compute_stress (bool): Whether to compute stress tensor. Defaults to False.
- per_atom_energies (bool): Whether to compute per-atom energy decomposition.
- Defaults to False.
- per_atom_stresses (bool): Whether to compute per-atom stress decomposition.
- Defaults to False.
- use_neighbor_list (bool): Whether to use a neighbor list for optimization.
- Significantly faster for large systems. Defaults to True.
- cutoff (float | None): Cutoff distance for interactions in distance units.
- If None, uses 2.5*sigma. Defaults to None.
-
- Examples:
- ```py
- # Basic model with default parameters
- model = MorseModel()
-
- # Model for diatomic hydrogen
- model = MorseModel(
- sigma=0.74, # Å
- epsilon=4.75, # eV
- alpha=1.94, # Steepness parameter
- compute_forces=True,
- )
- ```
-
- Notes:
- The alpha parameter can be related to the harmonic force constant k and
- dissociation energy D_e by: alpha = sqrt(k/(2*D_e))
- """
- super().__init__()
- self._device = device or torch.device("cpu")
- self._dtype = dtype
- self._compute_forces = compute_forces
- self._compute_stress = compute_stress
- self._per_atom_energies = per_atom_energies
- self._per_atom_stresses = per_atom_stresses
- self.use_neighbor_list = use_neighbor_list
- # Convert parameters to tensors
- self.sigma = torch.as_tensor(sigma, dtype=self.dtype, device=self.device)
- self.cutoff = torch.as_tensor(
- cutoff or 2.5 * sigma, dtype=self.dtype, device=self.device
- )
- self.epsilon = torch.as_tensor(epsilon, dtype=self.dtype, device=self.device)
- self.alpha = torch.as_tensor(alpha, dtype=self.dtype, device=self.device)
-
- def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]:
- """Compute Morse potential properties for a single unbatched system.
-
- Internal implementation that processes a single, non-batched simulation state.
- This method handles the core computations of pair interactions, including
- neighbor list construction, distance calculations, and property computation.
-
- Args:
- state (SimState): Single, non-batched simulation state containing atomic
- positions, cell vectors, and other system information.
-
- Returns:
- dict[str, torch.Tensor]: Computed properties:
- - "energy": Total potential energy (scalar)
- - "forces": Atomic forces with shape [n_atoms, 3] (if
- compute_forces=True)
- - "stress": Stress tensor with shape [3, 3] (if compute_stress=True)
- - "energies": Per-atom energies with shape [n_atoms] (if
- per_atom_energies=True)
- - "stresses": Per-atom stresses with shape [n_atoms, 3, 3] (if
- per_atom_stresses=True)
-
- Notes:
- This method can work with both neighbor list and full pairwise calculations.
- In both cases, interactions are truncated at the cutoff distance.
+ sigma: Equilibrium bond distance. Defaults to 1.0.
+ epsilon: Well depth / dissociation energy. Defaults to 5.0.
+ alpha: Width parameter. Defaults to 5.0.
+ device: Device for computations. Defaults to CPU.
+ dtype: Floating-point dtype. Defaults to torch.float32.
+ compute_forces: Whether to compute atomic forces. Defaults to True.
+ compute_stress: Whether to compute the stress tensor. Defaults to False.
+ per_atom_energies: Whether to return per-atom energies. Defaults to False.
+ per_atom_stresses: Whether to return per-atom stresses. Defaults to False.
+ neighbor_list_fn: Neighbor-list constructor. Defaults to torchsim_nl.
+ cutoff: Interaction cutoff. Defaults to 2.5 * sigma.
+ retain_graph: Keep computation graph for differentiable simulation.
"""
- positions = state.positions
- cell = state.row_vector_cell
- cell = cell.squeeze()
- pbc = state.pbc
-
- # Wrap positions into the unit cell
- wrapped_positions = (
- ts.transforms.pbc_wrap_batched(positions, state.cell, state.system_idx, pbc)
- if pbc.any()
- else positions
+ self.sigma = sigma
+ self.epsilon = epsilon
+ self.alpha = alpha
+ pair_fn = functools.partial(morse_pair, sigma=sigma, epsilon=epsilon, alpha=alpha)
+ super().__init__(
+ pair_fn=pair_fn,
+ cutoff=cutoff if cutoff is not None else 2.5 * sigma,
+ device=device,
+ dtype=dtype,
+ compute_forces=compute_forces,
+ compute_stress=compute_stress,
+ per_atom_energies=per_atom_energies,
+ per_atom_stresses=per_atom_stresses,
+ neighbor_list_fn=neighbor_list_fn,
+ reduce_to_half_list=True,
+ retain_graph=retain_graph,
)
-
- if self.use_neighbor_list:
- mapping, _, shifts_idx = torchsim_nl(
- positions=wrapped_positions,
- cell=cell,
- pbc=pbc,
- cutoff=self.cutoff,
- system_idx=state.system_idx,
- )
- # Pass shifts_idx directly - get_pair_displacements will convert them
- dr_vec, distances = transforms.get_pair_displacements(
- positions=wrapped_positions,
- cell=cell,
- pbc=pbc,
- pairs=(mapping[0], mapping[1]),
- shifts=shifts_idx,
- )
- else:
- dr_vec, distances = transforms.get_pair_displacements(
- positions=wrapped_positions,
- cell=cell,
- pbc=pbc,
- )
- mask = torch.eye(
- wrapped_positions.shape[0], dtype=torch.bool, device=self.device
- )
- distances = distances.masked_fill(mask, float("inf"))
- mask = distances < self.cutoff
- i, j = torch.where(mask)
- mapping = torch.stack([j, i])
- dr_vec = dr_vec[mask]
- distances = distances[mask]
-
- # Calculate pair energies and apply cutoff
- pair_energies = morse_pair(
- distances, sigma=self.sigma, epsilon=self.epsilon, alpha=self.alpha
- )
- mask = distances < self.cutoff
- pair_energies = torch.where(mask, pair_energies, torch.zeros_like(pair_energies))
-
- # Initialize results with total energy (sum/2 to avoid double counting)
- results = {"energy": 0.5 * pair_energies.sum()}
-
- if self._per_atom_energies:
- atom_energies = torch.zeros(
- positions.shape[0], dtype=self.dtype, device=self.device
- )
- atom_energies.index_add_(0, mapping[0], 0.5 * pair_energies)
- atom_energies.index_add_(0, mapping[1], 0.5 * pair_energies)
- results["energies"] = atom_energies
-
- if self.compute_forces or self.compute_stress:
- pair_forces = morse_pair_force(
- distances, sigma=self.sigma, epsilon=self.epsilon, alpha=self.alpha
- )
- pair_forces = torch.where(mask, pair_forces, torch.zeros_like(pair_forces))
-
- force_vectors = (pair_forces / distances)[:, None] * dr_vec
-
- if self.compute_forces:
- forces = torch.zeros_like(state.positions)
- forces.index_add_(0, mapping[0], -force_vectors)
- forces.index_add_(0, mapping[1], force_vectors)
- results["forces"] = forces
-
- if self.compute_stress and state.cell is not None:
- stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors)
- volume = torch.abs(torch.linalg.det(state.cell))
-
- results["stress"] = -stress_per_pair.sum(dim=0) / volume
-
- if self._per_atom_stresses:
- atom_stresses = torch.zeros(
- (state.positions.shape[0], 3, 3),
- dtype=self.dtype,
- device=self.device,
- )
- atom_stresses.index_add_(0, mapping[0], -0.5 * stress_per_pair)
- atom_stresses.index_add_(0, mapping[1], -0.5 * stress_per_pair)
- results["stresses"] = atom_stresses / volume
-
- return results
-
- def forward(self, state: ts.SimState, **_kwargs: object) -> dict[str, torch.Tensor]:
- """Compute Morse potential energies, forces, and stresses for a system.
-
- Main entry point for Morse potential calculations that handles batched states
- by dispatching each batch to the unbatched implementation and combining results.
-
- Args:
- state (SimState): Input state containing atomic positions, cell vectors,
- and other system information.
- **_kwargs: Unused; accepted for interface compatibility.
-
- Returns:
- dict[str, torch.Tensor]: Computed properties:
- - "energy": Potential energy with shape [n_systems]
- - "forces": Atomic forces with shape [n_atoms, 3]
- (if compute_forces=True)
- - "stress": Stress tensor with shape [n_systems, 3, 3]
- (if compute_stress=True)
- - May include additional outputs based on configuration
-
- Raises:
- ValueError: If batch cannot be inferred for multi-cell systems.
-
- Examples:
- ```py
- # Compute properties for a simulation state
- model = MorseModel(compute_forces=True)
- results = model(sim_state)
-
- energy = results["energy"] # Shape: [n_systems]
- forces = results["forces"] # Shape: [n_atoms, 3]
- ```
- """
- outputs = [self.unbatched_forward(state[i]) for i in range(state.n_systems)]
- properties = outputs[0]
-
- # we always return tensors
- # per atom properties are returned as (atoms, ...) tensors
- # global properties are returned as shape (..., n) tensors
- results: dict[str, torch.Tensor] = {}
- for key in ("stress", "energy"):
- if key in properties:
- results[key] = torch.stack([out[key] for out in outputs])
- for key in ("forces",):
- if key in properties:
- results[key] = torch.cat([out[key] for out in outputs], dim=0)
-
- return results
diff --git a/torch_sim/models/pair_potential.py b/torch_sim/models/pair_potential.py
index a537e6e9..1db1c2c9 100644
--- a/torch_sim/models/pair_potential.py
+++ b/torch_sim/models/pair_potential.py
@@ -1,4 +1,4 @@
-"""General batched pair potential model and standard pair interaction functions.
+"""General batched pair potential and pair forces models.
This module provides :class:`PairPotentialModel`, a flexible wrapper that turns any
pairwise energy function into a full TorchSim model with forces (via autograd) and
@@ -9,26 +9,63 @@
(e.g. the asymmetric particle-life interaction) that cannot be expressed as the
gradient of a scalar energy.
-Standard pair energy functions (all JIT-compatible):
+The pair function signature required by :class:`PairPotentialModel` is:
+``pair_fn(distances, atomic_numbers_i, atomic_numbers_j) -> pair_energies``,
+where all arguments are 1-D tensors of length n_pairs and the return value is a
+1-D tensor of pair energies. Additional parameters (e.g., ``sigma``, ``epsilon``)
+can be bound using :func:`functools.partial`.
+
+Notes:
+ - The ``cutoff`` parameter determines the neighbor list construction range.
+ Pairs beyond the cutoff are excluded from energy/force calculations. If your
+ potential has its own natural cutoff (e.g., WCA potential), ensure the model's
+ ``cutoff`` is at least as large.
+ - The ``atomic_numbers_i`` and ``atomic_numbers_j`` arguments are provided for
+ type-dependent potentials, but can be ignored (e.g., with ``# noqa: ARG001``)
+ for type-independent potentials like Lennard-Jones.
+ - The ``dtype`` of the SimState must match the model's ``dtype``. The model will
+ raise a ``TypeError`` if they don't match.
+ - Use ``reduce_to_half_list=True`` for symmetric potentials to halve computation
+ time. Only use ``False`` for asymmetric interactions or when you need the
+ full neighbor list for other purposes.
-* :func:`lj_pair` — Lennard-Jones 12-6
-* :func:`morse_pair` — Morse potential
-* :func:`soft_sphere_pair` — soft-sphere repulsion
-* :func:`particle_life_pair_force` — asymmetric particle-life force (use with
- :class:`PairForcesModel`)
Example::
- from torch_sim.models.pair_potential import PairPotentialModel, lj_pair
+ from torch_sim.models.pair_potential import PairPotentialModel
+ from torch_sim import io
+ from ase.build import bulk
import functools
+ import torch
+
+
+ def bmhtf_pair(dr, zi, zj, A, B, C, D, sigma):
+ # Born-Meyer-Huggins-Tosi-Fumi (BMHTF) potential for ionic crystals
+ # V(r) = A * exp(B * (sigma - r)) - C/r^6 - D/r^8
+ exp_term = A * torch.exp(B * (sigma - dr))
+ r6_term = C / dr.pow(6)
+ r8_term = D / dr.pow(8)
+ energy = exp_term - r6_term - r8_term
+ return torch.where(dr > 0, energy, torch.zeros_like(energy))
+
+
+ # Na-Cl interaction parameters
+ fn = functools.partial(
+ bmhtf_pair,
+ A=20.3548,
+ B=3.1546,
+ C=674.4793,
+ D=837.0770,
+ sigma=2.755,
+ )
+ model = PairPotentialModel(pair_fn=fn, cutoff=10.0)
- fn = functools.partial(lj_pair, sigma=1.0, epsilon=1.0)
- model = PairPotentialModel(pair_fn=fn, cutoff=2.5)
+ # Create NaCl structure using ASE
+ nacl_atoms = bulk("NaCl", "rocksalt", a=5.64)
+ sim_state = io.atoms_to_state(nacl_atoms, device=torch.device("cpu"))
results = model(sim_state)
"""
-# ruff: noqa: RUF002
-
from __future__ import annotations
from typing import TYPE_CHECKING
@@ -46,208 +83,6 @@
from torch_sim.state import SimState
-@torch.jit.script
-def lj_pair(
- dr: torch.Tensor,
- zi: torch.Tensor, # noqa: ARG001
- zj: torch.Tensor, # noqa: ARG001
- sigma: float = 1.0,
- epsilon: float = 1.0,
-) -> torch.Tensor:
- """Lennard-Jones 12-6 pair energy.
-
- V(r) = 4ε[(σ/r)¹² - (σ/r)⁶]
-
- Args:
- dr: Pairwise distances, shape [n_pairs].
- zi: Atomic numbers of first atoms (unused, for interface compatibility).
- zj: Atomic numbers of second atoms (unused, for interface compatibility).
- sigma: Length scale. Defaults to 1.0.
- epsilon: Energy scale. Defaults to 1.0.
-
- Returns:
- Pair energies, shape [n_pairs].
- """
- idr6 = (sigma / dr).pow(6)
- return 4.0 * epsilon * (idr6 * idr6 - idr6)
-
-
-@torch.jit.script
-def morse_pair(
- dr: torch.Tensor,
- zi: torch.Tensor, # noqa: ARG001
- zj: torch.Tensor, # noqa: ARG001
- sigma: float = 1.0,
- epsilon: float = 5.0,
- alpha: float = 5.0,
-) -> torch.Tensor:
- """Morse pair energy.
-
- V(r) = ε(1 - exp(-α(r - σ)))² - ε
-
- Args:
- dr: Pairwise distances, shape [n_pairs].
- zi: Atomic numbers of first atoms (unused).
- zj: Atomic numbers of second atoms (unused).
- sigma: Equilibrium bond distance. Defaults to 1.0.
- epsilon: Well depth / dissociation energy. Defaults to 5.0.
- alpha: Width parameter. Defaults to 5.0.
-
- Returns:
- Pair energies, shape [n_pairs].
- """
- return epsilon * (1.0 - torch.exp(-alpha * (dr - sigma))).pow(2) - epsilon
-
-
-@torch.jit.script
-def soft_sphere_pair(
- dr: torch.Tensor,
- zi: torch.Tensor, # noqa: ARG001
- zj: torch.Tensor, # noqa: ARG001
- sigma: float = 1.0,
- epsilon: float = 1.0,
- alpha: float = 2.0,
-) -> torch.Tensor:
- """Soft-sphere repulsive pair energy (zero beyond sigma).
-
- V(r) = ε/α * (1 - r/σ)^α for r < σ, else 0
-
- Args:
- dr: Pairwise distances, shape [n_pairs].
- zi: Atomic numbers of first atoms (unused).
- zj: Atomic numbers of second atoms (unused).
- sigma: Interaction diameter / cutoff. Defaults to 1.0.
- epsilon: Energy scale. Defaults to 1.0.
- alpha: Repulsion exponent. Defaults to 2.0.
-
- Returns:
- Pair energies, shape [n_pairs].
- """
- energy = epsilon / alpha * (1.0 - dr / sigma).pow(alpha)
- return torch.where(dr < sigma, energy, torch.zeros_like(energy))
-
-
-@torch.jit.script
-def particle_life_pair_force(
- dr: torch.Tensor,
- zi: torch.Tensor, # noqa: ARG001
- zj: torch.Tensor, # noqa: ARG001
- A: float = 1.0,
- beta: float = 0.3,
- sigma: float = 1.0,
-) -> torch.Tensor:
- """Asymmetric particle-life scalar force magnitude.
-
- This is a *force* function (not an energy), intended for use with
- :class:`PairForcesModel`.
-
- Args:
- dr: Pairwise distances, shape [n_pairs].
- zi: Atomic numbers of first atoms (unused).
- zj: Atomic numbers of second atoms (unused).
- A: Interaction amplitude. Defaults to 1.0.
- beta: Inner radius. Defaults to 0.3.
- sigma: Outer radius / cutoff. Defaults to 1.0.
-
- Returns:
- Scalar force magnitudes, shape [n_pairs].
- """
- inner_mask = dr < beta
- outer_mask = (dr >= beta) & (dr < sigma)
- inner_force = dr / beta - 1.0
- outer_force = A * (1.0 - torch.abs(2.0 * dr - 1.0 - beta) / (1.0 - beta))
- return torch.where(inner_mask, inner_force, torch.zeros_like(dr)) + torch.where(
- outer_mask, outer_force, torch.zeros_like(dr)
- )
-
-
-class MultiSoftSpherePairFn(torch.nn.Module):
- """Species-dependent soft-sphere pair energy function.
-
- Holds per-species-pair parameter matrices and looks up sigma, epsilon, and alpha
- for each interacting pair via their atomic numbers. Pass an instance to
- :class:`PairPotentialModel`.
-
- Example::
-
- fn = MultiSoftSpherePairFn(
- atomic_numbers=torch.tensor([18, 36]), # Ar and Kr
- sigma_matrix=torch.tensor([[3.4, 3.6], [3.6, 3.7]]),
- epsilon_matrix=torch.tensor([[0.01, 0.012], [0.012, 0.014]]),
- )
- model = PairPotentialModel(pair_fn=fn, cutoff=float(fn.sigma_matrix.max()))
- """
-
- def __init__(
- self,
- atomic_numbers: torch.Tensor,
- sigma_matrix: torch.Tensor,
- epsilon_matrix: torch.Tensor,
- alpha_matrix: torch.Tensor | None = None,
- ) -> None:
- """Initialize species-dependent soft-sphere parameters.
-
- Args:
- atomic_numbers: 1-D tensor of the unique atomic numbers present, used to
- map ``zi``/``zj`` to row/column indices. Shape: [n_species].
- sigma_matrix: Symmetric matrix of interaction diameters. Shape:
- [n_species, n_species].
- epsilon_matrix: Symmetric matrix of energy scales. Shape:
- [n_species, n_species].
- alpha_matrix: Symmetric matrix of repulsion exponents. If None, defaults
- to 2.0 for all pairs. Shape: [n_species, n_species].
- """
- super().__init__()
- self.z_to_idx: torch.Tensor
- self.atomic_numbers: torch.Tensor
- self.sigma_matrix: torch.Tensor
- self.epsilon_matrix: torch.Tensor
- self.alpha_matrix: torch.Tensor
-
- n = len(atomic_numbers)
- if sigma_matrix.shape != (n, n):
- raise ValueError(f"sigma_matrix must have shape ({n}, {n})")
- if epsilon_matrix.shape != (n, n):
- raise ValueError(f"epsilon_matrix must have shape ({n}, {n})")
- if alpha_matrix is not None and alpha_matrix.shape != (n, n):
- raise ValueError(f"alpha_matrix must have shape ({n}, {n})")
-
- self.register_buffer("atomic_numbers", atomic_numbers)
- self.register_buffer("sigma_matrix", sigma_matrix)
- self.register_buffer("epsilon_matrix", epsilon_matrix)
- self.register_buffer(
- "alpha_matrix",
- alpha_matrix if alpha_matrix is not None else torch.full((n, n), 2.0),
- )
- # Build a lookup table: atomic_number -> species index
- max_z = int(atomic_numbers.max().item()) + 1
- z_to_idx = torch.full((max_z,), -1, dtype=torch.long)
- for idx, z in enumerate(atomic_numbers.tolist()):
- z_to_idx[int(z)] = idx
- self.register_buffer("z_to_idx", z_to_idx)
-
- def forward(
- self, dr: torch.Tensor, zi: torch.Tensor, zj: torch.Tensor
- ) -> torch.Tensor:
- """Compute per-pair soft-sphere energies using species lookup.
-
- Args:
- dr: Pairwise distances, shape [n_pairs].
- zi: Atomic numbers of first atoms, shape [n_pairs].
- zj: Atomic numbers of second atoms, shape [n_pairs].
-
- Returns:
- Pair energies, shape [n_pairs].
- """
- idx_i = self.z_to_idx[zi]
- idx_j = self.z_to_idx[zj]
- sigma = self.sigma_matrix[idx_i, idx_j]
- epsilon = self.epsilon_matrix[idx_i, idx_j]
- alpha = self.alpha_matrix[idx_i, idx_j]
- energy = epsilon / alpha * (1.0 - dr / sigma).pow(alpha)
- return torch.where(dr < sigma, energy, torch.zeros_like(energy))
-
-
def full_to_half_list(
mapping: torch.Tensor,
system_mapping: torch.Tensor,
@@ -398,7 +233,7 @@ def _virial_stress(
volumes = torch.abs(torch.linalg.det(row_cell))
stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors)
stress = torch.zeros((n_systems, 3, 3), dtype=dtype, device=device)
- stress.index_add_(0, system_mapping, -stress_per_pair)
+ stress = stress.index_add(0, system_mapping, -stress_per_pair)
stress = stress / volumes[:, None, None]
return stress, stress_per_pair, volumes
@@ -428,17 +263,18 @@ def _accumulate_stress(
dtype,
device,
)
- stress_scale = 2.0 if half else 1.0
+ stress_scale = 1.0 if half else 0.5
out: dict[str, torch.Tensor] = {"stress": stress * stress_scale}
if per_atom:
- # Half list: each pair once → weight 1.0 per endpoint.
- # Full list: each pair twice (i→j and j→i) → weight 0.5 per endpoint.
- w = 1.0 if half else 0.5
+ # Each endpoint (i and j) gets half the pair's contribution.
+ # Half list: each unique pair appears once → w = 0.5.
+ # Full list: each unique pair appears twice → w = 0.25.
+ w = 0.5 if half else 0.25
n_atoms = positions.shape[0]
atom_stresses = torch.zeros((n_atoms, 3, 3), dtype=dtype, device=device)
- atom_stresses.index_add_(0, mapping[0], -w * stress_per_pair)
- atom_stresses.index_add_(0, mapping[1], -w * stress_per_pair)
+ atom_stresses = atom_stresses.index_add(0, mapping[0], -w * stress_per_pair)
+ atom_stresses = atom_stresses.index_add(0, mapping[1], -w * stress_per_pair)
out["stresses"] = atom_stresses / volumes[system_idx, None, None]
return out
@@ -451,7 +287,10 @@ class PairPotentialModel(ModelInterface):
callable of the form ``pair_fn(distances, atomic_numbers_i, atomic_numbers_j) ->
pair_energies``, where all arguments are 1-D tensors of length n_pairs and the
return value is a 1-D tensor of pair energies. Forces are obtained analytically
- via autograd.
+ via autograd by differentiating the energy with respect to positions.
+
+ When stress is computed, it uses the virial formula: σ = -1/V Σ_{ij} r_ij ⊗ f_ij,
+ where r_ij is the pair displacement vector and f_ij is the force vector.
Example::
@@ -467,16 +306,17 @@ def lj_fn(dr, zi, zj):
def __init__(
self,
pair_fn: Callable,
+ *,
cutoff: float,
device: torch.device | None = None,
- dtype: torch.dtype = torch.float32,
- *,
+ dtype: torch.dtype = torch.float64,
compute_forces: bool = True,
compute_stress: bool = False,
per_atom_energies: bool = False,
per_atom_stresses: bool = False,
neighbor_list_fn: Callable = torchsim_nl,
reduce_to_half_list: bool = False,
+ retain_graph: bool = False,
) -> None:
"""Initialize the pair potential model.
@@ -496,6 +336,10 @@ def __init__(
before computing interactions. Halves pair operations and makes
accumulation patterns unambiguous. Only valid for symmetric pair
functions; do not use for asymmetric interactions. Defaults to False.
+ retain_graph: If True, keep the computation graph after computing forces
+ so that the energy can still be differentiated w.r.t. model parameters
+ (e.g. for differentiable simulation / meta-optimization).
+ Defaults to False.
"""
super().__init__()
self._device = device or torch.device("cpu")
@@ -508,6 +352,7 @@ def __init__(
self.neighbor_list_fn = neighbor_list_fn
self.cutoff = torch.tensor(cutoff, dtype=dtype, device=self._device)
self.reduce_to_half_list = reduce_to_half_list
+ self.retain_graph = retain_graph
def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]:
"""Compute pair-potential properties with batched tensor operations.
@@ -520,7 +365,17 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
dict with keys ``"energy"`` (shape ``[n_systems]``), optionally
``"forces"`` (``[n_atoms, 3]``), ``"stress"`` (``[n_systems, 3, 3]``),
``"energies"`` (``[n_atoms]``), ``"stresses"`` (``[n_atoms, 3, 3]``).
+
+ Raises:
+ TypeError: If the SimState's dtype does not match the model's dtype.
"""
+ if state.dtype != self._dtype:
+ raise TypeError(
+ f"SimState dtype {state.dtype} does not match model dtype {self._dtype}. "
+ f"Either set the model dtype to {state.dtype} or convert the SimState "
+ f"to {self._dtype} using sim_state.to(dtype={self._dtype})."
+ )
+ dtype = self._dtype
half = self.reduce_to_half_list
(
positions,
@@ -555,16 +410,16 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
ew = 1.0 if half else 0.5
results: dict[str, torch.Tensor] = {}
- energy = torch.zeros(n_systems, dtype=self._dtype, device=self._device)
- energy.index_add_(0, system_mapping, ew * pair_energies)
+ energy = torch.zeros(n_systems, dtype=dtype, device=self._device)
+ energy = energy.index_add(0, system_mapping, ew * pair_energies)
results["energy"] = energy
if self.per_atom_energies:
atom_energies = torch.zeros(
- positions.shape[0], dtype=self._dtype, device=self._device
+ positions.shape[0], dtype=dtype, device=self._device
)
- atom_energies.index_add_(0, mapping[0], ew * pair_energies)
- atom_energies.index_add_(0, mapping[1], ew * pair_energies)
+ atom_energies = atom_energies.index_add(0, mapping[0], ew * pair_energies)
+ atom_energies = atom_energies.index_add(0, mapping[1], ew * pair_energies)
results["energies"] = atom_energies
if need_grad:
@@ -572,6 +427,7 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
pair_energies.sum(),
dist_for_grad,
create_graph=False,
+ retain_graph=self.retain_graph,
)
safe_dist = torch.where(distances > 0, distances, torch.ones_like(distances))
# force_vectors = -dV/dr * r̂_ij: positive (repulsive) pushes j away from i.
@@ -581,13 +437,13 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
forces = torch.zeros_like(positions)
if half:
# Half list: each pair once → apply Newton's third law explicitly.
- forces.index_add_(0, mapping[0], -force_vectors)
- forces.index_add_(0, mapping[1], force_vectors)
+ forces = forces.index_add(0, mapping[0], -force_vectors)
+ forces = forces.index_add(0, mapping[1], force_vectors)
else:
# Full list: atom i appears as mapping[0] for every i→j pair,
# covering all its neighbors. mapping[1] accumulation would
# double-count, so we only accumulate on the source atom.
- forces.index_add_(0, mapping[0], -force_vectors)
+ forces = forces.index_add(0, mapping[0], -force_vectors)
results["forces"] = forces
if self._compute_stress:
@@ -601,14 +457,17 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
force_vectors,
row_cell,
n_systems,
- self._dtype,
+ dtype,
self._device,
half=half,
per_atom=self.per_atom_stresses,
)
)
- return {k: v.detach() for k, v in results.items()}
+ if not self.retain_graph:
+ results = {k: v.detach() for k, v in results.items()}
+
+ return results
class PairForcesModel(ModelInterface):
@@ -622,12 +481,17 @@ class PairForcesModel(ModelInterface):
Forces are accumulated as:
F_i += -f_ij * r̂_ij, F_j += +f_ij * r̂_ij
+ Note:
+ Unlike :class:`PairPotentialModel`, this class does not compute energies
+ (returns zeros) since there is no underlying energy function. Use
+ :class:`PairPotentialModel` when your interaction can be expressed as an
+ energy function, as it provides automatic force computation via autograd
+ and is generally more efficient.
+
Example::
- from torch_sim.models.pair_potential import (
- PairForcesModel,
- particle_life_pair_force,
- )
+ from torch_sim.models.particle_life import particle_life_pair_force
+ from torch_sim.models.pair_potential import PairForcesModel
import functools
fn = functools.partial(particle_life_pair_force, A=1.0, beta=0.3, sigma=1.0)
@@ -638,10 +502,10 @@ class PairForcesModel(ModelInterface):
def __init__(
self,
force_fn: Callable,
+ *,
cutoff: float,
device: torch.device | None = None,
- dtype: torch.dtype = torch.float32,
- *,
+ dtype: torch.dtype = torch.float64,
compute_stress: bool = False,
per_atom_stresses: bool = False,
neighbor_list_fn: Callable = torchsim_nl,
@@ -686,7 +550,17 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
``"forces"`` (shape ``[n_atoms, 3]``), and optionally ``"stress"``
(shape ``[n_systems, 3, 3]``) and ``"stresses"``
(shape ``[n_atoms, 3, 3]``).
+
+ Raises:
+ TypeError: If the SimState's dtype does not match the model's dtype.
"""
+ if state.dtype != self._dtype:
+ raise TypeError(
+ f"SimState dtype {state.dtype} does not match model dtype {self._dtype}. "
+ f"Either set the model dtype to {state.dtype} or convert the SimState "
+ f"to {self._dtype} using sim_state.to(dtype={self._dtype})."
+ )
+ dtype = self._dtype
half = self.reduce_to_half_list
(
positions,
@@ -715,11 +589,11 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
force_vectors = (pair_forces / safe_dist)[:, None] * dr_vec
forces = torch.zeros_like(positions)
- forces.index_add_(0, mapping[0], -force_vectors)
- forces.index_add_(0, mapping[1], force_vectors)
+ forces = forces.index_add(0, mapping[0], -force_vectors)
+ forces = forces.index_add(0, mapping[1], force_vectors)
results: dict[str, torch.Tensor] = {
- "energy": torch.zeros(n_systems, dtype=self._dtype, device=self._device),
+ "energy": torch.zeros(n_systems, dtype=dtype, device=self._device),
"forces": forces,
}
@@ -734,7 +608,7 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
force_vectors,
row_cell,
n_systems,
- self._dtype,
+ dtype,
self._device,
half=half,
per_atom=self.per_atom_stresses,
diff --git a/torch_sim/models/particle_life.py b/torch_sim/models/particle_life.py
index c227f272..d5e53767 100644
--- a/torch_sim/models/particle_life.py
+++ b/torch_sim/models/particle_life.py
@@ -1,94 +1,75 @@
-"""Particle life model for computing forces between particles."""
+"""Particle life model.
-import torch
-
-import torch_sim as ts
-from torch_sim import transforms
-from torch_sim.models.interface import ModelInterface
-from torch_sim.neighbors import torchsim_nl
+Thin wrapper around :class:`~torch_sim.models.pair_potential.PairForcesModel` with
+the :func:`particle_life_pair_force` force function
+baked in.
+Example::
-DEFAULT_BETA = torch.tensor(0.3)
-DEFAULT_SIGMA = torch.tensor(1.0)
+ model = ParticleLifeModel(sigma=1.0, epsilon=1.0, beta=0.3, cutoff=1.0)
+ results = model(sim_state)
+"""
+from __future__ import annotations
-def asymmetric_particle_pair_force(
- dr: torch.Tensor,
- A: torch.Tensor,
- beta: torch.Tensor = DEFAULT_BETA,
- sigma: torch.Tensor = DEFAULT_SIGMA,
-) -> torch.Tensor:
- """Asymmetric interaction between particles.
-
- Args:
- dr: A tensor of shape [n, m] of pairwise distances between particles.
- A: Interaction scale. Either a float scalar or a tensor of shape [n, m].
- beta: Inner radius of the interaction. Either a float scalar or tensor of
- shape [n, m].
- sigma: Outer radius of the interaction. Either a float scalar or tensor of
- shape [n, m].
+import functools
+from collections.abc import Callable # noqa: TC003
- Returns:
- torch.Tensor: Energies with shape [n, m].
- """
- inner_mask = dr < beta
- outer_mask = (dr < sigma) & (dr > beta)
-
- def inner_force_fn(dr: torch.Tensor) -> torch.Tensor:
- return dr / beta - 1
-
- def intermediate_force_fn(dr: torch.Tensor) -> torch.Tensor:
- return A * (1 - torch.abs(2 * dr - 1 - beta) / (1 - beta))
+import torch
- return torch.where(inner_mask, inner_force_fn(dr), 0) + torch.where(
- outer_mask,
- intermediate_force_fn(dr),
- 0,
- )
+from torch_sim.models.pair_potential import PairForcesModel
+from torch_sim.neighbors import torchsim_nl
-def asymmetric_particle_pair_force_jit(
+def particle_life_pair_force(
dr: torch.Tensor,
- A: torch.Tensor,
- beta: torch.Tensor = DEFAULT_BETA,
- sigma: torch.Tensor = DEFAULT_SIGMA,
+ zi: torch.Tensor, # noqa: ARG001
+ zj: torch.Tensor, # noqa: ARG001
+ A: torch.Tensor | float = 1.0,
+ beta: torch.Tensor | float = 0.3,
+ sigma: torch.Tensor | float = 1.0,
) -> torch.Tensor:
- """Asymmetric interaction between particles.
+ """Asymmetric particle-life scalar force magnitude.
+
+ This is a *force* function (not an energy), intended for use with
+ :class:`PairForcesModel`.
Args:
- dr: A tensor of shape [n, m] of pairwise distances between particles.
- A: Interaction scale. Either a float scalar or a tensor of shape [n, m].
- beta: Inner radius of the interaction. Either a float scalar or tensor of
- shape [n, m].
- sigma: Outer radius of the interaction. Either a float scalar or tensor of
- shape [n, m].
+ dr: Pairwise distances, shape [n_pairs].
+ zi: Atomic numbers of first atoms (unused).
+ zj: Atomic numbers of second atoms (unused).
+ A: Interaction amplitude. Defaults to 1.0.
+ beta: Inner radius. Defaults to 0.3.
+ sigma: Outer radius / cutoff. Defaults to 1.0.
Returns:
- torch.Tensor: Energies with shape [n, m].
+ Scalar force magnitudes, shape [n_pairs].
"""
inner_mask = dr < beta
- outer_mask = (dr < sigma) & (dr > beta)
-
- # Calculate inner forces directly
- inner_forces = torch.where(inner_mask, dr / beta - 1, torch.zeros_like(dr))
-
- # Calculate outer forces directly
- outer_forces = torch.where(
- outer_mask,
- A * (1 - torch.abs(2 * dr - 1 - beta) / (1 - beta)),
- torch.zeros_like(dr),
+ outer_mask = (dr >= beta) & (dr < sigma)
+ inner_force = dr / beta - 1.0
+ outer_force = A * (1.0 - torch.abs(2.0 * dr - 1.0 - beta) / (1.0 - beta))
+ return torch.where(inner_mask, inner_force, torch.zeros_like(dr)) + torch.where(
+ outer_mask, outer_force, torch.zeros_like(dr)
)
- return inner_forces + outer_forces
+class ParticleLifeModel(PairForcesModel):
+ """Asymmetric particle-life force model.
-class ParticleLifeModel(ModelInterface):
- """Calculator for asymmetric particle interaction.
+ Convenience subclass that fixes the force function to
+ :func:`particle_life_pair_force` so the caller only needs to supply
+ ``sigma``, ``epsilon`` (amplitude), and ``beta``.
- This model implements an asymmetric interaction between particles based on
- distance-dependent forces. The interaction is defined by three parameters:
- sigma, epsilon, and beta.
+ Example::
+ model = ParticleLifeModel(
+ sigma=1.0,
+ epsilon=1.0,
+ beta=0.3,
+ cutoff=1.0,
+ )
+ results = model(sim_state)
"""
def __init__(
@@ -97,192 +78,45 @@ def __init__(
epsilon: float = 1.0,
beta: float = 0.3,
device: torch.device | None = None,
- dtype: torch.dtype = torch.float32,
- *, # Force keyword-only arguments
- compute_forces: bool = False,
+ dtype: torch.dtype = torch.float64,
+ *,
compute_stress: bool = False,
- per_atom_energies: bool = False,
per_atom_stresses: bool = False,
- use_neighbor_list: bool = True,
+ neighbor_list_fn: Callable = torchsim_nl,
cutoff: float | None = None,
+ **kwargs: object, # noqa: ARG002
) -> None:
- """Initialize the calculator.
+ """Initialize the particle life model.
Args:
- sigma: Outer radius of the interaction.
- epsilon: Interaction scale.
- beta: Inner radius of the interaction.
- device: Device for computation.
- dtype: Data type for tensors.
- compute_forces: Whether to compute forces.
- compute_stress: Whether to compute stress tensor.
- per_atom_energies: Whether to compute per-atom energies.
- per_atom_stresses: Whether to compute per-atom stresses.
- use_neighbor_list: Whether to use neighbor list optimization.
- cutoff: Interaction cutoff distance. Defaults to 2.5 * sigma.
+ sigma: Outer radius of the interaction. Defaults to 1.0.
+ epsilon: Interaction amplitude (``A`` parameter). Defaults to 1.0.
+ beta: Inner radius of the interaction. Defaults to 0.3.
+ device: Device for computations. Defaults to CPU.
+ dtype: Floating-point dtype. Defaults to torch.float32.
+ compute_forces: Accepted for backward compatibility (always True).
+ compute_stress: Whether to compute stress tensor. Defaults to False.
+ per_atom_energies: Accepted for backward compatibility (ignored — this
+ is a force-only model, energy is always zero).
+ per_atom_stresses: Whether to return per-atom stresses. Defaults to False.
+ use_neighbor_list: Accepted for backward compatibility (ignored).
+ neighbor_list_fn: Neighbor-list constructor. Defaults to torchsim_nl.
+ cutoff: Interaction cutoff. Defaults to 2.5 * sigma.
+ **kwargs: Additional keyword arguments.
"""
- super().__init__()
- self._device = device or torch.device("cpu")
- self._dtype = dtype
-
- self._compute_forces = compute_forces
- self._compute_stress = compute_stress
- self._per_atom_energies = per_atom_energies
- self._per_atom_stresses = per_atom_stresses
-
- self.use_neighbor_list = use_neighbor_list
-
- # Convert parameters to tensors
- self.sigma = torch.tensor(sigma, dtype=self.dtype, device=self.device)
- self.cutoff = torch.tensor(
- cutoff or 2.5 * sigma, dtype=self.dtype, device=self.device
+ self.sigma_param = sigma
+ self.epsilon = epsilon
+ self.beta = beta
+ force_fn = functools.partial(
+ particle_life_pair_force, A=epsilon, beta=beta, sigma=sigma
)
- self.epsilon = torch.tensor(epsilon, dtype=self.dtype, device=self.device)
- self.beta = torch.tensor(beta, dtype=self.dtype, device=self.device)
-
- def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]:
- """Compute energies and forces for a single unbatched system.
-
- Internal implementation that processes a single, non-batched simulation state.
- This method handles the core computations of pair interactions, neighbor lists,
- and property calculations.
-
- Args:
- state: Single, non-batched simulation state containing atomic positions,
- cell vectors, and other system information.
-
- Returns:
- A dictionary containing the energy, forces, and stresses
- """
- positions = state.positions
- cell = state.row_vector_cell
-
- if cell.dim() == 3: # Check if there is an extra batch dimension
- cell = cell.squeeze(0) # Squeeze the first dimension
-
- # Ensure system_idx exists (create if None for single system)
- system_idx = (
- state.system_idx
- if state.system_idx is not None
- else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device)
+ super().__init__(
+ force_fn=force_fn,
+ cutoff=cutoff if cutoff is not None else 2.5 * sigma,
+ device=device,
+ dtype=dtype,
+ compute_stress=compute_stress,
+ per_atom_stresses=per_atom_stresses,
+ neighbor_list_fn=neighbor_list_fn,
+ reduce_to_half_list=False,
)
-
- # Wrap positions into the unit cell
- wrapped_positions = (
- ts.transforms.pbc_wrap_batched(positions, state.cell, system_idx, state.pbc)
- if state.pbc.any()
- else positions
- )
-
- if self.use_neighbor_list:
- mapping, _, shifts_idx = torchsim_nl(
- positions=wrapped_positions,
- cell=cell,
- pbc=state.pbc,
- cutoff=self.cutoff,
- system_idx=system_idx,
- )
- # Pass shifts_idx directly - get_pair_displacements will convert them
- dr_vec, distances = transforms.get_pair_displacements(
- positions=wrapped_positions,
- cell=cell,
- pbc=state.pbc,
- pairs=(mapping[0], mapping[1]),
- shifts=shifts_idx,
- )
- else:
- # Get all pairwise displacements
- dr_vec, distances = transforms.get_pair_displacements(
- positions=wrapped_positions,
- cell=cell,
- pbc=state.pbc,
- )
- # Mask out self-interactions
- mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device)
- distances = distances.masked_fill(mask, float("inf"))
- # Apply cutoff
- mask = distances < self.cutoff
- # Get valid pairs - match neighbor list convention for pair order
- i, j = torch.where(mask)
- mapping = torch.stack([j, i])
- # Get valid displacements and distances
- dr_vec = dr_vec[mask]
- distances = distances[mask]
-
- # Zero out energies beyond cutoff
- mask = distances < self.cutoff
-
- # Initialize results with total energy (sum/2 to avoid double counting)
- results: dict[str, torch.Tensor] = {
- "energy": torch.tensor(0.0, dtype=self.dtype, device=self.device),
- }
-
- # Calculate forces and apply cutoff
- pair_forces = asymmetric_particle_pair_force_jit(
- dr=distances, A=self.epsilon, sigma=self.sigma, beta=self.beta
- )
- pair_forces = torch.where(mask, pair_forces, torch.zeros_like(pair_forces))
-
- # Project forces along displacement vectors
- force_vectors = (pair_forces / distances)[:, None] * dr_vec
-
- # Initialize forces tensor
- forces = torch.zeros_like(state.positions)
- # Add force contributions (f_ij on i, -f_ij on j)
- forces.index_add_(0, mapping[0], -force_vectors)
- forces.index_add_(0, mapping[1], force_vectors)
- results["forces"] = forces
-
- return results
-
- def forward(self, state: ts.SimState, **_kwargs: object) -> dict[str, torch.Tensor]:
- """Compute particle life energies and forces for a system.
-
- Main entry point for particle life calculations that handles batched states by
- dispatching each batch to the unbatched implementation and combining results.
-
- Args:
- state: Input state containing atomic positions, cell vectors, and other
- system information. Can be a SimState object or a dictionary with the
- same keys.
- **_kwargs: Unused; accepted for interface compatibility.
-
- Returns:
- dict[str, torch.Tensor]: Computed properties:
- - "energy": Potential energy with shape [n_systems]
- - "forces": Atomic forces with shape [n_atoms, 3] (if
- compute_forces=True)
- - "stress": Stress tensor with shape [n_systems, 3, 3] (if
- compute_stress=True)
- - "energies": Per-atom energies with shape [n_atoms] (if
- per_atom_energies=True)
- - "stresses": Per-atom stresses with shape [n_atoms, 3, 3] (if
- per_atom_stresses=True)
-
- Raises:
- ValueError: If batch cannot be inferred for multi-cell systems.
- """
- sim_state = state
-
- if sim_state.system_idx is None and sim_state.cell.shape[0] > 1:
- raise ValueError(
- "system_idx can only be inferred if there is only one system."
- )
-
- outputs = [
- self.unbatched_forward(sim_state[idx]) for idx in range(sim_state.n_systems)
- ]
- properties = outputs[0]
-
- # we always return tensors
- # per atom properties are returned as (atoms, ...) tensors
- # global properties are returned as shape (..., n) tensors
- results: dict[str, torch.Tensor] = {}
- for key in ("stress", "energy"):
- if key in properties:
- results[key] = torch.stack([out[key] for out in outputs])
- for key in ("forces", "energies", "stresses"):
- if key in properties:
- results[key] = torch.cat([out[key] for out in outputs], dim=0)
-
- return results
diff --git a/torch_sim/models/soft_sphere.py b/torch_sim/models/soft_sphere.py
index 6ec7f26e..598eff88 100644
--- a/torch_sim/models/soft_sphere.py
+++ b/torch_sim/models/soft_sphere.py
@@ -1,616 +1,324 @@
-"""Soft sphere model for computing energies, forces and stresses.
+"""Soft sphere potential model.
-This module provides implementations of soft sphere potentials for molecular dynamics
-simulations. Soft sphere potentials are repulsive interatomic potentials that model
-the core repulsion between atoms, avoiding the infinite repulsion of hard sphere models
-while maintaining computational efficiency.
+Thin wrapper around :class:`~torch_sim.models.pair_potential.PairPotentialModel` with
+the :func:`soft_sphere_pair` energy function baked in.
The soft sphere potential has the form:
- V(r) = epsilon * (sigma/r)^alpha
-Where:
-
-* r is the distance between particles
-* sigma is the effective diameter of the particles
-* epsilon controls the energy scale
-* alpha determines the steepness of the repulsion (typically alpha >= 2)
-
-Soft sphere models are particularly useful for:
-
-* Granular matter simulations
-* Modeling excluded volume effects
-* Initial equilibration of dense systems
-* Coarse-grained molecular dynamics
+ V(r) = ε/α * (1 - r/σ)^α for r < σ, else 0
Example::
- # Create a soft sphere model with default parameters
- model = SoftSphereModel()
-
- # Calculate properties for a simulation state
+ model = SoftSphereModel(sigma=1.0, epsilon=1.0, alpha=2.0)
results = model(sim_state)
- energy = results["energy"]
- forces = results["forces"]
# For multiple species with different interaction parameters
multi_model = SoftSphereMultiModel(
- species=particle_types,
- sigma_matrix=size_matrix,
- epsilon_matrix=strength_matrix,
+ atomic_numbers=torch.tensor([18, 36]),
+ sigma_matrix=torch.tensor([[1.0, 0.8], [0.8, 0.6]]),
+ epsilon_matrix=torch.tensor([[1.0, 0.5], [0.5, 2.0]]),
)
results = multi_model(sim_state)
"""
-import torch
+from __future__ import annotations
-import torch_sim as ts
-from torch_sim import transforms
-from torch_sim.models.interface import ModelInterface
-from torch_sim.neighbors import torchsim_nl
+import functools
+from collections.abc import Callable # noqa: TC003
+import torch
-DEFAULT_SIGMA = torch.tensor(1.0)
-DEFAULT_EPSILON = torch.tensor(1.0)
-DEFAULT_ALPHA = torch.tensor(2.0)
+from torch_sim.models.pair_potential import PairPotentialModel
+from torch_sim.neighbors import torchsim_nl
def soft_sphere_pair(
dr: torch.Tensor,
- sigma: float | torch.Tensor = DEFAULT_SIGMA,
- epsilon: float | torch.Tensor = DEFAULT_EPSILON,
- alpha: float | torch.Tensor = DEFAULT_ALPHA,
+ zi: torch.Tensor, # noqa: ARG001
+ zj: torch.Tensor, # noqa: ARG001
+ sigma: torch.Tensor | float = 1.0,
+ epsilon: torch.Tensor | float = 1.0,
+ alpha: torch.Tensor | float = 2.0,
) -> torch.Tensor:
- """Calculate pairwise repulsive energies between soft spheres with finite-range
- interactions.
+ """Soft-sphere repulsive pair energy (zero beyond sigma).
- Computes a soft-core repulsive potential between particle pairs based on
- their separation distance, size, and interaction parameters. The potential
- goes to zero at finite range.
+ V(r) = ε/α * (1 - r/σ)^α for r < σ, else 0
Args:
- dr: Pairwise distances between particles. Shape: [n, m].
- sigma: Particle diameters. Either a scalar float or tensor of shape [n, m]
- for particle-specific sizes.
- epsilon: Energy scale of the interaction. Either a scalar float or tensor
- of shape [n, m] for pair-specific interaction strengths.
- alpha: Stiffness exponent controlling the interaction decay. Either a scalar
- float or tensor of shape [n, m].
+ dr: Pairwise distances, shape [n_pairs].
+ zi: Atomic numbers of first atoms (unused).
+ zj: Atomic numbers of second atoms (unused).
+ sigma: Interaction diameter / cutoff. Defaults to 1.0.
+ epsilon: Energy scale. Defaults to 1.0.
+ alpha: Repulsion exponent. Defaults to 2.0.
Returns:
- torch.Tensor: Pairwise interaction energies between particles. Shape: [n, m].
- Each element [i,j] represents the repulsive energy between particles i and j.
+ Pair energies, shape [n_pairs].
"""
-
- def fn(dr: torch.Tensor) -> torch.Tensor:
- return epsilon / alpha * (1.0 - (dr / sigma)).pow(alpha)
-
- # Create mask for distances within cutoff i.e sigma
- mask = dr < sigma
-
- # Use transforms.safe_mask to compute energies only where mask is True
- return transforms.safe_mask(mask, fn, dr)
+ energy = epsilon / alpha * (1.0 - dr / sigma).pow(alpha)
+ return torch.where(dr < sigma, energy, torch.zeros_like(energy))
def soft_sphere_pair_force(
dr: torch.Tensor,
- sigma: torch.Tensor = DEFAULT_SIGMA,
- epsilon: torch.Tensor = DEFAULT_EPSILON,
- alpha: torch.Tensor = DEFAULT_ALPHA,
+ zi: torch.Tensor, # noqa: ARG001
+ zj: torch.Tensor, # noqa: ARG001
+ sigma: torch.Tensor | float = 1.0,
+ epsilon: torch.Tensor | float = 1.0,
+ alpha: torch.Tensor | float = 2.0,
) -> torch.Tensor:
- """Computes the pairwise repulsive forces between soft spheres with finite range.
+ """Soft-sphere pair force (negative gradient of energy).
- This function implements a soft-core repulsive interaction that smoothly goes to zero
- at the cutoff distance sigma. The force magnitude is controlled by epsilon and its
- stiffness by alpha.
+ F(r) = (ε/σ) (1 - r/σ)^(α-1) for r < σ, else 0
Args:
- dr: A tensor of shape [n, m] containing pairwise distances between particles,
- where n and m represent different particle indices.
- sigma: Particle diameter defining the interaction cutoff distance. Can be either
- a float scalar or a tensor of shape [n, m] for particle-specific diameters.
- epsilon: Energy scale of the interaction. Can be either a float scalar or a
- tensor of shape [n, m] for particle-specific interaction strengths.
- alpha: Exponent controlling the stiffness of the repulsion. Higher values create
- a harder repulsion. Can be either a float scalar or a tensor of shape [n, m].
+ dr: Pairwise distances.
+ zi: Atomic numbers of first atoms (unused).
+ zj: Atomic numbers of second atoms (unused).
+ sigma: Interaction diameter. Defaults to 1.0.
+ epsilon: Energy scale. Defaults to 1.0.
+ alpha: Repulsion exponent. Defaults to 2.0.
Returns:
- torch.Tensor: Forces between particle pairs with shape [n, m]. Forces are zero
- for distances greater than sigma.
+ Pair force magnitudes.
"""
+ force = (epsilon / sigma) * (1.0 - (dr / sigma)).pow(alpha - 1)
+ mask = dr < sigma
+ return torch.where(mask, force, torch.zeros_like(force))
- def fn(dr: torch.Tensor) -> torch.Tensor:
- return (epsilon / sigma) * (1.0 - (dr / sigma)).pow(alpha - 1)
- # Create mask for distances within cutoff i.e sigma
- mask = dr < sigma
+class MultiSoftSpherePairFn(torch.nn.Module):
+ """Species-dependent soft-sphere pair energy function.
- # Use transforms.safe_mask to compute energies only where mask is True
- return transforms.safe_mask(mask, fn, dr)
-
-
-class SoftSphereModel(ModelInterface):
- """Calculator for soft sphere potential energies and forces.
-
- Implements a model for computing properties based on the soft sphere potential,
- which describes purely repulsive interactions between particles. This potential
- is useful for modeling systems where particles should not overlap but don't have
- attractive interactions, such as granular materials and some colloidal systems.
-
- The potential energy between particles i and j is:
- V_ij(r) = epsilon * (sigma/r)^alpha
-
- Attributes:
- sigma (torch.Tensor): Effective particle diameter in distance units.
- epsilon (torch.Tensor): Energy scale parameter in energy units.
- alpha (torch.Tensor): Exponent controlling repulsion steepness (typically ≥ 2).
- cutoff (torch.Tensor): Cutoff distance for interactions.
- use_neighbor_list (bool): Whether to use neighbor list optimization.
- _device (torch.device): Computation device (CPU/GPU).
- _dtype (torch.dtype): Data type for tensor calculations.
- _compute_forces (bool): Whether to compute forces.
- _compute_stress (bool): Whether to compute stress tensor.
- per_atom_energies (bool): Whether to compute per-atom energy decomposition.
- per_atom_stresses (bool): Whether to compute per-atom stress decomposition.
-
- Examples:
- ```py
- # Basic usage with default parameters
- model = SoftSphereModel()
- results = model(sim_state)
+ Holds per-species-pair parameter matrices and looks up sigma, epsilon, and alpha
+ for each interacting pair via their atomic numbers. Pass an instance to
+ :class:`PairPotentialModel`.
- # Custom parameters for colloidal system
- colloid_model = SoftSphereModel(
- sigma=2.0, # particle diameter in nm
- epsilon=10.0, # energy scale in kJ/mol
- alpha=12.0, # steep repulsion for hard colloids
- compute_stress=True,
- )
+ Example::
- # Get forces for a system with periodic boundary conditions
- results = colloid_model(
- ts.SimState(
- positions=positions,
- cell=box_vectors,
- pbc=torch.tensor([True, True, True]),
- )
+ fn = MultiSoftSpherePairFn(
+ atomic_numbers=torch.tensor([18, 36]), # Ar and Kr
+ sigma_matrix=torch.tensor([[3.4, 3.6], [3.6, 3.7]]),
+ epsilon_matrix=torch.tensor([[0.01, 0.012], [0.012, 0.014]]),
)
- forces = results["forces"] # shape: [n_particles, 3]
- ```
+ model = PairPotentialModel(pair_fn=fn, cutoff=float(fn.sigma_matrix.max()))
"""
def __init__(
self,
- sigma: float | torch.Tensor = 1.0,
- epsilon: float | torch.Tensor = 1.0,
- alpha: float | torch.Tensor = 2.0,
- device: torch.device | None = None,
- dtype: torch.dtype = torch.float32,
- *, # Force keyword-only arguments
- compute_forces: bool = True,
- compute_stress: bool = False,
- per_atom_energies: bool = False,
- per_atom_stresses: bool = False,
- use_neighbor_list: bool = True,
- cutoff: float | torch.Tensor | None = None,
+ atomic_numbers: torch.Tensor,
+ sigma_matrix: torch.Tensor,
+ epsilon_matrix: torch.Tensor,
+ alpha_matrix: torch.Tensor | None = None,
) -> None:
- """Initialize the soft sphere model.
-
- Creates a soft sphere model with specified parameters for particle interactions
- and computation options.
+ """Initialize species-dependent soft-sphere parameters.
Args:
- sigma (float): Effective particle diameter. Determines the distance
- scale of the interaction. Defaults to 1.0.
- epsilon (float): Energy scale parameter. Controls the strength of
- the repulsion. Defaults to 1.0.
- alpha (float): Exponent controlling repulsion steepness. Higher values
- create steeper, more hard-sphere-like repulsion. Defaults to 2.0.
- device (torch.device | None): Device for computations. If None, uses CPU.
- Defaults to None.
- dtype (torch.dtype): Data type for calculations. Defaults to torch.float32.
- compute_forces (bool): Whether to compute forces. Defaults to True.
- compute_stress (bool): Whether to compute stress tensor. Defaults to False.
- per_atom_energies (bool): Whether to compute per-atom energy decomposition.
- Defaults to False.
- per_atom_stresses (bool): Whether to compute per-atom stress decomposition.
- Defaults to False.
- use_neighbor_list (bool): Whether to use a neighbor list for optimization.
- Significantly faster for large systems. Defaults to True.
- cutoff (float | None): Cutoff distance for interactions. If None, uses
- the value of sigma. Defaults to None.
-
- Examples:
- ```py
- # Default model
- model = SoftSphereModel()
-
- # WCA-like repulsive potential (derived from Lennard-Jones)
- wca_model = SoftSphereModel(
- sigma=1.0,
- epsilon=1.0,
- alpha=12.0, # Steep repulsion similar to r^-12 term in LJ
- cutoff=2 ** (1 / 6), # WCA cutoff at minimum of LJ potential
- )
- ```
+ atomic_numbers: 1-D tensor of the unique atomic numbers present, used to
+ map ``zi``/``zj`` to row/column indices. Shape: [n_species].
+ sigma_matrix: Symmetric matrix of interaction diameters. Shape:
+ [n_species, n_species].
+ epsilon_matrix: Symmetric matrix of energy scales. Shape:
+ [n_species, n_species].
+ alpha_matrix: Symmetric matrix of repulsion exponents. If None, defaults
+ to 2.0 for all pairs. Shape: [n_species, n_species].
"""
super().__init__()
- self._device = device or torch.device("cpu")
- self._dtype = dtype
- self._compute_forces = compute_forces
- self._compute_stress = compute_stress
- self.per_atom_energies = per_atom_energies
- self.per_atom_stresses = per_atom_stresses
- self.use_neighbor_list = use_neighbor_list
-
- # Convert interaction parameters to tensors with proper dtype/device
- self.sigma = torch.as_tensor(sigma, dtype=dtype, device=self.device)
- self.cutoff = torch.as_tensor(cutoff or sigma, dtype=dtype, device=self.device)
- self.epsilon = torch.as_tensor(epsilon, dtype=dtype, device=self.device)
- self.alpha = torch.as_tensor(alpha, dtype=dtype, device=self.device)
-
- def unbatched_forward(self, state: ts.SimState) -> dict[str, torch.Tensor]:
- """Compute energies and forces for a single unbatched system.
-
- Internal implementation that processes a single, non-batched simulation state.
- This method handles the core computations for pair interactions, including
- neighbor list construction, distance calculations, and property computation.
+ n = len(atomic_numbers)
+ if sigma_matrix.shape != (n, n):
+ raise ValueError(f"sigma_matrix must have shape ({n}, {n})")
+ if epsilon_matrix.shape != (n, n):
+ raise ValueError(f"epsilon_matrix must have shape ({n}, {n})")
+ if alpha_matrix is not None and alpha_matrix.shape != (n, n):
+ raise ValueError(f"alpha_matrix must have shape ({n}, {n})")
+
+ self.register_buffer("atomic_numbers", atomic_numbers)
+ self.sigma_matrix = sigma_matrix
+ self.epsilon_matrix = epsilon_matrix
+ self.alpha_matrix = (
+ alpha_matrix if alpha_matrix is not None else torch.full((n, n), 2.0)
+ )
+ max_z = int(atomic_numbers.max().item()) + 1
+ z_to_idx = torch.full((max_z,), -1, dtype=torch.long)
+ for idx, z in enumerate(atomic_numbers.tolist()):
+ z_to_idx[int(z)] = idx
+ self.z_to_idx: torch.Tensor
+ self.register_buffer("z_to_idx", z_to_idx)
+
+ def forward(
+ self, dr: torch.Tensor, zi: torch.Tensor, zj: torch.Tensor
+ ) -> torch.Tensor:
+ """Compute per-pair soft-sphere energies using species lookup.
Args:
- state (SimState): Single, non-batched simulation state containing atomic
- positions, cell vectors, and other system information.
+ dr: Pairwise distances, shape [n_pairs].
+ zi: Atomic numbers of first atoms, shape [n_pairs].
+ zj: Atomic numbers of second atoms, shape [n_pairs].
Returns:
- dict[str, torch.Tensor]: Computed properties:
- - "energy": Total potential energy (scalar)
- - "forces": Atomic forces with shape [n_atoms, 3] (if
- compute_forces=True)
- - "stress": Stress tensor with shape [3, 3] (if compute_stress=True)
- - "energies": Per-atom energies with shape [n_atoms] (if
- per_atom_energies=True)
- - "stresses": Per-atom stresses with shape [n_atoms, 3, 3] (if
- per_atom_stresses=True)
-
- Notes:
- This method can work with both neighbor list and full pairwise calculations.
- The soft sphere potential is purely repulsive, and forces are truncated at
- the cutoff distance.
+ Pair energies, shape [n_pairs].
"""
- positions = state.positions
- cell = state.row_vector_cell
- cell = cell.squeeze()
- pbc = state.pbc
-
- # Ensure system_idx exists (create if None for single system)
- system_idx = (
- state.system_idx
- if state.system_idx is not None
- else torch.zeros(positions.shape[0], dtype=torch.long, device=self.device)
- )
+ idx_i = self.z_to_idx[zi]
+ idx_j = self.z_to_idx[zj]
+ sigma = self.sigma_matrix[idx_i, idx_j]
+ epsilon = self.epsilon_matrix[idx_i, idx_j]
+ alpha = self.alpha_matrix[idx_i, idx_j]
+ energy = epsilon / alpha * (1.0 - dr / sigma).pow(alpha)
+ return torch.where(dr < sigma, energy, torch.zeros_like(energy))
- # Wrap positions into the unit cell
- wrapped_positions = (
- ts.transforms.pbc_wrap_batched(positions, state.cell, system_idx, pbc)
- if pbc.any()
- else positions
- )
- if self.use_neighbor_list:
- mapping, _, shifts_idx = torchsim_nl(
- positions=wrapped_positions,
- cell=cell,
- pbc=pbc,
- cutoff=self.cutoff,
- system_idx=system_idx,
- )
- # Pass shifts_idx directly - get_pair_displacements will convert them
- dr_vec, distances = transforms.get_pair_displacements(
- positions=wrapped_positions,
- cell=cell,
- pbc=pbc,
- pairs=(mapping[0], mapping[1]),
- shifts=shifts_idx,
- )
-
- else:
- # Direct N^2 computation of all pairs
- dr_vec, distances = transforms.get_pair_displacements(
- positions=wrapped_positions,
- cell=cell,
- pbc=pbc,
- )
- # Remove self-interactions and apply cutoff
- mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device)
- distances = distances.masked_fill(mask, float("inf"))
- mask = distances < self.cutoff
-
- # Get valid pairs and their displacements
- i, j = torch.where(mask)
- mapping = torch.stack([j, i])
- dr_vec = dr_vec[mask]
- distances = distances[mask]
-
- # Calculate pair energies using soft sphere potential
- pair_energies = soft_sphere_pair(
- distances, sigma=self.sigma, epsilon=self.epsilon, alpha=self.alpha
+DEFAULT_SIGMA = torch.tensor(1.0)
+DEFAULT_EPSILON = torch.tensor(1.0)
+DEFAULT_ALPHA = torch.tensor(2.0)
+
+
+class SoftSphereModel(PairPotentialModel):
+ """Soft-sphere repulsive pair potential model.
+
+ Convenience subclass that fixes the pair function to :func:`soft_sphere_pair`
+ so the caller only needs to supply ``sigma``, ``epsilon``, and ``alpha``.
+
+ Example::
+
+ model = SoftSphereModel(
+ sigma=3.405,
+ epsilon=0.0104,
+ alpha=2.0,
+ compute_forces=True,
)
+ results = model(sim_state)
+ """
- # Initialize results with total energy (divide by 2 to avoid double counting)
- results = {"energy": 0.5 * pair_energies.sum()}
-
- if self.per_atom_energies:
- # Compute per-atom energy contributions
- atom_energies = torch.zeros(
- positions.shape[0], dtype=self.dtype, device=self.device
- )
- # Each atom gets half of the pair energy
- atom_energies.index_add_(0, mapping[0], 0.5 * pair_energies)
- atom_energies.index_add_(0, mapping[1], 0.5 * pair_energies)
- results["energies"] = atom_energies
-
- if self.compute_forces or self.compute_stress:
- # Calculate pair forces
- pair_forces = soft_sphere_pair_force(
- distances, sigma=self.sigma, epsilon=self.epsilon, alpha=self.alpha
- )
-
- # Project scalar forces onto displacement vectors
- force_vectors = (pair_forces / distances)[:, None] * dr_vec
-
- if self.compute_forces:
- # Compute atomic forces by accumulating pair contributions
- forces = torch.zeros_like(positions)
- forces.index_add_(0, mapping[0], -force_vectors)
- forces.index_add_(0, mapping[1], force_vectors)
- results["forces"] = forces
-
- if self.compute_stress and cell is not None:
- # Compute stress tensor using virial formula
- stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors)
- volume = torch.abs(torch.linalg.det(cell))
-
- results["stress"] = -stress_per_pair.sum(dim=0) / volume
-
- if self.per_atom_stresses:
- # Compute per-atom stress contributions
- atom_stresses = torch.zeros(
- (positions.shape[0], 3, 3), dtype=self.dtype, device=self.device
- )
- atom_stresses.index_add_(0, mapping[0], -0.5 * stress_per_pair)
- atom_stresses.index_add_(0, mapping[1], -0.5 * stress_per_pair)
- results["stresses"] = atom_stresses / volume
-
- return results
-
- def forward(self, state: ts.SimState, **_kwargs) -> dict[str, torch.Tensor]:
- """Compute soft sphere potential energies, forces, and stresses for a system.
-
- Main entry point for soft sphere potential calculations that handles batched
- states by dispatching each system to the unbatched implementation and combining
- results.
+ def __init__(
+ self,
+ sigma: float = 1.0,
+ epsilon: float = 1.0,
+ alpha: float = 2.0,
+ device: torch.device | None = None,
+ dtype: torch.dtype = torch.float64,
+ *,
+ compute_forces: bool = True,
+ compute_stress: bool = False,
+ per_atom_energies: bool = False,
+ per_atom_stresses: bool = False,
+ neighbor_list_fn: Callable = torchsim_nl,
+ use_neighbor_list: bool = True, # noqa: ARG002
+ cutoff: float | None = None,
+ retain_graph: bool = False,
+ ) -> None:
+ """Initialize the soft sphere model.
Args:
- state (SimState): Input state containing atomic positions, cell vectors,
- and other system information.
- **_kwargs: Unused; accepted for interface compatibility.
-
- Returns:
- dict[str, torch.Tensor]: Computed properties:
- - "energy": Potential energy with shape [n_systems]
- - "forces": Atomic forces with shape [n_atoms, 3]
- (if compute_forces=True)
- - "stress": Stress tensor with shape [n_systems, 3, 3]
- (if compute_stress=True)
- - May include additional outputs based on configuration
-
- Raises:
- ValueError: If system indices cannot be inferred for multi-cell systems.
-
- Examples:
- ```py
- # Compute properties for a simulation state
- model = SoftSphereModel(compute_forces=True)
- results = model(sim_state)
-
- energy = results["energy"] # Shape: [n_systems]
- forces = results["forces"] # Shape: [n_atoms, 3]
- ```
+ sigma: Effective particle diameter. Defaults to 1.0.
+ epsilon: Energy scale parameter. Defaults to 1.0.
+ alpha: Repulsion exponent. Defaults to 2.0.
+ device: Device for computations. Defaults to CPU.
+ dtype: Floating-point dtype. Defaults to torch.float32.
+ compute_forces: Whether to compute atomic forces. Defaults to True.
+ compute_stress: Whether to compute the stress tensor. Defaults to False.
+ per_atom_energies: Whether to return per-atom energies. Defaults to False.
+ per_atom_stresses: Whether to return per-atom stresses. Defaults to False.
+ neighbor_list_fn: Neighbor-list constructor. Defaults to torchsim_nl.
+ use_neighbor_list: Accepted for backward compatibility (ignored).
+ cutoff: Interaction cutoff. Defaults to sigma.
+ retain_graph: Keep computation graph for differentiable simulation.
"""
- # Handle System indices if not provided
- if state.system_idx is None and state.cell.shape[0] > 1:
- raise ValueError(
- "system_idx can only be inferred if there is only one system"
- )
-
- outputs = [self.unbatched_forward(state[i]) for i in range(state.n_systems)]
- properties = outputs[0]
-
- # Combine results
- results: dict[str, torch.Tensor] = {}
- for key in ("stress", "energy"):
- if key in properties:
- results[key] = torch.stack([out[key] for out in outputs])
- for key in ("forces", "energies", "stresses"):
- if key in properties:
- results[key] = torch.cat([out[key] for out in outputs], dim=0)
-
- return results
-
-
-class SoftSphereMultiModel(ModelInterface):
- """Calculator for systems with multiple particle types.
-
- Extends the basic soft sphere model to support multiple particle types with
- different interaction parameters for each pair of particle types. This enables
- simulation of heterogeneous systems like mixtures, composites, or biomolecular
- systems with different interaction strengths between different components.
-
- This model maintains matrices of interaction parameters (sigma, epsilon, alpha)
- where each element [i,j] represents the parameter for interactions between
- particle types i and j.
-
- Attributes:
- species (torch.Tensor): Particle type indices for each particle in the system.
- sigma_matrix (torch.Tensor): Matrix of distance parameters for each pair of types.
- Shape: [n_types, n_types].
- epsilon_matrix (torch.Tensor): Matrix of energy scale parameters for each pair.
- Shape: [n_types, n_types].
- alpha_matrix (torch.Tensor): Matrix of exponents for each pair of types.
- Shape: [n_types, n_types].
- cutoff (torch.Tensor): Maximum interaction distance.
- compute_forces (bool): Whether to compute forces.
- compute_stress (bool): Whether to compute stress tensor.
- per_atom_energies (bool): Whether to compute per-atom energy decomposition.
- per_atom_stresses (bool): Whether to compute per-atom stress decomposition.
- use_neighbor_list (bool): Whether to use neighbor list optimization.
- periodic (bool): Whether to use periodic boundary conditions.
- _device (torch.device): Computation device (CPU/GPU).
- _dtype (torch.dtype): Data type for tensor calculations.
-
- Examples:
- ```py
- # Create a binary mixture with different interaction parameters
- # Define interaction matrices (size 2x2 for binary system)
- sigma_matrix = torch.tensor(
- [
- [1.0, 0.8], # Type 0-0 and 0-1 interactions
- [0.8, 0.6], # Type 1-0 and 1-1 interactions
- ]
- )
+ self.sigma = sigma
+ self.epsilon = epsilon
+ self.alpha = alpha
- epsilon_matrix = torch.tensor(
- [
- [1.0, 0.5], # Type 0-0 and 0-1 interactions
- [0.5, 2.0], # Type 1-0 and 1-1 interactions
- ]
+ pair_fn = functools.partial(
+ soft_sphere_pair, sigma=sigma, epsilon=epsilon, alpha=alpha
)
+ super().__init__(
+ pair_fn=pair_fn,
+ cutoff=cutoff if cutoff is not None else sigma,
+ device=device,
+ dtype=dtype,
+ compute_forces=compute_forces,
+ compute_stress=compute_stress,
+ per_atom_energies=per_atom_energies,
+ per_atom_stresses=per_atom_stresses,
+ neighbor_list_fn=neighbor_list_fn,
+ reduce_to_half_list=True,
+ retain_graph=retain_graph,
+ )
+
- # Particle type assignments (0 or 1 for each particle)
- species = torch.tensor([0, 0, 1, 1, 0, 1])
+class SoftSphereMultiModel(PairPotentialModel):
+ """Multi-species soft-sphere potential model.
+
+ Uses :class:`MultiSoftSpherePairFn` internally
+ to look up per-species-pair parameters from matrices.
+
+ Example::
- # Create the model
model = SoftSphereMultiModel(
- species=species,
- sigma_matrix=sigma_matrix,
- epsilon_matrix=epsilon_matrix,
+ atomic_numbers=torch.tensor([18, 36]),
+ sigma_matrix=torch.tensor([[1.0, 0.8], [0.8, 0.6]]),
+ epsilon_matrix=torch.tensor([[1.0, 0.5], [0.5, 2.0]]),
compute_forces=True,
)
-
- # Compute properties
- results = model(simulation_state)
- ```
+ results = model(sim_state)
"""
def __init__(
self,
- n_species: int,
+ atomic_numbers: torch.Tensor,
sigma_matrix: torch.Tensor | None = None,
epsilon_matrix: torch.Tensor | None = None,
alpha_matrix: torch.Tensor | None = None,
device: torch.device | None = None,
dtype: torch.dtype = torch.float64,
- *, # Force keyword-only arguments
+ *,
pbc: torch.Tensor | bool = True,
compute_forces: bool = True,
compute_stress: bool = False,
per_atom_energies: bool = False,
per_atom_stresses: bool = False,
use_neighbor_list: bool = True,
+ neighbor_list_fn: Callable = torchsim_nl,
cutoff: float | None = None,
+ retain_graph: bool = False,
) -> None:
- """Initialize a soft sphere model for multi-component systems.
-
- Creates a model for systems with multiple particle types, each with potentially
- different interaction parameters.
+ """Initialize the multi-species soft sphere model.
Args:
- n_species (int): Number of particle types.
- sigma_matrix (torch.Tensor | None): Matrix of distance parameters for
- each pair of types. Shape [n_types, n_types]. If None, uses default
- value 1.0 for all pairs. Defaults to None.
- epsilon_matrix (torch.Tensor | None): Matrix of energy scale parameters
- for each pair of types. Shape [n_types, n_types]. If None, uses
- default value 1.0 for all pairs. Defaults to None.
- alpha_matrix (torch.Tensor | None): Matrix of exponents for each pair.
- Shape [n_types, n_types]. If None, uses default value 2.0 for all
- pairs. Defaults to None.
- device (torch.device | None): Device for computations. If None, uses CPU.
- Defaults to None.
- dtype (torch.dtype): Data type for calculations. Defaults to torch.float32.
- pbc (torch.Tensor | bool): Boolean tensor of shape (3,) indicating periodic
- boundary conditions in each axis. If None, all axes are assumed to be
- periodic. Defaults to True.
- compute_forces (bool): Whether to compute forces. Defaults to True.
- compute_stress (bool): Whether to compute stress tensor. Defaults to False.
- per_atom_energies (bool): Whether to compute per-atom energy decomposition.
- Defaults to False.
- per_atom_stresses (bool): Whether to compute per-atom stress decomposition.
- Defaults to False.
- use_neighbor_list (bool): Whether to use a neighbor list for optimization.
- Defaults to True.
- cutoff (float | None): Cutoff distance for interactions. If None, uses
- the maximum value from sigma_matrix. Defaults to None.
-
- Examples:
- ```py
- # Binary polymer mixture with different interactions
- # Polymer A (type 0): larger, softer particles
- # Polymer B (type 1): smaller, harder particles
-
- # Create species assignment (100 particles total)
- species = torch.cat(
- [
- torch.zeros(50, dtype=torch.long), # 50 particles of type 0
- torch.ones(50, dtype=torch.long), # 50 particles of type 1
- ]
- )
-
- # Interaction matrices
- sigma = torch.tensor(
- [
- [1.2, 1.0], # A-A and A-B interactions
- [1.0, 0.8], # B-A and B-B interactions
- ]
- )
-
- epsilon = torch.tensor(
- [
- [1.0, 1.5], # A-A and A-B interactions
- [1.5, 2.0], # B-A and B-B interactions
- ]
- )
-
- # Create model with mixing rules
- model = SoftSphereMultiModel(
- species=species,
- sigma_matrix=sigma,
- epsilon_matrix=epsilon,
- compute_forces=True,
- )
- ```
-
- Notes:
- The interaction matrices must be symmetric for physical consistency
- (e.g., interaction of type 0 with type 1 should be the same as type 1
- with type 0).
+ atomic_numbers: Atomic numbers of atoms in the system. May contain
+ duplicates; only the sorted unique values are used to define
+ species and determine matrix dimensions.
+ sigma_matrix: Symmetric matrix of interaction diameters.
+ Shape [n_species, n_species]. Defaults to 1.0 for all pairs.
+ epsilon_matrix: Symmetric matrix of energy scales.
+ Shape [n_species, n_species]. Defaults to 1.0 for all pairs.
+ alpha_matrix: Symmetric matrix of repulsion exponents.
+ Shape [n_species, n_species]. Defaults to 2.0 for all pairs.
+ device: Device for computations. Defaults to CPU.
+ dtype: Floating-point dtype. Defaults to torch.float64.
+ pbc: Periodic boundary conditions (kept for backward compat). Defaults
+ to True.
+ compute_forces: Whether to compute atomic forces. Defaults to True.
+ compute_stress: Whether to compute the stress tensor. Defaults to False.
+ per_atom_energies: Whether to return per-atom energies. Defaults to False.
+ per_atom_stresses: Whether to return per-atom stresses. Defaults to False.
+ use_neighbor_list: Accepted for backward compatibility (a neighbor list
+ is always used internally). Defaults to True.
+ neighbor_list_fn: Neighbor-list constructor. Defaults to torchsim_nl.
+ cutoff: Interaction cutoff. Defaults to max of sigma_matrix.
+ retain_graph: Keep computation graph for differentiable simulation.
"""
- super().__init__()
- self._device = device or torch.device("cpu")
- self._dtype = dtype
self.pbc = torch.tensor([pbc] * 3) if isinstance(pbc, bool) else pbc
- self._compute_forces = compute_forces
- self._compute_stress = compute_stress
- self.per_atom_energies = per_atom_energies
- self.per_atom_stresses = per_atom_stresses
self.use_neighbor_list = use_neighbor_list
+ unique_z = torch.unique(atomic_numbers).sort().values.long()
+ n_species = len(unique_z)
self.n_species = n_species
- # Initialize parameter matrices with defaults if not provided
- default_sigma = DEFAULT_SIGMA.to(device=self.device, dtype=self.dtype)
- default_epsilon = DEFAULT_EPSILON.to(device=self.device, dtype=self.dtype)
- default_alpha = DEFAULT_ALPHA.to(device=self.device, dtype=self.dtype)
+ _device = device or torch.device("cpu")
+ default_sigma = DEFAULT_SIGMA.to(device=_device, dtype=dtype)
+ default_epsilon = DEFAULT_EPSILON.to(device=_device, dtype=dtype)
+ default_alpha = DEFAULT_ALPHA.to(device=_device, dtype=dtype)
- # Validate matrix shapes match number of species
if sigma_matrix is not None and sigma_matrix.shape != (n_species, n_species):
raise ValueError(f"sigma_matrix must have shape ({n_species}, {n_species})")
if epsilon_matrix is not None and epsilon_matrix.shape != (
@@ -621,246 +329,49 @@ def __init__(
if alpha_matrix is not None and alpha_matrix.shape != (n_species, n_species):
raise ValueError(f"alpha_matrix must have shape ({n_species}, {n_species})")
- # Create parameter matrices, using defaults if not provided
self.sigma_matrix = (
sigma_matrix
if sigma_matrix is not None
else default_sigma
- * torch.ones((n_species, n_species), dtype=dtype, device=device)
+ * torch.ones((n_species, n_species), dtype=dtype, device=_device)
)
self.epsilon_matrix = (
epsilon_matrix
if epsilon_matrix is not None
else default_epsilon
- * torch.ones((n_species, n_species), dtype=dtype, device=device)
+ * torch.ones((n_species, n_species), dtype=dtype, device=_device)
)
self.alpha_matrix = (
alpha_matrix
if alpha_matrix is not None
else default_alpha
- * torch.ones((n_species, n_species), dtype=dtype, device=device)
+ * torch.ones((n_species, n_species), dtype=dtype, device=_device)
)
- # Ensure parameter matrices are symmetric (required for energy conservation)
for matrix_name in ("sigma_matrix", "epsilon_matrix", "alpha_matrix"):
matrix = getattr(self, matrix_name)
if not torch.allclose(matrix, matrix.T):
raise ValueError(f"{matrix_name} is not symmetric")
- # Set interaction cutoff distance
- self.cutoff = torch.tensor(
- cutoff or float(self.sigma_matrix.max()), dtype=dtype, device=device
- )
-
- def unbatched_forward(
- self,
- state: ts.SimState,
- ) -> dict[str, torch.Tensor]:
- """Compute energies and forces for a single unbatched system with multiple
- species.
-
- Internal implementation that processes a single, non-batched simulation state.
- This method handles all pair interactions between particles of different types
- using the appropriate interaction parameters from the parameter matrices.
+ _cutoff = cutoff or float(self.sigma_matrix.detach().max())
- Args:
- state (SimState): Single, non-batched simulation state containing atomic
- positions, cell vectors, and other system information. Species indices
- are read from ``state.atomic_numbers``.
-
- Returns:
- dict[str, torch.Tensor]: Computed properties:
- - "energy": Total potential energy (scalar)
- - "forces": Atomic forces with shape [n_atoms, 3]
- (if compute_forces=True)
- - "stress": Stress tensor with shape [3, 3]
- (if compute_stress=True)
- - "energies": Per-atom energies with shape [n_atoms]
- (if per_atom_energies=True)
- - "stresses": Per-atom stresses with shape [n_atoms, 3, 3]
- (if per_atom_stresses=True)
-
- Notes:
- This method supports both neighbor list optimization and full pairwise
- calculations based on the use_neighbor_list parameter. For each pair of
- particles, it looks up the appropriate parameters based on the species
- of the two particles.
- """
- species_idx = state.atomic_numbers.to(device=self.device, dtype=torch.long)
-
- positions = state.positions
- cell = state.row_vector_cell
- cell = cell.squeeze()
-
- # Compute neighbor list or full distance matrix
- if self.use_neighbor_list:
- # Get neighbor list for efficient computation
- # Ensure system_idx exists (create if None for single system)
- system_idx = torch.zeros(
- positions.shape[0], dtype=torch.long, device=self.device
- )
- mapping, _, shifts_idx = torchsim_nl(
- positions=positions,
- cell=cell,
- pbc=self.pbc,
- cutoff=self.cutoff,
- system_idx=system_idx,
- )
- # Pass shifts_idx directly - get_pair_displacements will convert them
- dr_vec, distances = transforms.get_pair_displacements(
- positions=positions,
- cell=cell,
- pbc=self.pbc,
- pairs=(mapping[0], mapping[1]),
- shifts=shifts_idx,
- )
-
- else:
- # Direct N^2 computation of all pairs
- dr_vec, distances = transforms.get_pair_displacements(
- positions=positions,
- cell=cell,
- pbc=self.pbc,
- )
- # Remove self-interactions and apply cutoff
- mask = torch.eye(positions.shape[0], dtype=torch.bool, device=self.device)
- distances = distances.masked_fill(mask, float("inf"))
- mask = distances < self.cutoff
-
- # Get valid pairs and their displacements
- i, j = torch.where(mask)
- mapping = torch.stack([j, i])
- dr_vec = dr_vec[mask]
- distances = distances[mask]
-
- # Look up species-specific parameters for each interacting pair
- pair_species_1 = species_idx[mapping[0]] # Species of first atom in pair
- pair_species_2 = species_idx[mapping[1]] # Species of second atom in pair
-
- # Get interaction parameters from parameter matrices
- pair_sigmas = self.sigma_matrix[pair_species_1, pair_species_2]
- pair_epsilons = self.epsilon_matrix[pair_species_1, pair_species_2]
- pair_alphas = self.alpha_matrix[pair_species_1, pair_species_2]
-
- # Calculate pair energies using species-specific parameters
- pair_energies = soft_sphere_pair(
- distances, sigma=pair_sigmas, epsilon=pair_epsilons, alpha=pair_alphas
+ pair_fn = MultiSoftSpherePairFn(
+ atomic_numbers=unique_z.to(device=_device),
+ sigma_matrix=self.sigma_matrix,
+ epsilon_matrix=self.epsilon_matrix,
+ alpha_matrix=self.alpha_matrix,
)
- # Initialize results with total energy (divide by 2 to avoid double counting)
- results = {"energy": 0.5 * pair_energies.sum()}
-
- if self.per_atom_energies:
- # Compute per-atom energy contributions
- atom_energies = torch.zeros(
- positions.shape[0], dtype=self.dtype, device=self.device
- )
- # Each atom gets half of the pair energy
- atom_energies.index_add_(0, mapping[0], 0.5 * pair_energies)
- atom_energies.index_add_(0, mapping[1], 0.5 * pair_energies)
- results["energies"] = atom_energies
-
- if self.compute_forces or self.compute_stress:
- # Calculate pair forces
- pair_forces = soft_sphere_pair_force(
- distances, sigma=pair_sigmas, epsilon=pair_epsilons, alpha=pair_alphas
- )
-
- # Project scalar forces onto displacement vectors
- force_vectors = (pair_forces / distances)[:, None] * dr_vec
-
- if self.compute_forces:
- # Compute atomic forces by accumulating pair contributions
- forces = torch.zeros_like(positions)
- forces.index_add_(0, mapping[0], -force_vectors)
- forces.index_add_(0, mapping[1], force_vectors)
- results["forces"] = forces
-
- if self.compute_stress and cell is not None:
- # Compute stress tensor using virial formula
- stress_per_pair = torch.einsum("...i,...j->...ij", dr_vec, force_vectors)
- volume = torch.abs(torch.linalg.det(cell))
-
- results["stress"] = -stress_per_pair.sum(dim=0) / volume
-
- if self.per_atom_stresses:
- # Compute per-atom stress contributions
- atom_stresses = torch.zeros(
- (positions.shape[0], 3, 3), dtype=self.dtype, device=self.device
- )
- atom_stresses.index_add_(0, mapping[0], -0.5 * stress_per_pair)
- atom_stresses.index_add_(0, mapping[1], -0.5 * stress_per_pair)
- results["stresses"] = atom_stresses / volume
-
- return results
-
- def forward(self, state: ts.SimState, **_kwargs) -> dict[str, torch.Tensor]:
- """Compute soft sphere potential properties for multi-component systems.
-
- Main entry point for multi-species soft sphere calculations that handles
- batched states by dispatching each batch to the unbatched implementation
- and combining results.
-
- Args:
- state (SimState): Input state containing atomic positions, cell vectors,
- and other system information.
- **_kwargs: Unused; accepted for interface compatibility.
-
- Returns:
- dict[str, torch.Tensor]: Computed properties:
- - "energy": Potential energy with shape [n_systems]
- - "forces": Atomic forces with shape [n_atoms, 3]
- (if compute_forces=True)
- - "stress": Stress tensor with shape [n_systems, 3, 3]
- (if compute_stress=True)
- - May include additional outputs based on configuration
-
- Raises:
- ValueError: If batch cannot be inferred for multi-cell systems or if
- species information is missing.
-
- Examples:
- ```py
- # Create model for binary mixture
- model = SoftSphereMultiModel(
- species=particle_types,
- sigma_matrix=distance_matrix,
- epsilon_matrix=strength_matrix,
- compute_forces=True,
- )
-
- # Calculate properties
- results = model(simulation_state)
- energy = results["energy"]
- forces = results["forces"]
- ```
-
- Notes:
- This method requires species information either provided during initialization
- or included in the state object's metadata.
- """
- if state.pbc != self.pbc:
- raise ValueError("PBC mismatch between model and state")
-
- # Handle system indices if not provided
- if state.system_idx is None and state.cell.shape[0] > 1:
- raise ValueError(
- "system_idx can only be inferred if there is only one system"
- )
-
- outputs = [
- self.unbatched_forward(state[sys_idx]) for sys_idx in range(state.n_systems)
- ]
- properties = outputs[0]
-
- # Combine results
- results: dict[str, torch.Tensor] = {}
- for key in ("stress", "energy", "forces", "energies", "stresses"):
- if key in properties:
- results[key] = torch.stack([out[key] for out in outputs])
-
- for key in ("forces", "energies", "stresses"):
- if key in properties:
- results[key] = torch.cat([out[key] for out in outputs], dim=0)
-
- return results
+ super().__init__(
+ pair_fn=pair_fn,
+ cutoff=_cutoff,
+ device=device,
+ dtype=dtype,
+ compute_forces=compute_forces,
+ compute_stress=compute_stress,
+ per_atom_energies=per_atom_energies,
+ per_atom_stresses=per_atom_stresses,
+ neighbor_list_fn=neighbor_list_fn,
+ reduce_to_half_list=True,
+ retain_graph=retain_graph,
+ )
diff --git a/torch_sim/workflows/a2c.py b/torch_sim/workflows/a2c.py
index 1ee49088..81d22078 100644
--- a/torch_sim/workflows/a2c.py
+++ b/torch_sim/workflows/a2c.py
@@ -282,7 +282,6 @@ def random_packed_structure(
device=device,
dtype=dtype,
compute_forces=True,
- use_neighbor_list=True,
)
# Dummy atomic numbers
@@ -402,15 +401,12 @@ def random_packed_structure_multi(
# Convert fractional to cartesian coordinates
positions_cart = torch.matmul(positions, cell)
- # Initialize multi-species soft sphere potential calculator
- n_species = len(element_counts)
model = SoftSphereMultiModel(
- n_species=n_species,
+ atomic_numbers=species_idx,
sigma_matrix=diameter_matrix,
device=device,
dtype=dtype,
compute_forces=True,
- use_neighbor_list=True,
)
state_dict = ts.SimState(