From 80bfb0b39a6818e9eb1bd683040a39e981a3351d Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Mon, 26 Jan 2026 15:26:43 +0100 Subject: [PATCH 01/22] Add a prototype of the heat flux wrapper --- .../metatomic/torch/ase_calculator.py | 15 +- .../metatomic/torch/heat_flux.py | 298 ++++++++++++++++++ 2 files changed, 307 insertions(+), 6 deletions(-) create mode 100644 python/metatomic_torch/metatomic/torch/heat_flux.py diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index c66ab500..1466e4bd 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -99,6 +99,10 @@ def _get_charges(atoms: ase.Atoms) -> np.ndarray: }, } +IMPLEMENTED_PROPERTIES = [ + "heat_flux", +] + class MetatomicCalculator(ase.calculators.calculator.Calculator): """ @@ -293,9 +297,9 @@ def __init__( for name, output in additional_outputs.items(): assert isinstance(name, str) assert isinstance(output, torch.ScriptObject) - assert "explicit_gradients_setter" in output._method_names(), ( - "outputs must be ModelOutput instances" - ) + assert ( + "explicit_gradients_setter" in output._method_names() + ), "outputs must be ModelOutput instances" self._additional_output_requests = additional_outputs @@ -318,7 +322,7 @@ def __init__( # We do our own check to verify if a property is implemented in `calculate()`, # so we pretend to be able to compute all properties ASE knows about. - self.implemented_properties = ALL_ASE_PROPERTIES + self.implemented_properties = ALL_ASE_PROPERTIES + IMPLEMENTED_PROPERTIES self.additional_outputs: Dict[str, TensorMap] = {} """ @@ -933,8 +937,7 @@ def _get_ase_input( tensor.set_info("quantity", infos["quantity"]) tensor.set_info("unit", infos["unit"]) - tensor.to(dtype=dtype, device=device) - return tensor + return tensor.to(dtype=dtype, device=device) def _ase_to_torch_data(atoms, dtype, device): diff --git a/python/metatomic_torch/metatomic/torch/heat_flux.py b/python/metatomic_torch/metatomic/torch/heat_flux.py new file mode 100644 index 00000000..44197ab5 --- /dev/null +++ b/python/metatomic_torch/metatomic/torch/heat_flux.py @@ -0,0 +1,298 @@ +import torch + +from torch.autograd.functional import jvp +from typing import List, Dict, Optional +from vesin.metatomic import compute_requested_neighbors + + +from metatensor.torch import Labels, TensorBlock, TensorMap +from metatomic.torch import ( + AtomisticModel, + ModelEvaluationOptions, + ModelOutput, + System, +) + + +def wrap_positions(positions: torch.Tensor, cell: torch.Tensor) -> torch.Tensor: + fractional_positions = torch.einsum("iv,kv->ik", positions, cell.inverse()) + fractional_positions -= torch.floor(fractional_positions) + wrapped_positions = torch.einsum("iv,kv->ik", fractional_positions, cell) + + return wrapped_positions + + +def check_collisions( + cell: torch.Tensor, positions: torch.Tensor, cutoff: float +) -> tuple[torch.Tensor, torch.Tensor]: + inv_cell = cell.inverse() + norm_inv_cell = torch.linalg.norm(inv_cell, dim=1) + inv_cell /= norm_inv_cell[:, None] + norm_coords = torch.einsum("iv,kv->ik", positions, inv_cell) + cell_vec_lengths = torch.diag(cell @ inv_cell) + collisions = torch.hstack( + [norm_coords <= cutoff, norm_coords >= cell_vec_lengths - cutoff], + ).to(device=positions.device) + + return collisions[:, [0, 3, 1, 4, 2, 5]], norm_coords + + +def collisions_to_replicas(collisions: torch.Tensor) -> torch.Tensor: + """ + Convert collisions to replicas. + + collisions: [N, 6]: has collisions with (x_lo, x_hi, y_lo, y_hi, z_lo, z_hi) + """ + origin = torch.full( + (len(collisions),), True, dtype=torch.bool, device=collisions.device + ) + axs = torch.vstack([origin, collisions[:, 0], collisions[:, 1]]) + ays = torch.vstack([origin, collisions[:, 2], collisions[:, 3]]) + azs = torch.vstack([origin, collisions[:, 4], collisions[:, 5]]) + # leverage broadcasting + outs = axs[:, None, None] & ays[None, :, None] & azs[None, None, :] + outs = torch.movedim(outs, -1, 0) + outs[:, 0, 0, 0] = False + return outs.to(device=collisions.device) + + +def generate_replica_atoms( + types: torch.Tensor, + positions: torch.Tensor, + cell: torch.Tensor, + replicas: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + replicas = torch.argwhere(replicas) + replica_idx = replicas[:, 0] + replica_offsets = torch.tensor( + [0, 1, -1], device=positions.device, dtype=positions.dtype + )[replicas[:, 1:]] + replica_positions = positions[replica_idx] + replica_positions += torch.einsum("aA,iA->ia", cell, replica_offsets) + + return replica_idx, types[replica_idx], replica_positions + + +def unfold_system(metatomic_system: System, cutoff: float) -> System: + wrapped_positions = wrap_positions( + metatomic_system.positions, metatomic_system.cell + ) + collisions, _ = check_collisions( + metatomic_system.cell, wrapped_positions, cutoff + 0.5 + ) + replicas = collisions_to_replicas(collisions) + replica_idx, replica_types, replica_positions = generate_replica_atoms( + metatomic_system.types, wrapped_positions, metatomic_system.cell, replicas + ) + unfolded_types = torch.cat( + [ + metatomic_system.types, + replica_types, + ] + ) + unfolded_positions = torch.cat( + [ + wrapped_positions, + replica_positions, + ] + ) + unfolded_idx = torch.cat( + [ + torch.arange(len(metatomic_system.types), device=metatomic_system.device), + replica_idx, + ] + ) + unfolded_n_atoms = len(unfolded_types) + masses_block = metatomic_system.get_data("masses").block() + velocities_block = metatomic_system.get_data("velocities").block() + unfolded_masses = masses_block.values[unfolded_idx] + unfolded_velocities = velocities_block.values[unfolded_idx] + unfolded_masses_block = TensorBlock( + values=unfolded_masses, + samples=Labels( + ["atoms"], + torch.arange(unfolded_n_atoms, device=metatomic_system.device).reshape( + -1, 1 + ), + ), + components=masses_block.components, + properties=masses_block.properties, + ) + unfolded_velocities_block = TensorBlock( + values=unfolded_velocities, + samples=Labels( + ["atoms"], + torch.arange(unfolded_n_atoms, device=metatomic_system.device).reshape( + -1, 1 + ), + ), + components=velocities_block.components, + properties=velocities_block.properties, + ) + unfolded_system = System( + types=unfolded_types, + positions=unfolded_positions, + cell=torch.tensor( + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + dtype=unfolded_positions.dtype, + device=metatomic_system.device, + ), + pbc=torch.tensor([False, False, False], device=metatomic_system.device), + ) + unfolded_system.add_data( + "masses", + TensorMap( + Labels("_", torch.tensor([[0]], device=metatomic_system.device)), + [unfolded_masses_block], + ), + ) + unfolded_system.add_data( + "velocities", + TensorMap( + Labels("_", torch.tensor([[0]], device=metatomic_system.device)), + [unfolded_velocities_block], + ), + ) + return unfolded_system.to(metatomic_system.dtype, metatomic_system.device) + + +class HeatFluxWrapper(torch.nn.Module): + + def __init__(self, model: AtomisticModel): + super().__init__() + + self._model = model + # TODO: throw error if the simulation cell is smaller than double the interaction range + self._interaction_range = model.capabilities().interaction_range + + self._requested_inputs = { + "masses": ModelOutput(quantity="mass", unit="u", per_atom=True), + "velocities": ModelOutput(quantity="velocity", unit="A/fs", per_atom=True), + } + + hf_output = ModelOutput( + quantity="heat_flux", + unit="", + explicit_gradients=[], + per_atom=False, + ) + outputs = self._model.capabilities().outputs.copy() + outputs["extra::heat_flux"] = hf_output + self._model.capabilities().outputs["extra::heat_flux"] = hf_output + + energies_output = ModelOutput( + quantity="energy", unit=outputs["energy"].unit, per_atom=True + ) + self._unfolded_run_options = ModelEvaluationOptions( + length_unit=self._model.capabilities().length_unit, + outputs={"energy": energies_output}, + selected_atoms=None, + ) + + def requested_inputs(self) -> Dict[str, ModelOutput]: + return self._requested_inputs + + def barycenter_and_atomic_energies(self, system: System, n_atoms: int): + atomic_e = self._model([system], self._unfolded_run_options, False)["energy"][ + 0 + ].values.flatten() + total_e = atomic_e[:n_atoms].sum() + r_aux = system.positions.detach() + barycenter = torch.einsum("i,ik->k", atomic_e[:n_atoms], r_aux[:n_atoms]) + + return barycenter, atomic_e, total_e + + def calc_unfolded_heat_flux(self, system: System) -> torch.Tensor: + n_atoms = len(system.positions) + unfolded_system = unfold_system(system, self._interaction_range).to("cpu") + compute_requested_neighbors( + unfolded_system, self._unfolded_run_options.length_unit, model=self._model + ) + unfolded_system = unfolded_system.to(system.device) + velocities: torch.Tensor = ( + unfolded_system.get_data("velocities").block().values.reshape(-1, 3) + ) + masses: torch.Tensor = ( + unfolded_system.get_data("masses").block().values.reshape(-1) + ) + barycenter, atomic_e, total_e = self.barycenter_and_atomic_energies( + unfolded_system, n_atoms + ) + + term1 = torch.zeros( + (3), device=system.positions.device, dtype=system.positions.dtype + ) + for i in range(3): + grad_i = torch.autograd.grad( + [barycenter[i]], + [unfolded_system.positions], + retain_graph=True, + create_graph=False, + )[0] + grad_i = torch.jit._unwrap_optional(grad_i) + term1[i] = (grad_i * velocities).sum() + + go = torch.jit.annotate( + Optional[List[Optional[torch.Tensor]]], [torch.ones_like(total_e)] + ) + grads = torch.autograd.grad( + [total_e], + [unfolded_system.positions], + grad_outputs=go, + )[0] + grads = torch.jit._unwrap_optional(grads) + term2 = ( + unfolded_system.positions * (grads * velocities).sum(dim=1, keepdim=True) + ).sum(dim=0) + + hf_pot = term1 - term2 + + hf_conv = ( + ( + atomic_e[:n_atoms] + + 0.5 + * masses[:n_atoms] + * torch.linalg.norm(velocities[:n_atoms], dim=1) ** 2 + * 103.6427 # u*A^2/fs^2 to eV + )[:, None] + * velocities[:n_atoms] + ).sum(dim=0) + + return hf_pot + hf_conv + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + ) -> Dict[str, TensorMap]: + + run_options = ModelEvaluationOptions( + length_unit=self._model.capabilities().length_unit, + outputs=outputs, + selected_atoms=None, + ) + results = self._model(systems, run_options, False) + + if "extra::heat_flux" not in outputs: + return results + + device = systems[0].device + heat_fluxes: List[torch.Tensor] = [] + for system in systems: + heat_fluxes.append(self.calc_unfolded_heat_flux(system)) + + samples = Labels( + ["system"], torch.arange(len(systems), device=device).reshape(-1, 1) + ) + + hf_block = TensorBlock( + values=torch.vstack(heat_fluxes).reshape(-1, 3, 1).to(device=device), + samples=samples, + components=[Labels(["xyz"], torch.arange(3, device=device).reshape(-1, 1))], + properties=Labels(["heat_flux"], torch.tensor([[0]], device=device)), + ) + results["extra::heat_flux"] = TensorMap( + Labels("_", torch.tensor([[0]], device=device)), [hf_block] + ) + return results From 2decdf49b7721c05bbdc86188baa9fcba71384cf Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Mon, 2 Feb 2026 12:16:26 +0100 Subject: [PATCH 02/22] Add documentations and agent-generated tests --- .../metatomic/torch/heat_flux.py | 55 ++- .../metatomic_torch/tests/test_heat_flux.py | 330 ++++++++++++++++++ 2 files changed, 378 insertions(+), 7 deletions(-) create mode 100644 python/metatomic_torch/tests/test_heat_flux.py diff --git a/python/metatomic_torch/metatomic/torch/heat_flux.py b/python/metatomic_torch/metatomic/torch/heat_flux.py index 44197ab5..71317f71 100644 --- a/python/metatomic_torch/metatomic/torch/heat_flux.py +++ b/python/metatomic_torch/metatomic/torch/heat_flux.py @@ -15,6 +15,9 @@ def wrap_positions(positions: torch.Tensor, cell: torch.Tensor) -> torch.Tensor: + """ + Wrap positions into the periodic cell. + """ fractional_positions = torch.einsum("iv,kv->ik", positions, cell.inverse()) fractional_positions -= torch.floor(fractional_positions) wrapped_positions = torch.einsum("iv,kv->ik", fractional_positions, cell) @@ -23,25 +26,51 @@ def wrap_positions(positions: torch.Tensor, cell: torch.Tensor) -> torch.Tensor: def check_collisions( - cell: torch.Tensor, positions: torch.Tensor, cutoff: float + cell: torch.Tensor, positions: torch.Tensor, cutoff: float, skin: float ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Detect atoms that lie within a cutoff distance from the periodic cell boundaries, + i.e. have interactions with atoms at the opposite end of the cell. + """ inv_cell = cell.inverse() norm_inv_cell = torch.linalg.norm(inv_cell, dim=1) inv_cell /= norm_inv_cell[:, None] - norm_coords = torch.einsum("iv,kv->ik", positions, inv_cell) cell_vec_lengths = torch.diag(cell @ inv_cell) + if cell_vec_lengths.min() < 2 * (cutoff + skin): + raise ValueError( + f"Cell is too small compared to {(cutoff + skin) = }. " + "Ensure that all cell vectors are at least twice the length." + ) + + cutoff += skin + norm_coords = torch.einsum("iv,kv->ik", positions, inv_cell) collisions = torch.hstack( [norm_coords <= cutoff, norm_coords >= cell_vec_lengths - cutoff], ).to(device=positions.device) - return collisions[:, [0, 3, 1, 4, 2, 5]], norm_coords + return ( + collisions[ + :, [0, 3, 1, 4, 2, 5] # reorder to (x_lo, x_hi, y_lo, y_hi, z_lo, z_hi) + ], + norm_coords, + ) def collisions_to_replicas(collisions: torch.Tensor) -> torch.Tensor: """ - Convert collisions to replicas. + Convert boundary-collision flags into a boolean mask over all periodic image + displacements in {0, +1, -1}^3. e.g. for an atom colliding with the x_lo and y_hi + boundaries, we need the replicas at (1, 0, 0), (0, -1, 0), (1, -1, 0) image cells. collisions: [N, 6]: has collisions with (x_lo, x_hi, y_lo, y_hi, z_lo, z_hi) + + returns: [N, 3, 3, 3] boolean mask over image displacements in {0, +1, -1}^3 + 0: no replica needed along that axis + 1: +1 replica needed along that axis (i.e., near low boundary, a replica is + placed just outside the high boundary) + 2: -1 replica needed along that axis (i.e., near high boundary, a replica is + placed just outside the low boundary) + axis order: x, y, z """ origin = torch.full( (len(collisions),), True, dtype=torch.bool, device=collisions.device @@ -52,7 +81,7 @@ def collisions_to_replicas(collisions: torch.Tensor) -> torch.Tensor: # leverage broadcasting outs = axs[:, None, None] & ays[None, :, None] & azs[None, None, :] outs = torch.movedim(outs, -1, 0) - outs[:, 0, 0, 0] = False + outs[:, 0, 0, 0] = False # not close to any boundary -> no replica needed return outs.to(device=collisions.device) @@ -62,6 +91,12 @@ def generate_replica_atoms( cell: torch.Tensor, replicas: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + For atoms near the low boundary (x_lo/y_lo/z_lo), generate their images shifted + by +1 cell vector (i.e., placed just outside the high boundary). + For atoms near the high boundary (x_hi/y_hi/z_hi), generate images shifted by −1 + cell vector. + """ replicas = torch.argwhere(replicas) replica_idx = replicas[:, 0] replica_offsets = torch.tensor( @@ -73,12 +108,18 @@ def generate_replica_atoms( return replica_idx, types[replica_idx], replica_positions -def unfold_system(metatomic_system: System, cutoff: float) -> System: +def unfold_system(metatomic_system: System, cutoff: float, skin: float = 0.5) -> System: + """ + Unfold a periodic system by generating replica atoms for those near the cell + boundaries within the specified cutoff distance. + The unfolded system has no periodic boundary conditions. + """ + wrapped_positions = wrap_positions( metatomic_system.positions, metatomic_system.cell ) collisions, _ = check_collisions( - metatomic_system.cell, wrapped_positions, cutoff + 0.5 + metatomic_system.cell, wrapped_positions, cutoff, skin ) replicas = collisions_to_replicas(collisions) replica_idx, replica_types, replica_positions = generate_replica_atoms( diff --git a/python/metatomic_torch/tests/test_heat_flux.py b/python/metatomic_torch/tests/test_heat_flux.py new file mode 100644 index 00000000..37b1f0a9 --- /dev/null +++ b/python/metatomic_torch/tests/test_heat_flux.py @@ -0,0 +1,330 @@ +import pytest +import torch + +from metatensor.torch import Labels, TensorBlock, TensorMap + +from metatomic.torch import ModelOutput, System + +from metatomic.torch.heat_flux import ( + HeatFluxWrapper, + check_collisions, + collisions_to_replicas, + generate_replica_atoms, + unfold_system, + wrap_positions, +) + + +def _make_scalar_tensormap(values: torch.Tensor, property_name: str) -> TensorMap: + block = TensorBlock( + values=values, + samples=Labels( + ["atoms"], + torch.arange(values.shape[0], device=values.device).reshape(-1, 1), + ), + components=[], + properties=Labels([property_name], torch.tensor([[0]], device=values.device)), + ) + return TensorMap(Labels("_", torch.tensor([[0]], device=values.device)), [block]) + + +def _make_velocity_tensormap(values: torch.Tensor) -> TensorMap: + block = TensorBlock( + values=values, + samples=Labels( + ["atoms"], + torch.arange(values.shape[0], device=values.device).reshape(-1, 1), + ), + components=[ + Labels( + ["xyz"], + torch.arange(3, device=values.device).reshape(-1, 1), + ) + ], + properties=Labels(["velocity"], torch.tensor([[0]], device=values.device)), + ) + return TensorMap(Labels("_", torch.tensor([[0]], device=values.device)), [block]) + + +def _make_system_with_data(positions: torch.Tensor, cell: torch.Tensor) -> System: + types = torch.tensor([1] * len(positions), dtype=torch.int32) + system = System( + types=types, + positions=positions, + cell=cell, + pbc=torch.tensor([True, True, True]), + ) + masses = torch.ones((len(positions), 1), dtype=positions.dtype) + velocities = torch.zeros((len(positions), 3, 1), dtype=positions.dtype) + system.add_data("masses", _make_scalar_tensormap(masses, "mass")) + system.add_data("velocities", _make_velocity_tensormap(velocities)) + return system + + +def test_wrap_positions_cubic_matches_expected(): + cell = torch.eye(3) * 2.0 + positions = torch.tensor([[-0.1, 0.0, 0.0], [2.1, 1.0, -0.5]]) + wrapped = wrap_positions(positions, cell) + expected = torch.tensor([[1.9, 0.0, 0.0], [0.1, 1.0, 1.5]]) + assert torch.allclose(wrapped, expected) + + +def test_check_collisions_cubic_axis_order(): + cell = torch.eye(3) * 2.0 + positions = torch.tensor([[0.1, 1.0, 1.9]]) + collisions, norm_coords = check_collisions(cell, positions, cutoff=0.2, skin=0.0) + assert torch.allclose(norm_coords, positions) + assert collisions.shape == (1, 6) + assert collisions[0].tolist() == [True, False, False, False, False, True] + + +def test_generate_replica_atoms_cubic_offsets(): + types = torch.tensor([1]) + positions = torch.tensor([[0.1, 1.0, 1.0]]) + cell = torch.eye(3) * 2.0 + collisions = torch.tensor([[True, False, False, False, False, False]]) + replicas = collisions_to_replicas(collisions) + replica_idx, replica_types, replica_positions = generate_replica_atoms( + types, positions, cell, replicas + ) + assert replica_idx.tolist() == [0] + assert replica_types.tolist() == [1] + assert torch.allclose( + replica_positions, positions + torch.tensor([[2.0, 0.0, 0.0]]) + ) + + +def test_wrap_positions_triclinic_fractional_bounds_and_shift(): + cell = torch.tensor( + [ + [2.0, 0.3, 0.2], + [0.1, 1.7, 0.4], + [0.2, 0.5, 1.9], + ] + ) + positions = torch.tensor( + [ + [-0.1, 0.0, 0.0], + [2.1, 1.6, -0.5], + [4.2, -0.2, 6.1], + ] + ) + inv_cell = cell.inverse() + wrapped = wrap_positions(positions, cell) + fractional_before = torch.einsum("iv,kv->ik", positions, inv_cell) + fractional_after = torch.einsum("iv,kv->ik", wrapped, inv_cell) + + assert torch.all(fractional_after >= 0) + assert torch.all(fractional_after < 1) + + delta_frac = fractional_after - fractional_before + rounded = torch.round(delta_frac) + assert torch.allclose(delta_frac, rounded, atol=1e-6, rtol=0) + assert torch.allclose(rounded, -torch.floor(fractional_before), atol=1e-6, rtol=0) + + +def test_check_collisions_triclinic_targets(): + cell = torch.tensor( + [ + [2.0, 0.3, 0.2], + [0.1, 1.7, 0.4], + [0.2, 0.5, 1.9], + ] + ) + cutoff = 0.2 + inv_cell = cell.inverse() + inv_cell_norm = inv_cell / torch.linalg.norm(inv_cell, dim=1)[:, None] + cell_vec_lengths = torch.diag(cell @ inv_cell_norm) + + target = torch.stack( + [ + torch.tensor([0.05, 0.6, 0.6]), + torch.tensor([cell_vec_lengths[0] - 0.05, 0.05, cell_vec_lengths[2] - 0.1]), + torch.tensor([0.3, cell_vec_lengths[1] - 0.05, 0.1]), + ] + ) + positions = target @ torch.inverse(inv_cell_norm).T + + collisions, norm_coords = check_collisions(cell, positions, cutoff=cutoff, skin=0.0) + assert torch.allclose(norm_coords, target, atol=1e-6, rtol=0) + + expected_low = target <= cutoff + expected_high = target >= cell_vec_lengths - cutoff + expected = torch.hstack([expected_low, expected_high]) + expected = expected[:, [0, 3, 1, 4, 2, 5]] + + assert torch.equal(collisions, expected) + + +def test_check_collisions_raises_on_small_cell(): + cell = torch.eye(3) * 1.0 + positions = torch.zeros((1, 3)) + with pytest.raises(ValueError, match="Cell is too small"): + check_collisions(cell, positions, cutoff=0.4, skin=0.2) + + +def test_collisions_to_replicas_combines_displacements(): + collisions = torch.tensor([[True, False, False, True, False, False]]) + replicas = collisions_to_replicas(collisions) + assert replicas.shape == (1, 3, 3, 3) + assert replicas[0, 0, 0, 0].item() is False + + nonzero = torch.nonzero(replicas, as_tuple=False) + expected = { + (0, 1, 0, 0), + (0, 0, 2, 0), + (0, 1, 2, 0), + } + assert {tuple(row.tolist()) for row in nonzero} == expected + + +def test_generate_replica_atoms_triclinic_offsets(): + cell = torch.tensor( + [ + [2.0, 0.3, 0.2], + [0.1, 1.7, 0.4], + [0.2, 0.5, 1.9], + ] + ) + types = torch.tensor([1]) + positions = torch.tensor([[0.2, 0.4, 0.6]]) + collisions = torch.tensor([[True, False, False, True, False, False]]) + replicas = collisions_to_replicas(collisions) + replica_idx, replica_types, replica_positions = generate_replica_atoms( + types, positions, cell, replicas + ) + + assert replica_idx.tolist() == [0, 0, 0] + assert replica_types.tolist() == [1, 1, 1] + + expected_offsets = [cell[:, 0], -cell[:, 1], cell[:, 0] - cell[:, 1]] + expected_positions = [positions[0] + offset for offset in expected_offsets] + + for expected in expected_positions: + assert any(torch.allclose(expected, actual, atol=1e-6, rtol=0) for actual in replica_positions) + + +def test_unfold_system_adds_replica_and_data(): + cell = torch.eye(3) * 2.0 + positions = torch.tensor([[0.1, 1.0, 1.0]]) + system = _make_system_with_data(positions, cell) + unfolded = unfold_system(system, cutoff=0.1) + + assert len(unfolded.positions) == 2 + assert torch.all(unfolded.pbc == torch.tensor([False, False, False])) + assert torch.allclose(unfolded.cell, torch.zeros_like(unfolded.cell)) + + masses = unfolded.get_data("masses").block().values + velocities = unfolded.get_data("velocities").block().values + assert masses.shape[0] == 2 + assert velocities.shape[0] == 2 + + assert torch.allclose(unfolded.positions[0], positions[0]) + assert torch.allclose( + unfolded.positions[1], positions[0] + torch.tensor([2.0, 0.0, 0.0]) + ) + + +def test_heat_flux_wrapper_requested_inputs(): + class DummyCapabilities: + def __init__(self): + self.outputs = {"energy": ModelOutput(quantity="energy", unit="eV")} + self.length_unit = "A" + self.interaction_range = 1.0 + + class DummyModel: + def __init__(self): + self._capabilities = DummyCapabilities() + + def capabilities(self): + return self._capabilities + + def __call__(self, systems, options, check_consistency): + results = {} + if "energy" in options.outputs: + values = torch.zeros((len(systems), 1), dtype=systems[0].positions.dtype) + block = TensorBlock( + values=values, + samples=Labels( + ["system"], + torch.arange(len(systems), device=values.device).reshape(-1, 1), + ), + components=[], + properties=Labels(["energy"], torch.tensor([[0]], device=values.device)), + ) + results["energy"] = TensorMap( + Labels("_", torch.tensor([[0]], device=values.device)), [block] + ) + return results + + wrapper = HeatFluxWrapper(DummyModel()) + requested = wrapper.requested_inputs() + assert set(requested.keys()) == {"masses", "velocities"} + + +def test_heat_flux_wrapper_forward_adds_output(monkeypatch): + class DummyCapabilities: + def __init__(self): + self.outputs = {"energy": ModelOutput(quantity="energy", unit="eV")} + self.length_unit = "A" + self.interaction_range = 1.0 + + class DummyModel: + def __init__(self): + self._capabilities = DummyCapabilities() + + def capabilities(self): + return self._capabilities + + def __call__(self, systems, options, check_consistency): + values = torch.zeros((len(systems), 1), dtype=systems[0].positions.dtype) + block = TensorBlock( + values=values, + samples=Labels( + ["system"], + torch.arange(len(systems), device=values.device).reshape(-1, 1), + ), + components=[], + properties=Labels(["energy"], torch.tensor([[0]], device=values.device)), + ) + return { + "energy": TensorMap( + Labels("_", torch.tensor([[0]], device=values.device)), [block] + ) + } + + def _fake_hf(self, system): + return torch.tensor( + [1.0, 2.0, 3.0], device=system.device, dtype=system.positions.dtype + ) + + wrapper = HeatFluxWrapper(DummyModel()) + monkeypatch.setattr(HeatFluxWrapper, "calc_unfolded_heat_flux", _fake_hf) + + cell = torch.eye(3) + systems = [ + System( + types=torch.tensor([1], dtype=torch.int32), + positions=torch.zeros((1, 3)), + cell=cell, + pbc=torch.tensor([True, True, True]), + ), + System( + types=torch.tensor([1], dtype=torch.int32), + positions=torch.ones((1, 3)), + cell=cell, + pbc=torch.tensor([True, True, True]), + ), + ] + + outputs = { + "energy": ModelOutput(quantity="energy", unit="eV"), + "extra::heat_flux": ModelOutput(quantity="heat_flux", unit=""), + } + results = wrapper.forward(systems, outputs, None) + assert "extra::heat_flux" in results + hf_block = results["extra::heat_flux"].block() + assert hf_block.values.shape == (2, 3, 1) + assert torch.allclose( + hf_block.values[:, :, 0], torch.tensor([[1.0, 2.0, 3.0]] * 2) + ) From e4abd21a8199f6458ba607505214e9eedfd0ec7c Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Tue, 3 Feb 2026 16:15:42 +0100 Subject: [PATCH 03/22] More documentations and tests --- .../metatomic/torch/ase_calculator.py | 6 +- .../metatomic/torch/heat_flux.py | 47 ++++++-- .../metatomic_torch/tests/test_heat_flux.py | 107 ++++++++++++++++-- 3 files changed, 136 insertions(+), 24 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 1466e4bd..0787dc0a 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -297,9 +297,9 @@ def __init__( for name, output in additional_outputs.items(): assert isinstance(name, str) assert isinstance(output, torch.ScriptObject) - assert ( - "explicit_gradients_setter" in output._method_names() - ), "outputs must be ModelOutput instances" + assert "explicit_gradients_setter" in output._method_names(), ( + "outputs must be ModelOutput instances" + ) self._additional_output_requests = additional_outputs diff --git a/python/metatomic_torch/metatomic/torch/heat_flux.py b/python/metatomic_torch/metatomic/torch/heat_flux.py index 71317f71..5f2b5844 100644 --- a/python/metatomic_torch/metatomic/torch/heat_flux.py +++ b/python/metatomic_torch/metatomic/torch/heat_flux.py @@ -1,11 +1,9 @@ -import torch +from typing import Dict, List, Optional -from torch.autograd.functional import jvp -from typing import List, Dict, Optional +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap from vesin.metatomic import compute_requested_neighbors - -from metatensor.torch import Labels, TensorBlock, TensorMap from metatomic.torch import ( AtomisticModel, ModelEvaluationOptions, @@ -36,10 +34,13 @@ def check_collisions( norm_inv_cell = torch.linalg.norm(inv_cell, dim=1) inv_cell /= norm_inv_cell[:, None] cell_vec_lengths = torch.diag(cell @ inv_cell) - if cell_vec_lengths.min() < 2 * (cutoff + skin): + if cell_vec_lengths.min() < (cutoff + skin): raise ValueError( - f"Cell is too small compared to {(cutoff + skin) = }. " - "Ensure that all cell vectors are at least twice the length." + "Cell is too small compared to (cutoff + skin) = " + + str(cutoff + skin) + + ". " + "Ensure that all cell vectors are at least this length. Currently, the" + " minimum cell vector length is " + str(cell_vec_lengths.min()) + "." ) cutoff += skin @@ -198,12 +199,32 @@ def unfold_system(metatomic_system: System, cutoff: float, skin: float = 0.5) -> class HeatFluxWrapper(torch.nn.Module): + """ + A wrapper around an AtomisticModel that computes the heat flux of a system using the + unfolded system approach. The heat flux is computed using the atomic energies (eV), + positions(Å), masses(u), velocities(Å/fs), and the energy gradients. - def __init__(self, model: AtomisticModel): + The unfolded system is generated by creating replica atoms for those near the cell + boundaries within the interaction range of the model wrapped. The wrapper adds the + heat flux to the model's outputs under the key "extra::heat_flux". + + For more details on the heat flux calculation, see `Langer, M. F., et al., Heat flux + for semilocal machine-learning potentials. (2023). Physical Review B, 108, L100302.` + """ + + def __init__(self, model: AtomisticModel, skin: float = 0.5): + """ + :param model: the :py:class:`AtomisticModel` to wrap, which should be able to + compute atomic energies and their gradients with respect to positions + :param skin: the skin parameter for unfolding the system. The wrapper will + generate replica atoms for those within (interaction_range + skin) distance from + the cell boundaries. A skin results in more replica atoms and thus higher + computational cost, but ensures that the heat flux is computed correctly. + """ super().__init__() self._model = model - # TODO: throw error if the simulation cell is smaller than double the interaction range + self.skin = skin self._interaction_range = model.capabilities().interaction_range self._requested_inputs = { @@ -245,7 +266,9 @@ def barycenter_and_atomic_energies(self, system: System, n_atoms: int): def calc_unfolded_heat_flux(self, system: System) -> torch.Tensor: n_atoms = len(system.positions) - unfolded_system = unfold_system(system, self._interaction_range).to("cpu") + unfolded_system = unfold_system(system, self._interaction_range, self.skin).to( + "cpu" + ) compute_requested_neighbors( unfolded_system, self._unfolded_run_options.length_unit, model=self._model ) @@ -307,7 +330,6 @@ def forward( outputs: Dict[str, ModelOutput], selected_atoms: Optional[Labels], ) -> Dict[str, TensorMap]: - run_options = ModelEvaluationOptions( length_unit=self._model.capabilities().length_unit, outputs=outputs, @@ -321,6 +343,7 @@ def forward( device = systems[0].device heat_fluxes: List[torch.Tensor] = [] for system in systems: + system.positions.requires_grad_(True) heat_fluxes.append(self.calc_unfolded_heat_flux(system)) samples = Labels( diff --git a/python/metatomic_torch/tests/test_heat_flux.py b/python/metatomic_torch/tests/test_heat_flux.py index 37b1f0a9..cd957140 100644 --- a/python/metatomic_torch/tests/test_heat_flux.py +++ b/python/metatomic_torch/tests/test_heat_flux.py @@ -1,10 +1,19 @@ +import metatomic_lj_test +import numpy as np import pytest import torch - +from ase import Atoms +from ase.md.velocitydistribution import MaxwellBoltzmannDistribution from metatensor.torch import Labels, TensorBlock, TensorMap -from metatomic.torch import ModelOutput, System - +from metatomic.torch import ( + AtomisticModel, + ModelCapabilities, + ModelMetadata, + ModelOutput, + System, +) +from metatomic.torch.ase_calculator import MetatomicCalculator from metatomic.torch.heat_flux import ( HeatFluxWrapper, check_collisions, @@ -15,6 +24,32 @@ ) +@pytest.fixture +def model(): + return metatomic_lj_test.lennard_jones_model( + atomic_type=18, + cutoff=7.0, + sigma=3.405, + epsilon=0.01032, + length_unit="Angstrom", + energy_unit="eV", + with_extension=False, + ) + + +@pytest.fixture +def atoms(): + n_atoms = 250 + cell = np.array([[20.3, 0.0, 0.0], [0.0, 20.3, 0.0], [0.0, 0.0, 20.3]]) + np.random.seed(42) + positions = np.random.random((n_atoms, 3)) * (1 + 2 * 0.1) - 0.1 + atoms = Atoms(f"Ar{n_atoms}", scaled_positions=positions, cell=cell, pbc=True) + MaxwellBoltzmannDistribution( + atoms, temperature_K=300, rng=np.random.default_rng(42) + ) + return atoms + + def _make_scalar_tensormap(values: torch.Tensor, property_name: str) -> TensorMap: block = TensorBlock( values=values, @@ -160,7 +195,7 @@ def test_check_collisions_raises_on_small_cell(): cell = torch.eye(3) * 1.0 positions = torch.zeros((1, 3)) with pytest.raises(ValueError, match="Cell is too small"): - check_collisions(cell, positions, cutoff=0.4, skin=0.2) + check_collisions(cell, positions, cutoff=0.9, skin=0.2) def test_collisions_to_replicas_combines_displacements(): @@ -201,7 +236,10 @@ def test_generate_replica_atoms_triclinic_offsets(): expected_positions = [positions[0] + offset for offset in expected_offsets] for expected in expected_positions: - assert any(torch.allclose(expected, actual, atol=1e-6, rtol=0) for actual in replica_positions) + assert any( + torch.allclose(expected, actual, atol=1e-6, rtol=0) + for actual in replica_positions + ) def test_unfold_system_adds_replica_and_data(): @@ -242,7 +280,9 @@ def capabilities(self): def __call__(self, systems, options, check_consistency): results = {} if "energy" in options.outputs: - values = torch.zeros((len(systems), 1), dtype=systems[0].positions.dtype) + values = torch.zeros( + (len(systems), 1), dtype=systems[0].positions.dtype + ) block = TensorBlock( values=values, samples=Labels( @@ -250,7 +290,9 @@ def __call__(self, systems, options, check_consistency): torch.arange(len(systems), device=values.device).reshape(-1, 1), ), components=[], - properties=Labels(["energy"], torch.tensor([[0]], device=values.device)), + properties=Labels( + ["energy"], torch.tensor([[0]], device=values.device) + ), ) results["energy"] = TensorMap( Labels("_", torch.tensor([[0]], device=values.device)), [block] @@ -285,7 +327,9 @@ def __call__(self, systems, options, check_consistency): torch.arange(len(systems), device=values.device).reshape(-1, 1), ), components=[], - properties=Labels(["energy"], torch.tensor([[0]], device=values.device)), + properties=Labels( + ["energy"], torch.tensor([[0]], device=values.device) + ), ) return { "energy": TensorMap( @@ -325,6 +369,51 @@ def _fake_hf(self, system): assert "extra::heat_flux" in results hf_block = results["extra::heat_flux"].block() assert hf_block.values.shape == (2, 3, 1) + assert torch.allclose(hf_block.values[:, :, 0], torch.tensor([[1.0, 2.0, 3.0]] * 2)) + + +def test_heat_flux_wrapper_calc_unfolded_heat_flux(model, atoms): + metadata = ModelMetadata() + wrapper = HeatFluxWrapper(model.eval()) + cap = wrapper._model.capabilities() + outputs = cap.outputs.copy() + outputs["extra::heat_flux"] = ModelOutput( + quantity="heat_flux", + unit="", + explicit_gradients=[], + per_atom=False, + ) + + new_cap = ModelCapabilities( + outputs=outputs, + atomic_types=cap.atomic_types, + interaction_range=cap.interaction_range, + length_unit=cap.length_unit, + supported_devices=cap.supported_devices, + dtype=cap.dtype, + ) + heat_model = AtomisticModel(wrapper.eval(), metadata, capabilities=new_cap).to( + device="cpu" + ) + calc = MetatomicCalculator( + heat_model, + device="cpu", + additional_outputs={ + "extra::heat_flux": ModelOutput( + quantity="heat_flux", + unit="", + explicit_gradients=[], + per_atom=False, + ) + }, + ) + atoms.calc = calc + atoms.get_potential_energy() + assert "extra::heat_flux" in atoms.calc.additional_outputs + results = atoms.calc.additional_outputs["extra::heat_flux"].block().values assert torch.allclose( - hf_block.values[:, :, 0], torch.tensor([[1.0, 2.0, 3.0]] * 2) + results, + torch.tensor( + [[5.50695568e12], [2.89550111e13], [-1.64821616e13]], dtype=results.dtype + ), ) From 0954a05b1182ad9198758c2d4a1bcd1fff3bf837 Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Wed, 4 Feb 2026 15:05:51 +0100 Subject: [PATCH 04/22] Minor --- python/metatomic_torch/metatomic/torch/heat_flux.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/metatomic_torch/metatomic/torch/heat_flux.py b/python/metatomic_torch/metatomic/torch/heat_flux.py index 5f2b5844..8bbd4cdf 100644 --- a/python/metatomic_torch/metatomic/torch/heat_flux.py +++ b/python/metatomic_torch/metatomic/torch/heat_flux.py @@ -241,7 +241,10 @@ def __init__(self, model: AtomisticModel, skin: float = 0.5): outputs = self._model.capabilities().outputs.copy() outputs["extra::heat_flux"] = hf_output self._model.capabilities().outputs["extra::heat_flux"] = hf_output - + if outputs["energy"].unit != "eV": + raise ValueError( + "HeatFluxWrapper can only be used with energy outputs in eV" + ) energies_output = ModelOutput( quantity="energy", unit=outputs["energy"].unit, per_atom=True ) From 3f7f07b4f9eed2022ca40f5a6af2a7419a19c586 Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Sat, 7 Feb 2026 18:29:43 +0100 Subject: [PATCH 05/22] Add `HardyHeatFluxWrapper` for comparison --- .../metatomic/torch/heat_flux.py | 146 ++++++++++++++++++ 1 file changed, 146 insertions(+) diff --git a/python/metatomic_torch/metatomic/torch/heat_flux.py b/python/metatomic_torch/metatomic/torch/heat_flux.py index 8bbd4cdf..bebe842a 100644 --- a/python/metatomic_torch/metatomic/torch/heat_flux.py +++ b/python/metatomic_torch/metatomic/torch/heat_flux.py @@ -363,3 +363,149 @@ def forward( Labels("_", torch.tensor([[0]], device=device)), [hf_block] ) return results + + +class HardyHeatFluxWrapper(torch.nn.Module): + """ + A wrapper around an AtomisticModel that computes the heat flux of a system using the + unfolded system approach. The heat flux is computed using the atomic energies (eV), + positions(Å), masses(u), velocities(Å/fs), and the energy gradients. + + The unfolded system is generated by creating replica atoms for those near the cell + boundaries within the interaction range of the model wrapped. The wrapper adds the + heat flux to the model's outputs under the key "extra::heat_flux". + + For more details on the heat flux calculation, see `Langer, M. F., et al., Heat flux + for semilocal machine-learning potentials. (2023). Physical Review B, 108, L100302.` + """ + + def __init__(self, model: AtomisticModel, skin: float = 0.5): + """ + :param model: the :py:class:`AtomisticModel` to wrap, which should be able to + compute atomic energies and their gradients with respect to positions + :param skin: the skin parameter for unfolding the system. The wrapper will + generate replica atoms for those within (interaction_range + skin) distance from + the cell boundaries. A skin results in more replica atoms and thus higher + computational cost, but ensures that the heat flux is computed correctly. + """ + super().__init__() + + self._model = model + self.skin = skin + self._interaction_range = model.capabilities().interaction_range + + self._requested_inputs = { + "masses": ModelOutput(quantity="mass", unit="u", per_atom=True), + "velocities": ModelOutput(quantity="velocity", unit="A/fs", per_atom=True), + } + + hf_output = ModelOutput( + quantity="heat_flux", + unit="", + explicit_gradients=[], + per_atom=False, + ) + outputs = self._model.capabilities().outputs.copy() + outputs["extra::heat_flux"] = hf_output + self._model.capabilities().outputs["extra::heat_flux"] = hf_output + if outputs["energy"].unit != "eV": + raise ValueError( + "HeatFluxWrapper can only be used with energy outputs in eV" + ) + energies_output = ModelOutput( + quantity="energy", unit=outputs["energy"].unit, per_atom=True + ) + self._unfolded_run_options = ModelEvaluationOptions( + length_unit=self._model.capabilities().length_unit, + outputs={"energy": energies_output}, + selected_atoms=None, + ) + + def requested_inputs(self) -> Dict[str, ModelOutput]: + return self._requested_inputs + + def barycenter_and_atomic_energies(self, system: System, n_atoms: int): + atomic_e = self._model([system], self._unfolded_run_options, False)["energy"][ + 0 + ].values.flatten() + total_e = atomic_e[:n_atoms].sum() + r_aux = system.positions.detach() + barycenter = torch.einsum("i,ik->k", atomic_e[:n_atoms], r_aux[:n_atoms]) + + return barycenter, atomic_e, total_e + + def calc_hardy_heat_flux(self, system: System) -> torch.Tensor: + n_atoms = len(system.positions) + velocities: torch.Tensor = ( + system.get_data("velocities").block().values.reshape(-1, 3) + ) + masses: torch.Tensor = ( + system.get_data("masses").block().values.reshape(-1) + ) + atomic_e = self._model([system], self.run_options, False)["energy"][ + 0 + ].values.flatten() + + hf_pot = 0 + for i, energy in enumerate(atomic_e): + hf_pot += ( + wrap_positions( + system.positions[i] - system.positions, system.cell + ) + * ( + torch.autograd.grad( + energy, system.positions, retain_graph=True if i != len(atomic_e) - 1 else False + )[0] + * velocities + ).sum(axis=1, keepdim=True) + ).sum(axis=0) + + hf_conv = ( + ( + atomic_e[:n_atoms] + + 0.5 + * masses[:n_atoms] + * torch.linalg.norm(velocities[:n_atoms], dim=1) ** 2 + * 103.6427 # u*A^2/fs^2 to eV + )[:, None] + * velocities[:n_atoms] + ).sum(dim=0) + + return hf_pot + hf_conv + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + ) -> Dict[str, TensorMap]: + run_options = ModelEvaluationOptions( + length_unit=self._model.capabilities().length_unit, + outputs=outputs, + selected_atoms=None, + ) + results = self._model(systems, run_options, False) + + if "extra::heat_flux" not in outputs: + return results + + device = systems[0].device + heat_fluxes: List[torch.Tensor] = [] + for system in systems: + system.positions.requires_grad_(True) + heat_fluxes.append(self.calc_hardy_heat_flux(system)) + + samples = Labels( + ["system"], torch.arange(len(systems), device=device).reshape(-1, 1) + ) + + hf_block = TensorBlock( + values=torch.vstack(heat_fluxes).reshape(-1, 3, 1).to(device=device), + samples=samples, + components=[Labels(["xyz"], torch.arange(3, device=device).reshape(-1, 1))], + properties=Labels(["heat_flux"], torch.tensor([[0]], device=device)), + ) + results["extra::heat_flux"] = TensorMap( + Labels("_", torch.tensor([[0]], device=device)), [hf_block] + ) + return results \ No newline at end of file From 7dfb8bf10474da752ce52e6b08f066d98015484e Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Tue, 10 Feb 2026 22:21:55 +0100 Subject: [PATCH 06/22] Fix bugs in non-orthogonal boxes --- .../metatomic/torch/heat_flux.py | 69 +++++-------- .../metatomic_torch/tests/test_heat_flux.py | 96 +++++++++++++++++-- 2 files changed, 112 insertions(+), 53 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/heat_flux.py b/python/metatomic_torch/metatomic/torch/heat_flux.py index bebe842a..9a422532 100644 --- a/python/metatomic_torch/metatomic/torch/heat_flux.py +++ b/python/metatomic_torch/metatomic/torch/heat_flux.py @@ -16,9 +16,9 @@ def wrap_positions(positions: torch.Tensor, cell: torch.Tensor) -> torch.Tensor: """ Wrap positions into the periodic cell. """ - fractional_positions = torch.einsum("iv,kv->ik", positions, cell.inverse()) + fractional_positions = torch.einsum("iv,vk->ik", positions, cell.inverse()) fractional_positions -= torch.floor(fractional_positions) - wrapped_positions = torch.einsum("iv,kv->ik", fractional_positions, cell) + wrapped_positions = torch.einsum("iv,vk->ik", fractional_positions, cell) return wrapped_positions @@ -31,22 +31,23 @@ def check_collisions( i.e. have interactions with atoms at the opposite end of the cell. """ inv_cell = cell.inverse() - norm_inv_cell = torch.linalg.norm(inv_cell, dim=1) - inv_cell /= norm_inv_cell[:, None] - cell_vec_lengths = torch.diag(cell @ inv_cell) - if cell_vec_lengths.min() < (cutoff + skin): + recip = inv_cell.T + norms = torch.linalg.norm(recip, dim=1) + heights = 1.0 / norms + if heights.min() < (cutoff + skin): raise ValueError( "Cell is too small compared to (cutoff + skin) = " + str(cutoff + skin) + ". " "Ensure that all cell vectors are at least this length. Currently, the" - " minimum cell vector length is " + str(cell_vec_lengths.min()) + "." + " minimum cell vector length is " + str(heights.min()) + "." ) cutoff += skin - norm_coords = torch.einsum("iv,kv->ik", positions, inv_cell) + normals = recip / norms[:, None] + norm_coords = torch.einsum("iv,kv->ik", positions, normals) collisions = torch.hstack( - [norm_coords <= cutoff, norm_coords >= cell_vec_lengths - cutoff], + [norm_coords <= cutoff, norm_coords >= heights - cutoff], ).to(device=positions.device) return ( @@ -104,7 +105,7 @@ def generate_replica_atoms( [0, 1, -1], device=positions.device, dtype=positions.dtype )[replicas[:, 1:]] replica_positions = positions[replica_idx] - replica_positions += torch.einsum("aA,iA->ia", cell, replica_offsets) + replica_positions += torch.einsum("iA,Aa->ia", replica_offsets, cell) return replica_idx, types[replica_idx], replica_positions @@ -370,13 +371,6 @@ class HardyHeatFluxWrapper(torch.nn.Module): A wrapper around an AtomisticModel that computes the heat flux of a system using the unfolded system approach. The heat flux is computed using the atomic energies (eV), positions(Å), masses(u), velocities(Å/fs), and the energy gradients. - - The unfolded system is generated by creating replica atoms for those near the cell - boundaries within the interaction range of the model wrapped. The wrapper adds the - heat flux to the model's outputs under the key "extra::heat_flux". - - For more details on the heat flux calculation, see `Langer, M. F., et al., Heat flux - for semilocal machine-learning potentials. (2023). Physical Review B, 108, L100302.` """ def __init__(self, model: AtomisticModel, skin: float = 0.5): @@ -424,41 +418,30 @@ def __init__(self, model: AtomisticModel, skin: float = 0.5): def requested_inputs(self) -> Dict[str, ModelOutput]: return self._requested_inputs - def barycenter_and_atomic_energies(self, system: System, n_atoms: int): - atomic_e = self._model([system], self._unfolded_run_options, False)["energy"][ - 0 - ].values.flatten() - total_e = atomic_e[:n_atoms].sum() - r_aux = system.positions.detach() - barycenter = torch.einsum("i,ik->k", atomic_e[:n_atoms], r_aux[:n_atoms]) - - return barycenter, atomic_e, total_e - def calc_hardy_heat_flux(self, system: System) -> torch.Tensor: n_atoms = len(system.positions) velocities: torch.Tensor = ( system.get_data("velocities").block().values.reshape(-1, 3) ) - masses: torch.Tensor = ( - system.get_data("masses").block().values.reshape(-1) - ) - atomic_e = self._model([system], self.run_options, False)["energy"][ + masses: torch.Tensor = system.get_data("masses").block().values.reshape(-1) + atomic_e = self._model([system], self._unfolded_run_options, False)["energy"][ 0 ].values.flatten() - hf_pot = 0 + hf_pot = torch.zeros( + 3, dtype=system.positions.dtype, device=system.positions.device + ) for i, energy in enumerate(atomic_e): + grad = torch.autograd.grad( + [energy], + [system.positions], + retain_graph=True if i != len(atomic_e) - 1 else False, + )[0] + grad = torch.jit._unwrap_optional(grad) hf_pot += ( - wrap_positions( - system.positions[i] - system.positions, system.cell - ) - * ( - torch.autograd.grad( - energy, system.positions, retain_graph=True if i != len(atomic_e) - 1 else False - )[0] - * velocities - ).sum(axis=1, keepdim=True) - ).sum(axis=0) + wrap_positions(system.positions[i] - system.positions, system.cell) + * (grad * velocities).sum(dim=1, keepdim=True) + ).sum(dim=0) hf_conv = ( ( @@ -508,4 +491,4 @@ def forward( results["extra::heat_flux"] = TensorMap( Labels("_", torch.tensor([[0]], device=device)), [hf_block] ) - return results \ No newline at end of file + return results diff --git a/python/metatomic_torch/tests/test_heat_flux.py b/python/metatomic_torch/tests/test_heat_flux.py index cd957140..a8c8eb88 100644 --- a/python/metatomic_torch/tests/test_heat_flux.py +++ b/python/metatomic_torch/tests/test_heat_flux.py @@ -15,6 +15,7 @@ ) from metatomic.torch.ase_calculator import MetatomicCalculator from metatomic.torch.heat_flux import ( + HardyHeatFluxWrapper, HeatFluxWrapper, check_collisions, collisions_to_replicas, @@ -39,11 +40,11 @@ def model(): @pytest.fixture def atoms(): - n_atoms = 250 - cell = np.array([[20.3, 0.0, 0.0], [0.0, 20.3, 0.0], [0.0, 0.0, 20.3]]) - np.random.seed(42) - positions = np.random.random((n_atoms, 3)) * (1 + 2 * 0.1) - 0.1 - atoms = Atoms(f"Ar{n_atoms}", scaled_positions=positions, cell=cell, pbc=True) + cell = np.array([[6.0, 0.0, 0.0], [0.0, 6.0, 0.0], [0.0, 0.0, 6.0]]) + positions = np.array([[3.0, 3.0, 3.0]]) + atoms = Atoms(f"Ar", scaled_positions=positions, cell=cell, pbc=True).repeat( + (2, 2, 2) + ) MaxwellBoltzmannDistribution( atoms, temperature_K=300, rng=np.random.default_rng(42) ) @@ -304,6 +305,76 @@ def __call__(self, systems, options, check_consistency): assert set(requested.keys()) == {"masses", "velocities"} +def test_unfolded_energy_order_used_for_barycenter(): + class DummyCapabilities: + def __init__(self): + self.outputs = {"energy": ModelOutput(quantity="energy", unit="eV")} + self.length_unit = "A" + self.interaction_range = 1.0 + + class DummyModel: + def __init__(self): + self._capabilities = DummyCapabilities() + + def capabilities(self): + return self._capabilities + + def __call__(self, systems, options, check_consistency): + system = systems[0] + n_atoms = len(system.positions) + values = torch.arange( + n_atoms, dtype=system.positions.dtype, device=system.positions.device + ).reshape(-1, 1) + block = TensorBlock( + values=values, + samples=Labels( + ["atoms"], + torch.arange(n_atoms, device=values.device).reshape(-1, 1), + ), + components=[], + properties=Labels( + ["energy"], torch.tensor([[0]], device=values.device) + ), + ) + return { + "energy": TensorMap( + Labels("_", torch.tensor([[0]], device=values.device)), [block] + ) + } + + cell = torch.eye(3) * 10.0 + positions = torch.tensor( + [ + [0.05, 5.0, 5.0], # near x_lo -> one replica + [9.95, 5.5, 5.0], # near x_hi -> one replica + [0.05, 6.0, 5.5], # near x_lo -> one replica + ] + ) + system = _make_system_with_data(positions, cell) + unfolded = unfold_system(system, cutoff=0.1, skin=0.0) + n_atoms = len(system.positions) + assert len(unfolded.positions) == n_atoms * 2 + + wrapper = HeatFluxWrapper(DummyModel()) + barycenter, atomic_e, total_e = wrapper.barycenter_and_atomic_energies( + unfolded, n_atoms + ) + + expected_atomic_e = torch.arange( + len(unfolded.positions), + dtype=unfolded.positions.dtype, + device=unfolded.positions.device, + ) + expected_total_e = expected_atomic_e[:n_atoms].sum() + expected_barycenter = torch.einsum( + "i,ik->k", expected_atomic_e[:n_atoms], unfolded.positions[:n_atoms] + ) + + assert torch.allclose(atomic_e, expected_atomic_e) + assert torch.allclose(total_e, expected_total_e) + assert torch.allclose(barycenter, expected_barycenter) + + def test_heat_flux_wrapper_forward_adds_output(monkeypatch): class DummyCapabilities: def __init__(self): @@ -372,9 +443,16 @@ def _fake_hf(self, system): assert torch.allclose(hf_block.values[:, :, 0], torch.tensor([[1.0, 2.0, 3.0]] * 2)) -def test_heat_flux_wrapper_calc_unfolded_heat_flux(model, atoms): +@pytest.mark.parametrize( + "heat_flux,expected", + [ + (HardyHeatFluxWrapper, [[8.1053e-05], [-1.2710e-05], [-2.8778e-04]]), + # (HeatFluxWrapper, [[4.0898e-05], [-3.1652e-04], [-2.1660e-04]]), + ], +) +def test_heat_flux_wrapper_calc_heat_flux(heat_flux, expected, model, atoms): metadata = ModelMetadata() - wrapper = HeatFluxWrapper(model.eval()) + wrapper = heat_flux(model.eval()) cap = wrapper._model.capabilities() outputs = cap.outputs.copy() outputs["extra::heat_flux"] = ModelOutput( @@ -413,7 +491,5 @@ def test_heat_flux_wrapper_calc_unfolded_heat_flux(model, atoms): results = atoms.calc.additional_outputs["extra::heat_flux"].block().values assert torch.allclose( results, - torch.tensor( - [[5.50695568e12], [2.89550111e13], [-1.64821616e13]], dtype=results.dtype - ), + torch.tensor(expected, dtype=results.dtype), ) From e57623f4c662229ca065f9eb9a4d84d9f530df39 Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Thu, 12 Feb 2026 11:10:40 +0100 Subject: [PATCH 07/22] Update tests --- .../metatomic_torch/tests/test_heat_flux.py | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/python/metatomic_torch/tests/test_heat_flux.py b/python/metatomic_torch/tests/test_heat_flux.py index a8c8eb88..78c65d76 100644 --- a/python/metatomic_torch/tests/test_heat_flux.py +++ b/python/metatomic_torch/tests/test_heat_flux.py @@ -147,8 +147,8 @@ def test_wrap_positions_triclinic_fractional_bounds_and_shift(): ) inv_cell = cell.inverse() wrapped = wrap_positions(positions, cell) - fractional_before = torch.einsum("iv,kv->ik", positions, inv_cell) - fractional_after = torch.einsum("iv,kv->ik", wrapped, inv_cell) + fractional_before = torch.einsum("iv,vk->ik", positions, inv_cell) + fractional_after = torch.einsum("iv,vk->ik", wrapped, inv_cell) assert torch.all(fractional_after >= 0) assert torch.all(fractional_after < 1) @@ -169,23 +169,25 @@ def test_check_collisions_triclinic_targets(): ) cutoff = 0.2 inv_cell = cell.inverse() - inv_cell_norm = inv_cell / torch.linalg.norm(inv_cell, dim=1)[:, None] - cell_vec_lengths = torch.diag(cell @ inv_cell_norm) + recip = inv_cell.T + norms = torch.linalg.norm(recip, dim=1) + heights = 1.0 / norms + norm_vectors = recip / norms[:, None] target = torch.stack( [ torch.tensor([0.05, 0.6, 0.6]), - torch.tensor([cell_vec_lengths[0] - 0.05, 0.05, cell_vec_lengths[2] - 0.1]), - torch.tensor([0.3, cell_vec_lengths[1] - 0.05, 0.1]), + torch.tensor([heights[0] - 0.05, 0.05, heights[2] - 0.1]), + torch.tensor([0.3, heights[1] - 0.05, 0.1]), ] ) - positions = target @ torch.inverse(inv_cell_norm).T + positions = target @ torch.inverse(norm_vectors).T collisions, norm_coords = check_collisions(cell, positions, cutoff=cutoff, skin=0.0) assert torch.allclose(norm_coords, target, atol=1e-6, rtol=0) expected_low = target <= cutoff - expected_high = target >= cell_vec_lengths - cutoff + expected_high = target >= heights - cutoff expected = torch.hstack([expected_low, expected_high]) expected = expected[:, [0, 3, 1, 4, 2, 5]] @@ -224,16 +226,16 @@ def test_generate_replica_atoms_triclinic_offsets(): ) types = torch.tensor([1]) positions = torch.tensor([[0.2, 0.4, 0.6]]) - collisions = torch.tensor([[True, False, False, True, False, False]]) + collisions = torch.tensor([[True, False, True, False, True, False]]) replicas = collisions_to_replicas(collisions) replica_idx, replica_types, replica_positions = generate_replica_atoms( types, positions, cell, replicas ) - assert replica_idx.tolist() == [0, 0, 0] - assert replica_types.tolist() == [1, 1, 1] + assert replica_idx.tolist() == [0, 0, 0, 0, 0, 0, 0] + assert replica_types.tolist() == [1, 1, 1, 1, 1, 1, 1] - expected_offsets = [cell[:, 0], -cell[:, 1], cell[:, 0] - cell[:, 1]] + expected_offsets = [cell[0], cell[1], cell[2], cell[0] + cell[1], cell[0] + cell[2], cell[1] + cell[2], cell[0] + cell[1] + cell[2]] expected_positions = [positions[0] + offset for offset in expected_offsets] for expected in expected_positions: @@ -446,8 +448,8 @@ def _fake_hf(self, system): @pytest.mark.parametrize( "heat_flux,expected", [ - (HardyHeatFluxWrapper, [[8.1053e-05], [-1.2710e-05], [-2.8778e-04]]), - # (HeatFluxWrapper, [[4.0898e-05], [-3.1652e-04], [-2.1660e-04]]), + # (HardyHeatFluxWrapper, [[4.0898e-05], [-3.1652e-04], [-2.1660e-04]]), + (HeatFluxWrapper, [[8.1053e-05], [-1.2710e-05], [-2.8778e-04]]), ], ) def test_heat_flux_wrapper_calc_heat_flux(heat_flux, expected, model, atoms): From 79f3f78cdfca54bc41f117483339e6a10c12934a Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Wed, 18 Feb 2026 13:38:37 +0100 Subject: [PATCH 08/22] Slightly faster version of VJP calculation of `term1` --- .../metatomic/torch/heat_flux.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/heat_flux.py b/python/metatomic_torch/metatomic/torch/heat_flux.py index 9a422532..fb9298ae 100644 --- a/python/metatomic_torch/metatomic/torch/heat_flux.py +++ b/python/metatomic_torch/metatomic/torch/heat_flux.py @@ -286,19 +286,19 @@ def calc_unfolded_heat_flux(self, system: System) -> torch.Tensor: barycenter, atomic_e, total_e = self.barycenter_and_atomic_energies( unfolded_system, n_atoms ) - - term1 = torch.zeros( - (3), device=system.positions.device, dtype=system.positions.dtype + basis = torch.eye( + 3, device=barycenter.device, dtype=barycenter.dtype ) - for i in range(3): - grad_i = torch.autograd.grad( - [barycenter[i]], - [unfolded_system.positions], - retain_graph=True, - create_graph=False, - )[0] - grad_i = torch.jit._unwrap_optional(grad_i) - term1[i] = (grad_i * velocities).sum() + term1_grads = torch.autograd.grad( + [barycenter], + [unfolded_system.positions], + grad_outputs=[basis], + retain_graph=True, + create_graph=False, + is_grads_batched=True, + ) + term1_grads = torch.jit._unwrap_optional(term1_grads[0]) + term1 = (term1_grads * velocities.unsqueeze(0)).sum(dim=(1, 2)) go = torch.jit.annotate( Optional[List[Optional[torch.Tensor]]], [torch.ones_like(total_e)] From e4093c20b184e8109fd8678238820c2a462d3e9c Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Mon, 23 Feb 2026 00:05:32 +0100 Subject: [PATCH 09/22] Torchscript-compatibility and more tests --- .../metatomic/torch/heat_flux.py | 25 +- .../metatomic_torch/tests/test_heat_flux.py | 260 ++++++++++++------ 2 files changed, 189 insertions(+), 96 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/heat_flux.py b/python/metatomic_torch/metatomic/torch/heat_flux.py index fb9298ae..19fa757b 100644 --- a/python/metatomic_torch/metatomic/torch/heat_flux.py +++ b/python/metatomic_torch/metatomic/torch/heat_flux.py @@ -255,6 +255,7 @@ def __init__(self, model: AtomisticModel, skin: float = 0.5): selected_atoms=None, ) + @torch.jit.export def requested_inputs(self) -> Dict[str, ModelOutput]: return self._requested_inputs @@ -286,19 +287,19 @@ def calc_unfolded_heat_flux(self, system: System) -> torch.Tensor: barycenter, atomic_e, total_e = self.barycenter_and_atomic_energies( unfolded_system, n_atoms ) - basis = torch.eye( - 3, device=barycenter.device, dtype=barycenter.dtype - ) - term1_grads = torch.autograd.grad( - [barycenter], - [unfolded_system.positions], - grad_outputs=[basis], - retain_graph=True, - create_graph=False, - is_grads_batched=True, + + term1 = torch.zeros( + (3), device=system.positions.device, dtype=system.positions.dtype ) - term1_grads = torch.jit._unwrap_optional(term1_grads[0]) - term1 = (term1_grads * velocities.unsqueeze(0)).sum(dim=(1, 2)) + for i in range(3): + grad_i = torch.autograd.grad( + [barycenter[i]], + [unfolded_system.positions], + retain_graph=True, + create_graph=False, + )[0] + grad_i = torch.jit._unwrap_optional(grad_i) + term1[i] = (grad_i * velocities).sum() go = torch.jit.annotate( Optional[List[Optional[torch.Tensor]]], [torch.ones_like(total_e)] diff --git a/python/metatomic_torch/tests/test_heat_flux.py b/python/metatomic_torch/tests/test_heat_flux.py index 78c65d76..f62455a8 100644 --- a/python/metatomic_torch/tests/test_heat_flux.py +++ b/python/metatomic_torch/tests/test_heat_flux.py @@ -15,7 +15,6 @@ ) from metatomic.torch.ase_calculator import MetatomicCalculator from metatomic.torch.heat_flux import ( - HardyHeatFluxWrapper, HeatFluxWrapper, check_collisions, collisions_to_replicas, @@ -97,6 +96,44 @@ def _make_system_with_data(positions: torch.Tensor, cell: torch.Tensor) -> Syste return system +class _DummyCapabilities: + """Reusable stub for ``model.capabilities()``.""" + + def __init__(self, energy_unit: str = "eV"): + self.outputs = {"energy": ModelOutput(quantity="energy", unit=energy_unit)} + self.length_unit = "A" + self.interaction_range = 1.0 + + +class _ZeroDummyModel: + """Dummy model returning zero energies. Accepts an optional *energy_unit*.""" + + def __init__(self, energy_unit: str = "eV"): + self._capabilities = _DummyCapabilities(energy_unit) + + def capabilities(self): + return self._capabilities + + def __call__(self, systems, options, check_consistency): + values = torch.zeros((len(systems), 1), dtype=systems[0].positions.dtype) + block = TensorBlock( + values=values, + samples=Labels( + ["system"], + torch.arange(len(systems), device=values.device).reshape(-1, 1), + ), + components=[], + properties=Labels( + ["energy"], torch.tensor([[0]], device=values.device) + ), + ) + return { + "energy": TensorMap( + Labels("_", torch.tensor([[0]], device=values.device)), [block] + ) + } + + def test_wrap_positions_cubic_matches_expected(): cell = torch.eye(3) * 2.0 positions = torch.tensor([[-0.1, 0.0, 0.0], [2.1, 1.0, -0.5]]) @@ -201,6 +238,21 @@ def test_check_collisions_raises_on_small_cell(): check_collisions(cell, positions, cutoff=0.9, skin=0.2) +def test_skin_parameter_affects_collisions(): + """Increasing the skin should extend the effective detection range.""" + cell = torch.eye(3) * 2.0 + # atom at distance 0.3 from the low-x boundary + positions = torch.tensor([[0.3, 1.0, 1.0]]) + + # cutoff=0.2, skin=0.0 → effective range 0.2 < 0.3 → no collision + collisions_no_skin, _ = check_collisions(cell, positions, cutoff=0.2, skin=0.0) + assert not collisions_no_skin.any() + + # cutoff=0.2, skin=0.2 → effective range 0.4 > 0.3 → x_lo collision + collisions_with_skin, _ = check_collisions(cell, positions, cutoff=0.2, skin=0.2) + assert collisions_with_skin[0, 0].item() # x_lo + + def test_collisions_to_replicas_combines_displacements(): collisions = torch.tensor([[True, False, False, True, False, False]]) replicas = collisions_to_replicas(collisions) @@ -266,57 +318,66 @@ def test_unfold_system_adds_replica_and_data(): ) -def test_heat_flux_wrapper_requested_inputs(): - class DummyCapabilities: - def __init__(self): - self.outputs = {"energy": ModelOutput(quantity="energy", unit="eV")} - self.length_unit = "A" - self.interaction_range = 1.0 +def test_unfold_system_no_replicas_for_interior_atoms(): + """Atoms well inside the cell should produce no replicas.""" + cell = torch.eye(3) * 10.0 + positions = torch.tensor([[5.0, 5.0, 5.0], [3.0, 4.0, 6.0]]) + system = _make_system_with_data(positions, cell) + unfolded = unfold_system(system, cutoff=1.0, skin=0.0) - class DummyModel: - def __init__(self): - self._capabilities = DummyCapabilities() + assert len(unfolded.positions) == 2 + assert torch.allclose(unfolded.positions, wrap_positions(positions, cell)) - def capabilities(self): - return self._capabilities - def __call__(self, systems, options, check_consistency): - results = {} - if "energy" in options.outputs: - values = torch.zeros( - (len(systems), 1), dtype=systems[0].positions.dtype - ) - block = TensorBlock( - values=values, - samples=Labels( - ["system"], - torch.arange(len(systems), device=values.device).reshape(-1, 1), - ), - components=[], - properties=Labels( - ["energy"], torch.tensor([[0]], device=values.device) - ), - ) - results["energy"] = TensorMap( - Labels("_", torch.tensor([[0]], device=values.device)), [block] - ) - return results +def test_unfold_system_triclinic_cell(): + """Unfolding should work for triclinic cells and propagate all data.""" + cell = torch.tensor( + [ + [4.0, 0.6, 0.4], + [0.2, 3.4, 0.8], + [0.4, 1.0, 3.8], + ] + ) + # One atom near the origin (close to low boundaries), one in the interior + positions = torch.tensor( + [ + [0.05, 0.05, 0.05], + [2.0, 1.7, 1.9], + ] + ) + system = _make_system_with_data(positions, cell) + unfolded = unfold_system(system, cutoff=0.3, skin=0.0) + + # The near-origin atom should generate at least one replica + assert len(unfolded.positions) > 2 + assert torch.all(unfolded.pbc == torch.tensor([False, False, False])) + assert torch.allclose(unfolded.cell, torch.zeros_like(unfolded.cell)) + assert torch.all(unfolded.types == 1) + assert unfolded.get_data("masses").block().values.shape[0] == len( + unfolded.positions + ) + assert unfolded.get_data("velocities").block().values.shape[0] == len( + unfolded.positions + ) + + +def test_heat_flux_wrapper_rejects_non_eV_energy(): + with pytest.raises(ValueError, match="energy outputs in eV"): + HeatFluxWrapper(_ZeroDummyModel(energy_unit="kcal/mol")) - wrapper = HeatFluxWrapper(DummyModel()) + +def test_heat_flux_wrapper_requested_inputs(): + wrapper = HeatFluxWrapper(_ZeroDummyModel()) requested = wrapper.requested_inputs() assert set(requested.keys()) == {"masses", "velocities"} def test_unfolded_energy_order_used_for_barycenter(): - class DummyCapabilities: - def __init__(self): - self.outputs = {"energy": ModelOutput(quantity="energy", unit="eV")} - self.length_unit = "A" - self.interaction_range = 1.0 + class _ArangeDummyModel: + """Returns per-atom energies [0, 1, 2, …] so ordering can be verified.""" - class DummyModel: def __init__(self): - self._capabilities = DummyCapabilities() + self._capabilities = _DummyCapabilities() def capabilities(self): return self._capabilities @@ -357,7 +418,7 @@ def __call__(self, systems, options, check_consistency): n_atoms = len(system.positions) assert len(unfolded.positions) == n_atoms * 2 - wrapper = HeatFluxWrapper(DummyModel()) + wrapper = HeatFluxWrapper(_ArangeDummyModel()) barycenter, atomic_e, total_e = wrapper.barycenter_and_atomic_energies( unfolded, n_atoms ) @@ -378,44 +439,12 @@ def __call__(self, systems, options, check_consistency): def test_heat_flux_wrapper_forward_adds_output(monkeypatch): - class DummyCapabilities: - def __init__(self): - self.outputs = {"energy": ModelOutput(quantity="energy", unit="eV")} - self.length_unit = "A" - self.interaction_range = 1.0 - - class DummyModel: - def __init__(self): - self._capabilities = DummyCapabilities() - - def capabilities(self): - return self._capabilities - - def __call__(self, systems, options, check_consistency): - values = torch.zeros((len(systems), 1), dtype=systems[0].positions.dtype) - block = TensorBlock( - values=values, - samples=Labels( - ["system"], - torch.arange(len(systems), device=values.device).reshape(-1, 1), - ), - components=[], - properties=Labels( - ["energy"], torch.tensor([[0]], device=values.device) - ), - ) - return { - "energy": TensorMap( - Labels("_", torch.tensor([[0]], device=values.device)), [block] - ) - } - def _fake_hf(self, system): return torch.tensor( [1.0, 2.0, 3.0], device=system.device, dtype=system.positions.dtype ) - wrapper = HeatFluxWrapper(DummyModel()) + wrapper = HeatFluxWrapper(_ZeroDummyModel()) monkeypatch.setattr(HeatFluxWrapper, "calc_unfolded_heat_flux", _fake_hf) cell = torch.eye(3) @@ -445,16 +474,32 @@ def _fake_hf(self, system): assert torch.allclose(hf_block.values[:, :, 0], torch.tensor([[1.0, 2.0, 3.0]] * 2)) -@pytest.mark.parametrize( - "heat_flux,expected", - [ - # (HardyHeatFluxWrapper, [[4.0898e-05], [-3.1652e-04], [-2.1660e-04]]), - (HeatFluxWrapper, [[8.1053e-05], [-1.2710e-05], [-2.8778e-04]]), - ], -) -def test_heat_flux_wrapper_calc_heat_flux(heat_flux, expected, model, atoms): +def test_forward_without_heat_flux_returns_model_results(): + """When ``extra::heat_flux`` is not requested, forward should return model + results unchanged and *not* invoke the heat-flux computation.""" + wrapper = HeatFluxWrapper(_ZeroDummyModel()) + + cell = torch.eye(3) + systems = [ + System( + types=torch.tensor([1], dtype=torch.int32), + positions=torch.zeros((1, 3)), + cell=cell, + pbc=torch.tensor([True, True, True]), + ), + ] + outputs = {"energy": ModelOutput(quantity="energy", unit="eV")} + results = wrapper.forward(systems, outputs, None) + + assert "energy" in results + assert "extra::heat_flux" not in results + + +def test_heat_flux_wrapper_calc_heat_flux(model, atoms): + expected = [[8.1053e-05], [-1.2710e-05], [-2.8778e-04]] + metadata = ModelMetadata() - wrapper = heat_flux(model.eval()) + wrapper = HeatFluxWrapper(model.eval()) cap = wrapper._model.capabilities() outputs = cap.outputs.copy() outputs["extra::heat_flux"] = ModelOutput( @@ -495,3 +540,50 @@ def test_heat_flux_wrapper_calc_heat_flux(heat_flux, expected, model, atoms): results, torch.tensor(expected, dtype=results.dtype), ) + + +def test_torch_scriptability(model, atoms): + expected = [[8.1053e-05], [-1.2710e-05], [-2.8778e-04]] + metadata = ModelMetadata() + wrapper = HeatFluxWrapper(model.eval()) + cap = wrapper._model.capabilities() + outputs = cap.outputs.copy() + outputs["extra::heat_flux"] = ModelOutput( + quantity="heat_flux", + unit="", + explicit_gradients=[], + per_atom=False, + ) + + new_cap = ModelCapabilities( + outputs=outputs, + atomic_types=cap.atomic_types, + interaction_range=cap.interaction_range, + length_unit=cap.length_unit, + supported_devices=cap.supported_devices, + dtype=cap.dtype, + ) + scripted = torch.jit.script(wrapper) + heat_model = AtomisticModel(scripted.eval(), metadata, capabilities=new_cap).to( + device="cpu" + ) + calc = MetatomicCalculator( + heat_model, + device="cpu", + additional_outputs={ + "extra::heat_flux": ModelOutput( + quantity="heat_flux", + unit="", + explicit_gradients=[], + per_atom=False, + ) + }, + ) + atoms.calc = calc + atoms.get_potential_energy() + assert "extra::heat_flux" in atoms.calc.additional_outputs + results = atoms.calc.additional_outputs["extra::heat_flux"].block().values + assert torch.allclose( + results, + torch.tensor(expected, dtype=results.dtype), + ) From 155747a28fe842917808f5725d2eb47bc840f0f7 Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Mon, 23 Feb 2026 09:13:09 +0100 Subject: [PATCH 10/22] Deps and minor fix --- tox.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/tox.ini b/tox.ini index 24612bf9..229f9d9a 100644 --- a/tox.ini +++ b/tox.ini @@ -146,6 +146,7 @@ deps = torch=={env:METATOMIC_TESTS_TORCH_VERSION:2.10}.* numpy {env:METATOMIC_TESTS_NUMPY_VERSION_PIN} vesin + vesin-torch ase # for metatensor-lj-test setuptools-scm From ec04879556f3c0565514bf1a0bae2dbff2aade84 Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Tue, 24 Feb 2026 10:24:47 +0100 Subject: [PATCH 11/22] Improve the tests a little bit --- .../metatomic/torch/heat_flux.py | 145 ++---------------- .../metatomic_torch/tests/test_heat_flux.py | 70 +++------ 2 files changed, 29 insertions(+), 186 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/heat_flux.py b/python/metatomic_torch/metatomic/torch/heat_flux.py index 19fa757b..c92378aa 100644 --- a/python/metatomic_torch/metatomic/torch/heat_flux.py +++ b/python/metatomic_torch/metatomic/torch/heat_flux.py @@ -17,7 +17,7 @@ def wrap_positions(positions: torch.Tensor, cell: torch.Tensor) -> torch.Tensor: Wrap positions into the periodic cell. """ fractional_positions = torch.einsum("iv,vk->ik", positions, cell.inverse()) - fractional_positions -= torch.floor(fractional_positions) + fractional_positions = fractional_positions - torch.floor(fractional_positions) wrapped_positions = torch.einsum("iv,vk->ik", fractional_positions, cell) return wrapped_positions @@ -43,7 +43,7 @@ def check_collisions( " minimum cell vector length is " + str(heights.min()) + "." ) - cutoff += skin + cutoff = cutoff + skin normals = recip / norms[:, None] norm_coords = torch.einsum("iv,kv->ik", positions, normals) collisions = torch.hstack( @@ -104,8 +104,7 @@ def generate_replica_atoms( replica_offsets = torch.tensor( [0, 1, -1], device=positions.device, dtype=positions.dtype )[replicas[:, 1:]] - replica_positions = positions[replica_idx] - replica_positions += torch.einsum("iA,Aa->ia", replica_offsets, cell) + replica_positions = positions[replica_idx] + torch.einsum("iA,Aa->ia", replica_offsets, cell) return replica_idx, types[replica_idx], replica_positions @@ -260,9 +259,13 @@ def requested_inputs(self) -> Dict[str, ModelOutput]: return self._requested_inputs def barycenter_and_atomic_energies(self, system: System, n_atoms: int): - atomic_e = self._model([system], self._unfolded_run_options, False)["energy"][ - 0 - ].values.flatten() + energy_block = self._model([system], self._unfolded_run_options, False)[ + "energy" + ].block(0) + atom_indices = energy_block.samples.column("atom").to(torch.long) + sorted_order = torch.argsort(atom_indices) + atomic_e = energy_block.values.flatten()[sorted_order] + total_e = atomic_e[:n_atoms].sum() r_aux = system.positions.detach() barycenter = torch.einsum("i,ik->k", atomic_e[:n_atoms], r_aux[:n_atoms]) @@ -365,131 +368,3 @@ def forward( Labels("_", torch.tensor([[0]], device=device)), [hf_block] ) return results - - -class HardyHeatFluxWrapper(torch.nn.Module): - """ - A wrapper around an AtomisticModel that computes the heat flux of a system using the - unfolded system approach. The heat flux is computed using the atomic energies (eV), - positions(Å), masses(u), velocities(Å/fs), and the energy gradients. - """ - - def __init__(self, model: AtomisticModel, skin: float = 0.5): - """ - :param model: the :py:class:`AtomisticModel` to wrap, which should be able to - compute atomic energies and their gradients with respect to positions - :param skin: the skin parameter for unfolding the system. The wrapper will - generate replica atoms for those within (interaction_range + skin) distance from - the cell boundaries. A skin results in more replica atoms and thus higher - computational cost, but ensures that the heat flux is computed correctly. - """ - super().__init__() - - self._model = model - self.skin = skin - self._interaction_range = model.capabilities().interaction_range - - self._requested_inputs = { - "masses": ModelOutput(quantity="mass", unit="u", per_atom=True), - "velocities": ModelOutput(quantity="velocity", unit="A/fs", per_atom=True), - } - - hf_output = ModelOutput( - quantity="heat_flux", - unit="", - explicit_gradients=[], - per_atom=False, - ) - outputs = self._model.capabilities().outputs.copy() - outputs["extra::heat_flux"] = hf_output - self._model.capabilities().outputs["extra::heat_flux"] = hf_output - if outputs["energy"].unit != "eV": - raise ValueError( - "HeatFluxWrapper can only be used with energy outputs in eV" - ) - energies_output = ModelOutput( - quantity="energy", unit=outputs["energy"].unit, per_atom=True - ) - self._unfolded_run_options = ModelEvaluationOptions( - length_unit=self._model.capabilities().length_unit, - outputs={"energy": energies_output}, - selected_atoms=None, - ) - - def requested_inputs(self) -> Dict[str, ModelOutput]: - return self._requested_inputs - - def calc_hardy_heat_flux(self, system: System) -> torch.Tensor: - n_atoms = len(system.positions) - velocities: torch.Tensor = ( - system.get_data("velocities").block().values.reshape(-1, 3) - ) - masses: torch.Tensor = system.get_data("masses").block().values.reshape(-1) - atomic_e = self._model([system], self._unfolded_run_options, False)["energy"][ - 0 - ].values.flatten() - - hf_pot = torch.zeros( - 3, dtype=system.positions.dtype, device=system.positions.device - ) - for i, energy in enumerate(atomic_e): - grad = torch.autograd.grad( - [energy], - [system.positions], - retain_graph=True if i != len(atomic_e) - 1 else False, - )[0] - grad = torch.jit._unwrap_optional(grad) - hf_pot += ( - wrap_positions(system.positions[i] - system.positions, system.cell) - * (grad * velocities).sum(dim=1, keepdim=True) - ).sum(dim=0) - - hf_conv = ( - ( - atomic_e[:n_atoms] - + 0.5 - * masses[:n_atoms] - * torch.linalg.norm(velocities[:n_atoms], dim=1) ** 2 - * 103.6427 # u*A^2/fs^2 to eV - )[:, None] - * velocities[:n_atoms] - ).sum(dim=0) - - return hf_pot + hf_conv - - def forward( - self, - systems: List[System], - outputs: Dict[str, ModelOutput], - selected_atoms: Optional[Labels], - ) -> Dict[str, TensorMap]: - run_options = ModelEvaluationOptions( - length_unit=self._model.capabilities().length_unit, - outputs=outputs, - selected_atoms=None, - ) - results = self._model(systems, run_options, False) - - if "extra::heat_flux" not in outputs: - return results - - device = systems[0].device - heat_fluxes: List[torch.Tensor] = [] - for system in systems: - system.positions.requires_grad_(True) - heat_fluxes.append(self.calc_hardy_heat_flux(system)) - - samples = Labels( - ["system"], torch.arange(len(systems), device=device).reshape(-1, 1) - ) - - hf_block = TensorBlock( - values=torch.vstack(heat_fluxes).reshape(-1, 3, 1).to(device=device), - samples=samples, - components=[Labels(["xyz"], torch.arange(3, device=device).reshape(-1, 1))], - properties=Labels(["heat_flux"], torch.tensor([[0]], device=device)), - ) - results["extra::heat_flux"] = TensorMap( - Labels("_", torch.tensor([[0]], device=device)), [hf_block] - ) - return results diff --git a/python/metatomic_torch/tests/test_heat_flux.py b/python/metatomic_torch/tests/test_heat_flux.py index f62455a8..d379320c 100644 --- a/python/metatomic_torch/tests/test_heat_flux.py +++ b/python/metatomic_torch/tests/test_heat_flux.py @@ -391,8 +391,18 @@ def __call__(self, systems, options, check_consistency): block = TensorBlock( values=values, samples=Labels( - ["atoms"], - torch.arange(n_atoms, device=values.device).reshape(-1, 1), + ["system", "atom"], + torch.stack( + [ + torch.zeros( + n_atoms, + dtype=torch.int32, + device=values.device, + ), + torch.arange(n_atoms, device=values.device), + ], + dim=1, + ), ), components=[], properties=Labels( @@ -476,7 +486,7 @@ def _fake_hf(self, system): def test_forward_without_heat_flux_returns_model_results(): """When ``extra::heat_flux`` is not requested, forward should return model - results unchanged and *not* invoke the heat-flux computation.""" + results unchanged and not invoke the heat-flux computation.""" wrapper = HeatFluxWrapper(_ZeroDummyModel()) cell = torch.eye(3) @@ -495,8 +505,9 @@ def test_forward_without_heat_flux_returns_model_results(): assert "extra::heat_flux" not in results -def test_heat_flux_wrapper_calc_heat_flux(model, atoms): - expected = [[8.1053e-05], [-1.2710e-05], [-2.8778e-04]] +@pytest.mark.parametrize("use_script", [True, False]) +def test_heat_flux_wrapper_calc_heat_flux(model, atoms, use_script): + expected = [[8.8238e-05], [-2.5559e-04], [-2.0570e-04]] metadata = ModelMetadata() wrapper = HeatFluxWrapper(model.eval()) @@ -517,54 +528,11 @@ def test_heat_flux_wrapper_calc_heat_flux(model, atoms): supported_devices=cap.supported_devices, dtype=cap.dtype, ) - heat_model = AtomisticModel(wrapper.eval(), metadata, capabilities=new_cap).to( - device="cpu" - ) - calc = MetatomicCalculator( - heat_model, - device="cpu", - additional_outputs={ - "extra::heat_flux": ModelOutput( - quantity="heat_flux", - unit="", - explicit_gradients=[], - per_atom=False, - ) - }, - ) - atoms.calc = calc - atoms.get_potential_energy() - assert "extra::heat_flux" in atoms.calc.additional_outputs - results = atoms.calc.additional_outputs["extra::heat_flux"].block().values - assert torch.allclose( - results, - torch.tensor(expected, dtype=results.dtype), - ) + if use_script: + wrapper = torch.jit.script(wrapper) -def test_torch_scriptability(model, atoms): - expected = [[8.1053e-05], [-1.2710e-05], [-2.8778e-04]] - metadata = ModelMetadata() - wrapper = HeatFluxWrapper(model.eval()) - cap = wrapper._model.capabilities() - outputs = cap.outputs.copy() - outputs["extra::heat_flux"] = ModelOutput( - quantity="heat_flux", - unit="", - explicit_gradients=[], - per_atom=False, - ) - - new_cap = ModelCapabilities( - outputs=outputs, - atomic_types=cap.atomic_types, - interaction_range=cap.interaction_range, - length_unit=cap.length_unit, - supported_devices=cap.supported_devices, - dtype=cap.dtype, - ) - scripted = torch.jit.script(wrapper) - heat_model = AtomisticModel(scripted.eval(), metadata, capabilities=new_cap).to( + heat_model = AtomisticModel(wrapper.eval(), metadata, capabilities=new_cap).to( device="cpu" ) calc = MetatomicCalculator( From 677cd19bb84464779f1531e1a88ae2e81ebe4d7b Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Tue, 24 Feb 2026 22:02:45 +0100 Subject: [PATCH 12/22] Let's go vesin-0.5.1 --- .../metatomic/torch/heat_flux.py | 46 +++++++++++-------- python/metatomic_torch/pyproject.toml | 1 + .../metatomic_torch/tests/test_heat_flux.py | 16 +++++-- 3 files changed, 38 insertions(+), 25 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/heat_flux.py b/python/metatomic_torch/metatomic/torch/heat_flux.py index c92378aa..3b06d86f 100644 --- a/python/metatomic_torch/metatomic/torch/heat_flux.py +++ b/python/metatomic_torch/metatomic/torch/heat_flux.py @@ -2,7 +2,7 @@ import torch from metatensor.torch import Labels, TensorBlock, TensorMap -from vesin.metatomic import compute_requested_neighbors +from vesin.metatomic import compute_requested_neighbors_from_options from metatomic.torch import ( AtomisticModel, @@ -12,7 +12,7 @@ ) -def wrap_positions(positions: torch.Tensor, cell: torch.Tensor) -> torch.Tensor: +def _wrap_positions(positions: torch.Tensor, cell: torch.Tensor) -> torch.Tensor: """ Wrap positions into the periodic cell. """ @@ -23,7 +23,7 @@ def wrap_positions(positions: torch.Tensor, cell: torch.Tensor) -> torch.Tensor: return wrapped_positions -def check_collisions( +def _check_collisions( cell: torch.Tensor, positions: torch.Tensor, cutoff: float, skin: float ) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -58,7 +58,7 @@ def check_collisions( ) -def collisions_to_replicas(collisions: torch.Tensor) -> torch.Tensor: +def _collisions_to_replicas(collisions: torch.Tensor) -> torch.Tensor: """ Convert boundary-collision flags into a boolean mask over all periodic image displacements in {0, +1, -1}^3. e.g. for an atom colliding with the x_lo and y_hi @@ -87,7 +87,7 @@ def collisions_to_replicas(collisions: torch.Tensor) -> torch.Tensor: return outs.to(device=collisions.device) -def generate_replica_atoms( +def _generate_replica_atoms( types: torch.Tensor, positions: torch.Tensor, cell: torch.Tensor, @@ -104,26 +104,30 @@ def generate_replica_atoms( replica_offsets = torch.tensor( [0, 1, -1], device=positions.device, dtype=positions.dtype )[replicas[:, 1:]] - replica_positions = positions[replica_idx] + torch.einsum("iA,Aa->ia", replica_offsets, cell) + replica_positions = positions[replica_idx] + torch.einsum( + "iA,Aa->ia", replica_offsets, cell + ) return replica_idx, types[replica_idx], replica_positions -def unfold_system(metatomic_system: System, cutoff: float, skin: float = 0.5) -> System: +def _unfold_system( + metatomic_system: System, cutoff: float, skin: float = 0.5 +) -> System: """ Unfold a periodic system by generating replica atoms for those near the cell boundaries within the specified cutoff distance. The unfolded system has no periodic boundary conditions. """ - wrapped_positions = wrap_positions( + wrapped_positions = _wrap_positions( metatomic_system.positions, metatomic_system.cell ) - collisions, _ = check_collisions( + collisions, _ = _check_collisions( metatomic_system.cell, wrapped_positions, cutoff, skin ) - replicas = collisions_to_replicas(collisions) - replica_idx, replica_types, replica_positions = generate_replica_atoms( + replicas = _collisions_to_replicas(collisions) + replica_idx, replica_types, replica_positions = _generate_replica_atoms( metatomic_system.types, wrapped_positions, metatomic_system.cell, replicas ) unfolded_types = torch.cat( @@ -258,7 +262,7 @@ def __init__(self, model: AtomisticModel, skin: float = 0.5): def requested_inputs(self) -> Dict[str, ModelOutput]: return self._requested_inputs - def barycenter_and_atomic_energies(self, system: System, n_atoms: int): + def _barycenter_and_atomic_energies(self, system: System, n_atoms: int): energy_block = self._model([system], self._unfolded_run_options, False)[ "energy" ].block(0) @@ -272,22 +276,24 @@ def barycenter_and_atomic_energies(self, system: System, n_atoms: int): return barycenter, atomic_e, total_e - def calc_unfolded_heat_flux(self, system: System) -> torch.Tensor: + def _calc_unfolded_heat_flux(self, system: System) -> torch.Tensor: n_atoms = len(system.positions) - unfolded_system = unfold_system(system, self._interaction_range, self.skin).to( - "cpu" + unfolded_system = _unfold_system(system, self._interaction_range, self.skin).to( + system.device ) - compute_requested_neighbors( - unfolded_system, self._unfolded_run_options.length_unit, model=self._model + compute_requested_neighbors_from_options( + [unfolded_system], + self._model.requested_neighbor_lists(), + self._unfolded_run_options.length_unit, + False, ) - unfolded_system = unfolded_system.to(system.device) velocities: torch.Tensor = ( unfolded_system.get_data("velocities").block().values.reshape(-1, 3) ) masses: torch.Tensor = ( unfolded_system.get_data("masses").block().values.reshape(-1) ) - barycenter, atomic_e, total_e = self.barycenter_and_atomic_energies( + barycenter, atomic_e, total_e = self._barycenter_and_atomic_energies( unfolded_system, n_atoms ) @@ -352,7 +358,7 @@ def forward( heat_fluxes: List[torch.Tensor] = [] for system in systems: system.positions.requires_grad_(True) - heat_fluxes.append(self.calc_unfolded_heat_flux(system)) + heat_fluxes.append(self._calc_unfolded_heat_flux(system)) samples = Labels( ["system"], torch.arange(len(systems), device=device).reshape(-1, 1) diff --git a/python/metatomic_torch/pyproject.toml b/python/metatomic_torch/pyproject.toml index 6f7de23b..da0a82fd 100644 --- a/python/metatomic_torch/pyproject.toml +++ b/python/metatomic_torch/pyproject.toml @@ -58,6 +58,7 @@ python_files = ["*.py"] testpaths = ["tests"] filterwarnings = [ "error", + "ignore:Found metatomic.torch.*but vesin.metatomic was only tested with:UserWarning", "ignore:`torch.jit.script` is deprecated. Please switch to `torch.compile` or `torch.export`:DeprecationWarning", "ignore:`torch.jit.save` is deprecated. Please switch to `torch.export`:DeprecationWarning", "ignore:`torch.jit.load` is deprecated. Please switch to `torch.export`:DeprecationWarning", diff --git a/python/metatomic_torch/tests/test_heat_flux.py b/python/metatomic_torch/tests/test_heat_flux.py index d379320c..0cfccf5e 100644 --- a/python/metatomic_torch/tests/test_heat_flux.py +++ b/python/metatomic_torch/tests/test_heat_flux.py @@ -41,7 +41,7 @@ def model(): def atoms(): cell = np.array([[6.0, 0.0, 0.0], [0.0, 6.0, 0.0], [0.0, 0.0, 6.0]]) positions = np.array([[3.0, 3.0, 3.0]]) - atoms = Atoms(f"Ar", scaled_positions=positions, cell=cell, pbc=True).repeat( + atoms = Atoms("Ar", scaled_positions=positions, cell=cell, pbc=True).repeat( (2, 2, 2) ) MaxwellBoltzmannDistribution( @@ -123,9 +123,7 @@ def __call__(self, systems, options, check_consistency): torch.arange(len(systems), device=values.device).reshape(-1, 1), ), components=[], - properties=Labels( - ["energy"], torch.tensor([[0]], device=values.device) - ), + properties=Labels(["energy"], torch.tensor([[0]], device=values.device)), ) return { "energy": TensorMap( @@ -287,7 +285,15 @@ def test_generate_replica_atoms_triclinic_offsets(): assert replica_idx.tolist() == [0, 0, 0, 0, 0, 0, 0] assert replica_types.tolist() == [1, 1, 1, 1, 1, 1, 1] - expected_offsets = [cell[0], cell[1], cell[2], cell[0] + cell[1], cell[0] + cell[2], cell[1] + cell[2], cell[0] + cell[1] + cell[2]] + expected_offsets = [ + cell[0], + cell[1], + cell[2], + cell[0] + cell[1], + cell[0] + cell[2], + cell[1] + cell[2], + cell[0] + cell[1] + cell[2], + ] expected_positions = [positions[0] + offset for offset in expected_offsets] for expected in expected_positions: From 2af67ec64c48fdca0c1c99625e3bf886f999983b Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Tue, 24 Feb 2026 22:23:36 +0100 Subject: [PATCH 13/22] Minor fix for warning filtering --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index b2465893..369ea899 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ filterwarnings = [ "ignore:ast.Str is deprecated and will be removed in Python 3.14:DeprecationWarning", "ignore:Attribute s is deprecated and will be removed in Python 3.14:DeprecationWarning", "ignore:ast.NameConstant is deprecated and will be removed in Python 3.14:DeprecationWarning", + "ignore:Found metatomic.torch.*but vesin.metatomic was only tested with:UserWarning", "ignore:`torch.jit.script` is deprecated. Please switch to `torch.compile` or `torch.export`:DeprecationWarning", "ignore:`torch.jit.save` is deprecated. Please switch to `torch.export`:DeprecationWarning", ] From 9eebb91b1dd8c55336925d11d813d2c86fc6bbbc Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Tue, 24 Feb 2026 22:42:46 +0100 Subject: [PATCH 14/22] Fix naming issues --- .../metatomic_torch/tests/test_heat_flux.py | 51 ++++++++++--------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/python/metatomic_torch/tests/test_heat_flux.py b/python/metatomic_torch/tests/test_heat_flux.py index 0cfccf5e..0a55ff04 100644 --- a/python/metatomic_torch/tests/test_heat_flux.py +++ b/python/metatomic_torch/tests/test_heat_flux.py @@ -16,11 +16,11 @@ from metatomic.torch.ase_calculator import MetatomicCalculator from metatomic.torch.heat_flux import ( HeatFluxWrapper, - check_collisions, - collisions_to_replicas, - generate_replica_atoms, - unfold_system, - wrap_positions, + _check_collisions, + _collisions_to_replicas, + _generate_replica_atoms, + _unfold_system, + _wrap_positions, ) @@ -135,7 +135,7 @@ def __call__(self, systems, options, check_consistency): def test_wrap_positions_cubic_matches_expected(): cell = torch.eye(3) * 2.0 positions = torch.tensor([[-0.1, 0.0, 0.0], [2.1, 1.0, -0.5]]) - wrapped = wrap_positions(positions, cell) + wrapped = _wrap_positions(positions, cell) expected = torch.tensor([[1.9, 0.0, 0.0], [0.1, 1.0, 1.5]]) assert torch.allclose(wrapped, expected) @@ -143,7 +143,7 @@ def test_wrap_positions_cubic_matches_expected(): def test_check_collisions_cubic_axis_order(): cell = torch.eye(3) * 2.0 positions = torch.tensor([[0.1, 1.0, 1.9]]) - collisions, norm_coords = check_collisions(cell, positions, cutoff=0.2, skin=0.0) + collisions, norm_coords = _check_collisions(cell, positions, cutoff=0.2, skin=0.0) assert torch.allclose(norm_coords, positions) assert collisions.shape == (1, 6) assert collisions[0].tolist() == [True, False, False, False, False, True] @@ -154,8 +154,8 @@ def test_generate_replica_atoms_cubic_offsets(): positions = torch.tensor([[0.1, 1.0, 1.0]]) cell = torch.eye(3) * 2.0 collisions = torch.tensor([[True, False, False, False, False, False]]) - replicas = collisions_to_replicas(collisions) - replica_idx, replica_types, replica_positions = generate_replica_atoms( + replicas = _collisions_to_replicas(collisions) + replica_idx, replica_types, replica_positions = _generate_replica_atoms( types, positions, cell, replicas ) assert replica_idx.tolist() == [0] @@ -181,7 +181,7 @@ def test_wrap_positions_triclinic_fractional_bounds_and_shift(): ] ) inv_cell = cell.inverse() - wrapped = wrap_positions(positions, cell) + wrapped = _wrap_positions(positions, cell) fractional_before = torch.einsum("iv,vk->ik", positions, inv_cell) fractional_after = torch.einsum("iv,vk->ik", wrapped, inv_cell) @@ -218,7 +218,9 @@ def test_check_collisions_triclinic_targets(): ) positions = target @ torch.inverse(norm_vectors).T - collisions, norm_coords = check_collisions(cell, positions, cutoff=cutoff, skin=0.0) + collisions, norm_coords = _check_collisions( + cell, positions, cutoff=cutoff, skin=0.0 + ) assert torch.allclose(norm_coords, target, atol=1e-6, rtol=0) expected_low = target <= cutoff @@ -233,7 +235,7 @@ def test_check_collisions_raises_on_small_cell(): cell = torch.eye(3) * 1.0 positions = torch.zeros((1, 3)) with pytest.raises(ValueError, match="Cell is too small"): - check_collisions(cell, positions, cutoff=0.9, skin=0.2) + _check_collisions(cell, positions, cutoff=0.9, skin=0.2) def test_skin_parameter_affects_collisions(): @@ -243,17 +245,17 @@ def test_skin_parameter_affects_collisions(): positions = torch.tensor([[0.3, 1.0, 1.0]]) # cutoff=0.2, skin=0.0 → effective range 0.2 < 0.3 → no collision - collisions_no_skin, _ = check_collisions(cell, positions, cutoff=0.2, skin=0.0) + collisions_no_skin, _ = _check_collisions(cell, positions, cutoff=0.2, skin=0.0) assert not collisions_no_skin.any() # cutoff=0.2, skin=0.2 → effective range 0.4 > 0.3 → x_lo collision - collisions_with_skin, _ = check_collisions(cell, positions, cutoff=0.2, skin=0.2) + collisions_with_skin, _ = _check_collisions(cell, positions, cutoff=0.2, skin=0.2) assert collisions_with_skin[0, 0].item() # x_lo def test_collisions_to_replicas_combines_displacements(): collisions = torch.tensor([[True, False, False, True, False, False]]) - replicas = collisions_to_replicas(collisions) + replicas = _collisions_to_replicas(collisions) assert replicas.shape == (1, 3, 3, 3) assert replicas[0, 0, 0, 0].item() is False @@ -277,8 +279,8 @@ def test_generate_replica_atoms_triclinic_offsets(): types = torch.tensor([1]) positions = torch.tensor([[0.2, 0.4, 0.6]]) collisions = torch.tensor([[True, False, True, False, True, False]]) - replicas = collisions_to_replicas(collisions) - replica_idx, replica_types, replica_positions = generate_replica_atoms( + replicas = _collisions_to_replicas(collisions) + replica_idx, replica_types, replica_positions = _generate_replica_atoms( types, positions, cell, replicas ) @@ -307,8 +309,7 @@ def test_unfold_system_adds_replica_and_data(): cell = torch.eye(3) * 2.0 positions = torch.tensor([[0.1, 1.0, 1.0]]) system = _make_system_with_data(positions, cell) - unfolded = unfold_system(system, cutoff=0.1) - + unfolded = _unfold_system(system, cutoff=0.1) assert len(unfolded.positions) == 2 assert torch.all(unfolded.pbc == torch.tensor([False, False, False])) assert torch.allclose(unfolded.cell, torch.zeros_like(unfolded.cell)) @@ -329,10 +330,10 @@ def test_unfold_system_no_replicas_for_interior_atoms(): cell = torch.eye(3) * 10.0 positions = torch.tensor([[5.0, 5.0, 5.0], [3.0, 4.0, 6.0]]) system = _make_system_with_data(positions, cell) - unfolded = unfold_system(system, cutoff=1.0, skin=0.0) + unfolded = _unfold_system(system, cutoff=1.0, skin=0.0) assert len(unfolded.positions) == 2 - assert torch.allclose(unfolded.positions, wrap_positions(positions, cell)) + assert torch.allclose(unfolded.positions, _wrap_positions(positions, cell)) def test_unfold_system_triclinic_cell(): @@ -352,7 +353,7 @@ def test_unfold_system_triclinic_cell(): ] ) system = _make_system_with_data(positions, cell) - unfolded = unfold_system(system, cutoff=0.3, skin=0.0) + unfolded = _unfold_system(system, cutoff=0.3, skin=0.0) # The near-origin atom should generate at least one replica assert len(unfolded.positions) > 2 @@ -430,12 +431,12 @@ def __call__(self, systems, options, check_consistency): ] ) system = _make_system_with_data(positions, cell) - unfolded = unfold_system(system, cutoff=0.1, skin=0.0) + unfolded = _unfold_system(system, cutoff=0.1, skin=0.0) n_atoms = len(system.positions) assert len(unfolded.positions) == n_atoms * 2 wrapper = HeatFluxWrapper(_ArangeDummyModel()) - barycenter, atomic_e, total_e = wrapper.barycenter_and_atomic_energies( + barycenter, atomic_e, total_e = wrapper._barycenter_and_atomic_energies( unfolded, n_atoms ) @@ -461,7 +462,7 @@ def _fake_hf(self, system): ) wrapper = HeatFluxWrapper(_ZeroDummyModel()) - monkeypatch.setattr(HeatFluxWrapper, "calc_unfolded_heat_flux", _fake_hf) + monkeypatch.setattr(HeatFluxWrapper, "_calc_unfolded_heat_flux", _fake_hf) cell = torch.eye(3) systems = [ From f78cfa79bff3472f73d1882dd9f12b302cca5b64 Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Fri, 27 Feb 2026 22:53:36 +0100 Subject: [PATCH 15/22] Replace `einsum` and some documentation improvements --- pyproject.toml | 1 - .../metatomic/torch/ase_calculator.py | 6 +-- .../metatomic/torch/heat_flux.py | 32 +++++++-------- python/metatomic_torch/pyproject.toml | 1 - .../metatomic_torch/tests/test_heat_flux.py | 40 ++++++++++++------- 5 files changed, 41 insertions(+), 39 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 423c8c23..1a67cd6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,6 @@ filterwarnings = [ "ignore:ast.Str is deprecated and will be removed in Python 3.14:DeprecationWarning", "ignore:Attribute s is deprecated and will be removed in Python 3.14:DeprecationWarning", "ignore:ast.NameConstant is deprecated and will be removed in Python 3.14:DeprecationWarning", - "ignore:Found metatomic.torch.*but vesin.metatomic was only tested with:UserWarning", "ignore:`torch.jit.script` is deprecated. Please switch to `torch.compile` or `torch.export`:DeprecationWarning", "ignore:`torch.jit.save` is deprecated. Please switch to `torch.export`:DeprecationWarning", "ignore:.*vesin.metatomic was only tested with metatomic.torch >=0.1.3,<0.2.*:UserWarning", diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 3b43b361..96b85f51 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -106,10 +106,6 @@ def _get_charges(atoms: ase.Atoms) -> np.ndarray: }, } -IMPLEMENTED_PROPERTIES = [ - "heat_flux", -] - class MetatomicCalculator(ase.calculators.calculator.Calculator): """ @@ -329,7 +325,7 @@ def __init__( # We do our own check to verify if a property is implemented in `calculate()`, # so we pretend to be able to compute all properties ASE knows about. - self.implemented_properties = ALL_ASE_PROPERTIES + IMPLEMENTED_PROPERTIES + self.implemented_properties = ALL_ASE_PROPERTIES self.additional_outputs: Dict[str, TensorMap] = {} """ diff --git a/python/metatomic_torch/metatomic/torch/heat_flux.py b/python/metatomic_torch/metatomic/torch/heat_flux.py index 3b06d86f..44410e3d 100644 --- a/python/metatomic_torch/metatomic/torch/heat_flux.py +++ b/python/metatomic_torch/metatomic/torch/heat_flux.py @@ -16,18 +16,19 @@ def _wrap_positions(positions: torch.Tensor, cell: torch.Tensor) -> torch.Tensor """ Wrap positions into the periodic cell. """ - fractional_positions = torch.einsum("iv,vk->ik", positions, cell.inverse()) + fractional_positions = positions @ cell.inverse() fractional_positions = fractional_positions - torch.floor(fractional_positions) - wrapped_positions = torch.einsum("iv,vk->ik", fractional_positions, cell) + wrapped_positions = fractional_positions @ cell return wrapped_positions -def _check_collisions( +def _check_close_to_cell_boundary( cell: torch.Tensor, positions: torch.Tensor, cutoff: float, skin: float -) -> tuple[torch.Tensor, torch.Tensor]: +) -> torch.Tensor: """ - Detect atoms that lie within a cutoff distance from the periodic cell boundaries, + Detect atoms that lie within a cutoff distance (in our context, the interaction + range of the model + the skin) from the periodic cell boundaries, i.e. have interactions with atoms at the opposite end of the cell. """ inv_cell = cell.inverse() @@ -45,17 +46,14 @@ def _check_collisions( cutoff = cutoff + skin normals = recip / norms[:, None] - norm_coords = torch.einsum("iv,kv->ik", positions, normals) + norm_coords = positions @ normals.T collisions = torch.hstack( [norm_coords <= cutoff, norm_coords >= heights - cutoff], ).to(device=positions.device) - return ( - collisions[ - :, [0, 3, 1, 4, 2, 5] # reorder to (x_lo, x_hi, y_lo, y_hi, z_lo, z_hi) - ], - norm_coords, - ) + return collisions[ + :, [0, 3, 1, 4, 2, 5] # reorder to (x_lo, x_hi, y_lo, y_hi, z_lo, z_hi) + ] def _collisions_to_replicas(collisions: torch.Tensor) -> torch.Tensor: @@ -104,9 +102,7 @@ def _generate_replica_atoms( replica_offsets = torch.tensor( [0, 1, -1], device=positions.device, dtype=positions.dtype )[replicas[:, 1:]] - replica_positions = positions[replica_idx] + torch.einsum( - "iA,Aa->ia", replica_offsets, cell - ) + replica_positions = positions[replica_idx] + replica_offsets @ cell return replica_idx, types[replica_idx], replica_positions @@ -120,10 +116,12 @@ def _unfold_system( The unfolded system has no periodic boundary conditions. """ + if not metatomic_system.pbc.any(): + raise ValueError("Unfolding systems is only supported for periodic systems.") wrapped_positions = _wrap_positions( metatomic_system.positions, metatomic_system.cell ) - collisions, _ = _check_collisions( + collisions = _check_close_to_cell_boundary( metatomic_system.cell, wrapped_positions, cutoff, skin ) replicas = _collisions_to_replicas(collisions) @@ -272,7 +270,7 @@ def _barycenter_and_atomic_energies(self, system: System, n_atoms: int): total_e = atomic_e[:n_atoms].sum() r_aux = system.positions.detach() - barycenter = torch.einsum("i,ik->k", atomic_e[:n_atoms], r_aux[:n_atoms]) + barycenter = (atomic_e[:n_atoms, None] * r_aux[:n_atoms]).sum(dim=0) return barycenter, atomic_e, total_e diff --git a/python/metatomic_torch/pyproject.toml b/python/metatomic_torch/pyproject.toml index 19545c6b..5dbe99b7 100644 --- a/python/metatomic_torch/pyproject.toml +++ b/python/metatomic_torch/pyproject.toml @@ -58,7 +58,6 @@ python_files = ["*.py"] testpaths = ["tests"] filterwarnings = [ "error", - "ignore:Found metatomic.torch.*but vesin.metatomic was only tested with:UserWarning", "ignore:`torch.jit.script` is deprecated. Please switch to `torch.compile` or `torch.export`:DeprecationWarning", "ignore:`torch.jit.save` is deprecated. Please switch to `torch.export`:DeprecationWarning", "ignore:`torch.jit.load` is deprecated. Please switch to `torch.export`:DeprecationWarning", diff --git a/python/metatomic_torch/tests/test_heat_flux.py b/python/metatomic_torch/tests/test_heat_flux.py index 0a55ff04..0224a9e1 100644 --- a/python/metatomic_torch/tests/test_heat_flux.py +++ b/python/metatomic_torch/tests/test_heat_flux.py @@ -16,7 +16,7 @@ from metatomic.torch.ase_calculator import MetatomicCalculator from metatomic.torch.heat_flux import ( HeatFluxWrapper, - _check_collisions, + _check_close_to_cell_boundary, _collisions_to_replicas, _generate_replica_atoms, _unfold_system, @@ -140,11 +140,10 @@ def test_wrap_positions_cubic_matches_expected(): assert torch.allclose(wrapped, expected) -def test_check_collisions_cubic_axis_order(): +def test_check_close_to_cell_boundary_cubic_axis_order(): cell = torch.eye(3) * 2.0 positions = torch.tensor([[0.1, 1.0, 1.9]]) - collisions, norm_coords = _check_collisions(cell, positions, cutoff=0.2, skin=0.0) - assert torch.allclose(norm_coords, positions) + collisions = _check_close_to_cell_boundary(cell, positions, cutoff=0.2, skin=0.0) assert collisions.shape == (1, 6) assert collisions[0].tolist() == [True, False, False, False, False, True] @@ -194,7 +193,7 @@ def test_wrap_positions_triclinic_fractional_bounds_and_shift(): assert torch.allclose(rounded, -torch.floor(fractional_before), atol=1e-6, rtol=0) -def test_check_collisions_triclinic_targets(): +def test_check_close_to_cell_boundary_triclinic_targets(): cell = torch.tensor( [ [2.0, 0.3, 0.2], @@ -218,10 +217,7 @@ def test_check_collisions_triclinic_targets(): ) positions = target @ torch.inverse(norm_vectors).T - collisions, norm_coords = _check_collisions( - cell, positions, cutoff=cutoff, skin=0.0 - ) - assert torch.allclose(norm_coords, target, atol=1e-6, rtol=0) + collisions = _check_close_to_cell_boundary(cell, positions, cutoff=cutoff, skin=0.0) expected_low = target <= cutoff expected_high = target >= heights - cutoff @@ -231,11 +227,11 @@ def test_check_collisions_triclinic_targets(): assert torch.equal(collisions, expected) -def test_check_collisions_raises_on_small_cell(): +def test_check_close_to_cell_boundary_raises_on_small_cell(): cell = torch.eye(3) * 1.0 positions = torch.zeros((1, 3)) with pytest.raises(ValueError, match="Cell is too small"): - _check_collisions(cell, positions, cutoff=0.9, skin=0.2) + _check_close_to_cell_boundary(cell, positions, cutoff=0.9, skin=0.2) def test_skin_parameter_affects_collisions(): @@ -245,11 +241,15 @@ def test_skin_parameter_affects_collisions(): positions = torch.tensor([[0.3, 1.0, 1.0]]) # cutoff=0.2, skin=0.0 → effective range 0.2 < 0.3 → no collision - collisions_no_skin, _ = _check_collisions(cell, positions, cutoff=0.2, skin=0.0) + collisions_no_skin = _check_close_to_cell_boundary( + cell, positions, cutoff=0.2, skin=0.0 + ) assert not collisions_no_skin.any() # cutoff=0.2, skin=0.2 → effective range 0.4 > 0.3 → x_lo collision - collisions_with_skin, _ = _check_collisions(cell, positions, cutoff=0.2, skin=0.2) + collisions_with_skin = _check_close_to_cell_boundary( + cell, positions, cutoff=0.2, skin=0.2 + ) assert collisions_with_skin[0, 0].item() # x_lo @@ -436,8 +436,18 @@ def __call__(self, systems, options, check_consistency): assert len(unfolded.positions) == n_atoms * 2 wrapper = HeatFluxWrapper(_ArangeDummyModel()) - barycenter, atomic_e, total_e = wrapper._barycenter_and_atomic_energies( - unfolded, n_atoms + + # Verify atomic energy ordering by reproducing the inlined logic + # from _calc_unfolded_heat_flux + energy_block = _ArangeDummyModel()( + [unfolded], wrapper._unfolded_run_options, False + )["energy"].block(0) + atom_indices = energy_block.samples.column("atom").to(torch.long) + sorted_order = torch.argsort(atom_indices) + atomic_e = energy_block.values.flatten()[sorted_order] + total_e = atomic_e[:n_atoms].sum() + barycenter = torch.einsum( + "i,ik->k", atomic_e[:n_atoms], unfolded.positions[:n_atoms] ) expected_atomic_e = torch.arange( From ab9e8198da7ecefa943ae5e002c472287000cd87 Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Fri, 27 Feb 2026 23:46:23 +0100 Subject: [PATCH 16/22] Replace `self._model = model` with `self._model = model.module` and remove some tests --- .../metatomic/torch/heat_flux.py | 33 ++-- .../metatomic_torch/tests/test_heat_flux.py | 167 ++---------------- 2 files changed, 38 insertions(+), 162 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/heat_flux.py b/python/metatomic_torch/metatomic/torch/heat_flux.py index 44410e3d..eb657cd3 100644 --- a/python/metatomic_torch/metatomic/torch/heat_flux.py +++ b/python/metatomic_torch/metatomic/torch/heat_flux.py @@ -8,6 +8,7 @@ AtomisticModel, ModelEvaluationOptions, ModelOutput, + NeighborListOptions, System, ) @@ -225,10 +226,12 @@ def __init__(self, model: AtomisticModel, skin: float = 0.5): """ super().__init__() - self._model = model + assert isinstance(model, AtomisticModel) + self._model = model.module self.skin = skin self._interaction_range = model.capabilities().interaction_range + self._requested_neighbor_lists = model.requested_neighbor_lists() self._requested_inputs = { "masses": ModelOutput(quantity="mass", unit="u", per_atom=True), "velocities": ModelOutput(quantity="velocity", unit="A/fs", per_atom=True), @@ -240,9 +243,8 @@ def __init__(self, model: AtomisticModel, skin: float = 0.5): explicit_gradients=[], per_atom=False, ) - outputs = self._model.capabilities().outputs.copy() + outputs = model.capabilities().outputs.copy() outputs["extra::heat_flux"] = hf_output - self._model.capabilities().outputs["extra::heat_flux"] = hf_output if outputs["energy"].unit != "eV": raise ValueError( "HeatFluxWrapper can only be used with energy outputs in eV" @@ -251,19 +253,25 @@ def __init__(self, model: AtomisticModel, skin: float = 0.5): quantity="energy", unit=outputs["energy"].unit, per_atom=True ) self._unfolded_run_options = ModelEvaluationOptions( - length_unit=self._model.capabilities().length_unit, + length_unit=model.capabilities().length_unit, outputs={"energy": energies_output}, selected_atoms=None, ) + @torch.jit.export + def requested_neighbor_lists(self) -> List[NeighborListOptions]: + return self._requested_neighbor_lists + @torch.jit.export def requested_inputs(self) -> Dict[str, ModelOutput]: return self._requested_inputs def _barycenter_and_atomic_energies(self, system: System, n_atoms: int): - energy_block = self._model([system], self._unfolded_run_options, False)[ - "energy" - ].block(0) + energy_block = self._model( + [system], + self._unfolded_run_options.outputs, + self._unfolded_run_options.selected_atoms, + )["energy"].block(0) atom_indices = energy_block.samples.column("atom").to(torch.long) sorted_order = torch.argsort(atom_indices) atomic_e = energy_block.values.flatten()[sorted_order] @@ -281,7 +289,7 @@ def _calc_unfolded_heat_flux(self, system: System) -> torch.Tensor: ) compute_requested_neighbors_from_options( [unfolded_system], - self._model.requested_neighbor_lists(), + self.requested_neighbor_lists(), self._unfolded_run_options.length_unit, False, ) @@ -342,12 +350,9 @@ def forward( outputs: Dict[str, ModelOutput], selected_atoms: Optional[Labels], ) -> Dict[str, TensorMap]: - run_options = ModelEvaluationOptions( - length_unit=self._model.capabilities().length_unit, - outputs=outputs, - selected_atoms=None, - ) - results = self._model(systems, run_options, False) + outputs_wo_heat_flux = outputs.copy() + del outputs_wo_heat_flux["extra::heat_flux"] + results = self._model(systems, outputs_wo_heat_flux, selected_atoms) if "extra::heat_flux" not in outputs: return results diff --git a/python/metatomic_torch/tests/test_heat_flux.py b/python/metatomic_torch/tests/test_heat_flux.py index 0224a9e1..5e5bcd5c 100644 --- a/python/metatomic_torch/tests/test_heat_flux.py +++ b/python/metatomic_torch/tests/test_heat_flux.py @@ -37,6 +37,19 @@ def model(): ) +@pytest.fixture +def model_in_kcal_per_mol(): + return metatomic_lj_test.lennard_jones_model( + atomic_type=18, + cutoff=7.0, + sigma=3.405, + epsilon=0.01032, + length_unit="Angstrom", + energy_unit="kcal/mol", + with_extension=False, + ) + + @pytest.fixture def atoms(): cell = np.array([[6.0, 0.0, 0.0], [0.0, 6.0, 0.0], [0.0, 0.0, 6.0]]) @@ -110,6 +123,7 @@ class _ZeroDummyModel: def __init__(self, energy_unit: str = "eV"): self._capabilities = _DummyCapabilities(energy_unit) + self.module = None def capabilities(self): return self._capabilities @@ -368,167 +382,24 @@ def test_unfold_system_triclinic_cell(): ) -def test_heat_flux_wrapper_rejects_non_eV_energy(): +def test_heat_flux_wrapper_rejects_non_eV_energy(model_in_kcal_per_mol): with pytest.raises(ValueError, match="energy outputs in eV"): - HeatFluxWrapper(_ZeroDummyModel(energy_unit="kcal/mol")) + HeatFluxWrapper(model_in_kcal_per_mol) -def test_heat_flux_wrapper_requested_inputs(): - wrapper = HeatFluxWrapper(_ZeroDummyModel()) +def test_heat_flux_wrapper_requested_inputs(model): + wrapper = HeatFluxWrapper(model) requested = wrapper.requested_inputs() assert set(requested.keys()) == {"masses", "velocities"} -def test_unfolded_energy_order_used_for_barycenter(): - class _ArangeDummyModel: - """Returns per-atom energies [0, 1, 2, …] so ordering can be verified.""" - - def __init__(self): - self._capabilities = _DummyCapabilities() - - def capabilities(self): - return self._capabilities - - def __call__(self, systems, options, check_consistency): - system = systems[0] - n_atoms = len(system.positions) - values = torch.arange( - n_atoms, dtype=system.positions.dtype, device=system.positions.device - ).reshape(-1, 1) - block = TensorBlock( - values=values, - samples=Labels( - ["system", "atom"], - torch.stack( - [ - torch.zeros( - n_atoms, - dtype=torch.int32, - device=values.device, - ), - torch.arange(n_atoms, device=values.device), - ], - dim=1, - ), - ), - components=[], - properties=Labels( - ["energy"], torch.tensor([[0]], device=values.device) - ), - ) - return { - "energy": TensorMap( - Labels("_", torch.tensor([[0]], device=values.device)), [block] - ) - } - - cell = torch.eye(3) * 10.0 - positions = torch.tensor( - [ - [0.05, 5.0, 5.0], # near x_lo -> one replica - [9.95, 5.5, 5.0], # near x_hi -> one replica - [0.05, 6.0, 5.5], # near x_lo -> one replica - ] - ) - system = _make_system_with_data(positions, cell) - unfolded = _unfold_system(system, cutoff=0.1, skin=0.0) - n_atoms = len(system.positions) - assert len(unfolded.positions) == n_atoms * 2 - - wrapper = HeatFluxWrapper(_ArangeDummyModel()) - - # Verify atomic energy ordering by reproducing the inlined logic - # from _calc_unfolded_heat_flux - energy_block = _ArangeDummyModel()( - [unfolded], wrapper._unfolded_run_options, False - )["energy"].block(0) - atom_indices = energy_block.samples.column("atom").to(torch.long) - sorted_order = torch.argsort(atom_indices) - atomic_e = energy_block.values.flatten()[sorted_order] - total_e = atomic_e[:n_atoms].sum() - barycenter = torch.einsum( - "i,ik->k", atomic_e[:n_atoms], unfolded.positions[:n_atoms] - ) - - expected_atomic_e = torch.arange( - len(unfolded.positions), - dtype=unfolded.positions.dtype, - device=unfolded.positions.device, - ) - expected_total_e = expected_atomic_e[:n_atoms].sum() - expected_barycenter = torch.einsum( - "i,ik->k", expected_atomic_e[:n_atoms], unfolded.positions[:n_atoms] - ) - - assert torch.allclose(atomic_e, expected_atomic_e) - assert torch.allclose(total_e, expected_total_e) - assert torch.allclose(barycenter, expected_barycenter) - - -def test_heat_flux_wrapper_forward_adds_output(monkeypatch): - def _fake_hf(self, system): - return torch.tensor( - [1.0, 2.0, 3.0], device=system.device, dtype=system.positions.dtype - ) - - wrapper = HeatFluxWrapper(_ZeroDummyModel()) - monkeypatch.setattr(HeatFluxWrapper, "_calc_unfolded_heat_flux", _fake_hf) - - cell = torch.eye(3) - systems = [ - System( - types=torch.tensor([1], dtype=torch.int32), - positions=torch.zeros((1, 3)), - cell=cell, - pbc=torch.tensor([True, True, True]), - ), - System( - types=torch.tensor([1], dtype=torch.int32), - positions=torch.ones((1, 3)), - cell=cell, - pbc=torch.tensor([True, True, True]), - ), - ] - - outputs = { - "energy": ModelOutput(quantity="energy", unit="eV"), - "extra::heat_flux": ModelOutput(quantity="heat_flux", unit=""), - } - results = wrapper.forward(systems, outputs, None) - assert "extra::heat_flux" in results - hf_block = results["extra::heat_flux"].block() - assert hf_block.values.shape == (2, 3, 1) - assert torch.allclose(hf_block.values[:, :, 0], torch.tensor([[1.0, 2.0, 3.0]] * 2)) - - -def test_forward_without_heat_flux_returns_model_results(): - """When ``extra::heat_flux`` is not requested, forward should return model - results unchanged and not invoke the heat-flux computation.""" - wrapper = HeatFluxWrapper(_ZeroDummyModel()) - - cell = torch.eye(3) - systems = [ - System( - types=torch.tensor([1], dtype=torch.int32), - positions=torch.zeros((1, 3)), - cell=cell, - pbc=torch.tensor([True, True, True]), - ), - ] - outputs = {"energy": ModelOutput(quantity="energy", unit="eV")} - results = wrapper.forward(systems, outputs, None) - - assert "energy" in results - assert "extra::heat_flux" not in results - - @pytest.mark.parametrize("use_script", [True, False]) def test_heat_flux_wrapper_calc_heat_flux(model, atoms, use_script): expected = [[8.8238e-05], [-2.5559e-04], [-2.0570e-04]] metadata = ModelMetadata() wrapper = HeatFluxWrapper(model.eval()) - cap = wrapper._model.capabilities() + cap = model.capabilities() outputs = cap.outputs.copy() outputs["extra::heat_flux"] = ModelOutput( quantity="heat_flux", From 2aa5e018aeed0e6caa4656bd8384c08e0e650e01 Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Sun, 1 Mar 2026 20:48:49 +0100 Subject: [PATCH 17/22] Minor fix --- python/metatomic_torch/metatomic/torch/heat_flux.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/metatomic_torch/metatomic/torch/heat_flux.py b/python/metatomic_torch/metatomic/torch/heat_flux.py index eb657cd3..23dec8cb 100644 --- a/python/metatomic_torch/metatomic/torch/heat_flux.py +++ b/python/metatomic_torch/metatomic/torch/heat_flux.py @@ -351,7 +351,8 @@ def forward( selected_atoms: Optional[Labels], ) -> Dict[str, TensorMap]: outputs_wo_heat_flux = outputs.copy() - del outputs_wo_heat_flux["extra::heat_flux"] + if "extra::heat_flux" in outputs: + del outputs_wo_heat_flux["extra::heat_flux"] results = self._model(systems, outputs_wo_heat_flux, selected_atoms) if "extra::heat_flux" not in outputs: From 2c725d19ecb4b476dc075668e228bd5380166b78 Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Sat, 7 Mar 2026 01:05:32 +0100 Subject: [PATCH 18/22] Remove `skin` --- .../metatomic/torch/heat_flux.py | 24 ++++++--------- .../metatomic_torch/tests/test_heat_flux.py | 29 ++++--------------- 2 files changed, 14 insertions(+), 39 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/heat_flux.py b/python/metatomic_torch/metatomic/torch/heat_flux.py index 23dec8cb..1158cfa3 100644 --- a/python/metatomic_torch/metatomic/torch/heat_flux.py +++ b/python/metatomic_torch/metatomic/torch/heat_flux.py @@ -25,27 +25,26 @@ def _wrap_positions(positions: torch.Tensor, cell: torch.Tensor) -> torch.Tensor def _check_close_to_cell_boundary( - cell: torch.Tensor, positions: torch.Tensor, cutoff: float, skin: float + cell: torch.Tensor, positions: torch.Tensor, cutoff: float ) -> torch.Tensor: """ Detect atoms that lie within a cutoff distance (in our context, the interaction - range of the model + the skin) from the periodic cell boundaries, + range of the model) from the periodic cell boundaries, i.e. have interactions with atoms at the opposite end of the cell. """ inv_cell = cell.inverse() recip = inv_cell.T norms = torch.linalg.norm(recip, dim=1) heights = 1.0 / norms - if heights.min() < (cutoff + skin): + if heights.min() < cutoff: raise ValueError( - "Cell is too small compared to (cutoff + skin) = " - + str(cutoff + skin) + "Cell is too small compared to cutoff = " + + str(cutoff) + ". " "Ensure that all cell vectors are at least this length. Currently, the" " minimum cell vector length is " + str(heights.min()) + "." ) - cutoff = cutoff + skin normals = recip / norms[:, None] norm_coords = positions @ normals.T collisions = torch.hstack( @@ -109,7 +108,7 @@ def _generate_replica_atoms( def _unfold_system( - metatomic_system: System, cutoff: float, skin: float = 0.5 + metatomic_system: System, cutoff: float ) -> System: """ Unfold a periodic system by generating replica atoms for those near the cell @@ -123,7 +122,7 @@ def _unfold_system( metatomic_system.positions, metatomic_system.cell ) collisions = _check_close_to_cell_boundary( - metatomic_system.cell, wrapped_positions, cutoff, skin + metatomic_system.cell, wrapped_positions, cutoff ) replicas = _collisions_to_replicas(collisions) replica_idx, replica_types, replica_positions = _generate_replica_atoms( @@ -215,20 +214,15 @@ class HeatFluxWrapper(torch.nn.Module): for semilocal machine-learning potentials. (2023). Physical Review B, 108, L100302.` """ - def __init__(self, model: AtomisticModel, skin: float = 0.5): + def __init__(self, model: AtomisticModel): """ :param model: the :py:class:`AtomisticModel` to wrap, which should be able to compute atomic energies and their gradients with respect to positions - :param skin: the skin parameter for unfolding the system. The wrapper will - generate replica atoms for those within (interaction_range + skin) distance from - the cell boundaries. A skin results in more replica atoms and thus higher - computational cost, but ensures that the heat flux is computed correctly. """ super().__init__() assert isinstance(model, AtomisticModel) self._model = model.module - self.skin = skin self._interaction_range = model.capabilities().interaction_range self._requested_neighbor_lists = model.requested_neighbor_lists() @@ -284,7 +278,7 @@ def _barycenter_and_atomic_energies(self, system: System, n_atoms: int): def _calc_unfolded_heat_flux(self, system: System) -> torch.Tensor: n_atoms = len(system.positions) - unfolded_system = _unfold_system(system, self._interaction_range, self.skin).to( + unfolded_system = _unfold_system(system, self._interaction_range).to( system.device ) compute_requested_neighbors_from_options( diff --git a/python/metatomic_torch/tests/test_heat_flux.py b/python/metatomic_torch/tests/test_heat_flux.py index 5e5bcd5c..767144ff 100644 --- a/python/metatomic_torch/tests/test_heat_flux.py +++ b/python/metatomic_torch/tests/test_heat_flux.py @@ -157,7 +157,7 @@ def test_wrap_positions_cubic_matches_expected(): def test_check_close_to_cell_boundary_cubic_axis_order(): cell = torch.eye(3) * 2.0 positions = torch.tensor([[0.1, 1.0, 1.9]]) - collisions = _check_close_to_cell_boundary(cell, positions, cutoff=0.2, skin=0.0) + collisions = _check_close_to_cell_boundary(cell, positions, cutoff=0.2) assert collisions.shape == (1, 6) assert collisions[0].tolist() == [True, False, False, False, False, True] @@ -231,7 +231,7 @@ def test_check_close_to_cell_boundary_triclinic_targets(): ) positions = target @ torch.inverse(norm_vectors).T - collisions = _check_close_to_cell_boundary(cell, positions, cutoff=cutoff, skin=0.0) + collisions = _check_close_to_cell_boundary(cell, positions, cutoff=cutoff) expected_low = target <= cutoff expected_high = target >= heights - cutoff @@ -245,26 +245,7 @@ def test_check_close_to_cell_boundary_raises_on_small_cell(): cell = torch.eye(3) * 1.0 positions = torch.zeros((1, 3)) with pytest.raises(ValueError, match="Cell is too small"): - _check_close_to_cell_boundary(cell, positions, cutoff=0.9, skin=0.2) - - -def test_skin_parameter_affects_collisions(): - """Increasing the skin should extend the effective detection range.""" - cell = torch.eye(3) * 2.0 - # atom at distance 0.3 from the low-x boundary - positions = torch.tensor([[0.3, 1.0, 1.0]]) - - # cutoff=0.2, skin=0.0 → effective range 0.2 < 0.3 → no collision - collisions_no_skin = _check_close_to_cell_boundary( - cell, positions, cutoff=0.2, skin=0.0 - ) - assert not collisions_no_skin.any() - - # cutoff=0.2, skin=0.2 → effective range 0.4 > 0.3 → x_lo collision - collisions_with_skin = _check_close_to_cell_boundary( - cell, positions, cutoff=0.2, skin=0.2 - ) - assert collisions_with_skin[0, 0].item() # x_lo + _check_close_to_cell_boundary(cell, positions, cutoff=1.1) def test_collisions_to_replicas_combines_displacements(): @@ -344,7 +325,7 @@ def test_unfold_system_no_replicas_for_interior_atoms(): cell = torch.eye(3) * 10.0 positions = torch.tensor([[5.0, 5.0, 5.0], [3.0, 4.0, 6.0]]) system = _make_system_with_data(positions, cell) - unfolded = _unfold_system(system, cutoff=1.0, skin=0.0) + unfolded = _unfold_system(system, cutoff=1.0) assert len(unfolded.positions) == 2 assert torch.allclose(unfolded.positions, _wrap_positions(positions, cell)) @@ -367,7 +348,7 @@ def test_unfold_system_triclinic_cell(): ] ) system = _make_system_with_data(positions, cell) - unfolded = _unfold_system(system, cutoff=0.3, skin=0.0) + unfolded = _unfold_system(system, cutoff=0.3) # The near-origin atom should generate at least one replica assert len(unfolded.positions) > 2 From b1b81da36f82feb167c9e7cc9d3dc7959777ceae Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Sat, 7 Mar 2026 01:32:12 +0100 Subject: [PATCH 19/22] Add heat flux as a standard output, also add momentum and velocity as outputs of i-pi Fix CI Try again Again Add an example on how to use the wrapper Update documentation Try to fix docs again Hopefully the last fix --- docs/src/outputs/heat_fluxes.rst | 65 +++++++++++++++++++ docs/src/outputs/index.rst | 10 +++ docs/src/outputs/momenta.rst | 8 +++ docs/src/outputs/velocities.rst | 9 +++ metatomic-torch/src/misc.cpp | 1 + metatomic-torch/src/model.cpp | 6 ++ metatomic-torch/src/outputs.cpp | 34 ++++++++++ metatomic-torch/tests/models.cpp | 2 +- .../metatomic/torch/heat_flux.py | 60 ++++++++++++----- .../metatomic_torch/tests/test_heat_flux.py | 9 +-- 10 files changed, 183 insertions(+), 21 deletions(-) create mode 100644 docs/src/outputs/heat_fluxes.rst diff --git a/docs/src/outputs/heat_fluxes.rst b/docs/src/outputs/heat_fluxes.rst new file mode 100644 index 00000000..a45632ef --- /dev/null +++ b/docs/src/outputs/heat_fluxes.rst @@ -0,0 +1,65 @@ +.. _heat-fluxes-output: + +Heat Fluxes +^^^^^^^^^^^ + +Heat fluxes are associated with the ``"heat_flux"`` or +``"heat_flux/"`` name (see :ref:`output-variants`), and must have the +following metadata: + +.. list-table:: Metadata for heat fluxes + :widths: 2 3 7 + :header-rows: 1 + + * - Metadata + - Names + - Description + + * - keys + - ``"_"`` + - the keys must have a single dimension named ``"_"``, with a single + entry set to ``0``. Heat fluxes are always a + :py:class:`metatensor.torch.TensorMap` with a single block. + + * - samples + - ``["system"]`` + - the samples must be named ``["system"]``, since + heat fluxes are always not per-atom. + + ``"system"`` must range from 0 to the number of systems given as an input + to the model. ``"atom"`` must range between 0 and the number of + atoms/particles in the corresponding system. If ``selected_atoms`` is + provided, then only the selected atoms for each system should be part of + the samples. + + * - components + - ``"xyz"`` + - heat fluxes must have a single component dimension named + ``"xyz"``, with three entries set to ``0``, ``1``, and ``2``. The + heat fluxes are always 3D vectors, and the order of the + components is x, y, z. + + * - properties + - ``"heat_flux"`` + - heat fluxes must have a single property dimension named + ``"heat_flux"``, with a single entry set to ``0``. + +The following simulation engine can use the ``"heat_flux"`` output. + +.. grid:: 1 3 3 3 + + .. grid-item-card:: + :text-align: center + :padding: 1 + :link: engine-ase + :link-type: ref + + |ase-logo| + + .. grid-item-card:: + :text-align: center + :padding: 1 + :link: engine-ipi + :link-type: ref + + |ipi-logo| diff --git a/docs/src/outputs/index.rst b/docs/src/outputs/index.rst index 8b6ec11f..002a6cd3 100644 --- a/docs/src/outputs/index.rst +++ b/docs/src/outputs/index.rst @@ -25,6 +25,7 @@ schema they need and add a new section to these pages. momenta velocities charges + heat_fluxes features variants @@ -131,6 +132,15 @@ quantities, i.e. quantities with a well-defined physical meaning. Atomic charges, e.g. formal or partial charges on atoms + .. grid-item-card:: Heat fluxes + :link: heat-fluxes-output + :link-type: ref + + .. image:: /../static/images/charges-output.png + + Heat fluxes, i.e. the amount of energy transferred per unit time, i.e. + :math:`\sum_i E_i \times \vec v_i` + Machine learning quantities ^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/src/outputs/momenta.rst b/docs/src/outputs/momenta.rst index cc74eb52..8d3a1a76 100644 --- a/docs/src/outputs/momenta.rst +++ b/docs/src/outputs/momenta.rst @@ -59,3 +59,11 @@ The following simulation engine can provide ``"momenta"`` as inputs to the model :link-type: ref |ase-logo| + + .. grid-item-card:: + :text-align: center + :padding: 1 + :link: engine-ipi + :link-type: ref + + |ipi-logo| diff --git a/docs/src/outputs/velocities.rst b/docs/src/outputs/velocities.rst index 7e0834f9..9249bef8 100644 --- a/docs/src/outputs/velocities.rst +++ b/docs/src/outputs/velocities.rst @@ -55,3 +55,12 @@ The following simulation engine can provide ``"velocities"`` as inputs to the mo :link-type: ref |ase-logo| + + .. grid-item-card:: + :text-align: center + :padding: 1 + :link: engine-ipi + :link-type: ref + + |ipi-logo| + \ No newline at end of file diff --git a/metatomic-torch/src/misc.cpp b/metatomic-torch/src/misc.cpp index ab4a2498..647b685e 100644 --- a/metatomic-torch/src/misc.cpp +++ b/metatomic-torch/src/misc.cpp @@ -427,6 +427,7 @@ inline std::unordered_set KNOWN_INPUTS_OUTPUTS = { "velocities", "masses", "charges", + "heat_flux", }; std::tuple details::validate_name_and_check_variant( diff --git a/metatomic-torch/src/model.cpp b/metatomic-torch/src/model.cpp index a7b6bee2..93c07b4d 100644 --- a/metatomic-torch/src/model.cpp +++ b/metatomic-torch/src/model.cpp @@ -1157,6 +1157,12 @@ static std::map KNOWN_QUANTITIES = { // alternative names {"C", "Coulomb"}, }}}, + {"heat_flux", Quantity{/* name */ "heat_flux", /* baseline */ "eV*Angstrom/fs", { + {"eV*Angstrom/fs", 1.0}, + }, { + // alternative names + {"eV*A/fs", "eV*Angstrom/fs"}, + }}} }; bool metatomic_torch::valid_quantity(const std::string& quantity) { diff --git a/metatomic-torch/src/outputs.cpp b/metatomic-torch/src/outputs.cpp index 86aa23f1..546a1d01 100644 --- a/metatomic-torch/src/outputs.cpp +++ b/metatomic-torch/src/outputs.cpp @@ -584,6 +584,38 @@ static void check_charges( validate_no_gradients("charges", charges_block); } +/// Check output metadata for heat flux. +static void check_heat_flux( + const TensorMap& value, + const std::vector& systems, + const ModelOutput& request +) { + // Ensure the output contains a single block with the expected key + validate_single_block("heat_flux", value); + + // Check samples values from systems + validate_atomic_samples("heat_flux", value, systems, request, torch::nullopt); + + auto tensor_options = torch::TensorOptions().device(value->device()); + auto heat_flux_block = TensorMapHolder::block_by_id(value, 0); + std::vector expected_component { + torch::make_intrusive( + "xyz", + torch::tensor({{0}, {1}, {2}}, tensor_options) + ) + }; + validate_components("heat_flux", heat_flux_block->components(), expected_component); + + auto expected_properties = torch::make_intrusive( + "heat_flux", + torch::tensor({{0}}, tensor_options) + ); + validate_properties("heat_flux", heat_flux_block, expected_properties); + + // Should not have any gradients + validate_no_gradients("heat_flux", heat_flux_block); +} + void metatomic_torch::check_outputs( const std::vector& systems, const c10::Dict& requested, @@ -654,6 +686,8 @@ void metatomic_torch::check_outputs( check_velocities(value, systems, request); } else if (base == "charges") { check_charges(value, systems, request); + } else if (base == "heat_flux") { + check_heat_flux(value, systems, request); } else if (name.find("::") != std::string::npos) { // this is a non-standard output, there is nothing to check } else { diff --git a/metatomic-torch/tests/models.cpp b/metatomic-torch/tests/models.cpp index 0ae77300..9b812a25 100644 --- a/metatomic-torch/tests/models.cpp +++ b/metatomic-torch/tests/models.cpp @@ -110,7 +110,7 @@ TEST_CASE("Models metadata") { virtual ~WarningHandler() override = default; void process(const torch::Warning& warning) override { auto expected = std::string( - "unknown quantity 'unknown', only [charge energy force " + "unknown quantity 'unknown', only [charge energy force heat_flux " "length mass momentum pressure velocity] are supported" ); CHECK(warning.msg() == expected); diff --git a/python/metatomic_torch/metatomic/torch/heat_flux.py b/python/metatomic_torch/metatomic/torch/heat_flux.py index 1158cfa3..5a0e1e8f 100644 --- a/python/metatomic_torch/metatomic/torch/heat_flux.py +++ b/python/metatomic_torch/metatomic/torch/heat_flux.py @@ -38,9 +38,7 @@ def _check_close_to_cell_boundary( heights = 1.0 / norms if heights.min() < cutoff: raise ValueError( - "Cell is too small compared to cutoff = " - + str(cutoff) - + ". " + "Cell is too small compared to cutoff = " + str(cutoff) + ". " "Ensure that all cell vectors are at least this length. Currently, the" " minimum cell vector length is " + str(heights.min()) + "." ) @@ -107,9 +105,7 @@ def _generate_replica_atoms( return replica_idx, types[replica_idx], replica_positions -def _unfold_system( - metatomic_system: System, cutoff: float -) -> System: +def _unfold_system(metatomic_system: System, cutoff: float) -> System: """ Unfold a periodic system by generating replica atoms for those near the cell boundaries within the specified cutoff distance. @@ -202,16 +198,45 @@ def _unfold_system( class HeatFluxWrapper(torch.nn.Module): """ - A wrapper around an AtomisticModel that computes the heat flux of a system using the - unfolded system approach. The heat flux is computed using the atomic energies (eV), - positions(Å), masses(u), velocities(Å/fs), and the energy gradients. + :py:class:`HeatFluxWrapper` is a wrapper around an :py:class:`AtomisticModel` that + computes the heat flux of a system using the unfolded system approach. The heat flux + is computed using the atomic energies (eV), positions(Å), masses(u), + velocities(Å/fs), and the energy gradients. The unfolded system is generated by creating replica atoms for those near the cell boundaries within the interaction range of the model wrapped. The wrapper adds the - heat flux to the model's outputs under the key "extra::heat_flux". + heat flux to the model's outputs under the key "heat_flux". For more details on the heat flux calculation, see `Langer, M. F., et al., Heat flux for semilocal machine-learning potentials. (2023). Physical Review B, 108, L100302.` + + To use this wrapper, here's a helper code snippet: + + >>> import torch + >>> def create_heat_flux_model(model: AtomisticModel) -> AtomisticModel: + ... metadata = ModelMetadata() # your model's metadata here + ... wrapper = torch.jit.script(HeatFluxWrapper(model.eval())) + ... capabilities = model.capabilities() + ... outputs = capabilities.outputs.copy() + ... outputs["heat_flux"] = ModelOutput( + ... quantity="heat_flux", + ... unit="", + ... explicit_gradients=[], + ... per_atom=False, + ... ) + ... new_cap = ModelCapabilities( + ... outputs=outputs, + ... atomic_types=capabilities.atomic_types, + ... interaction_range=capabilities.interaction_range, + ... length_unit=capabilities.length_unit, + ... supported_devices=capabilities.supported_devices, + ... dtype=capabilities.dtype, + ... ) + ... heat_model = AtomisticModel( + ... wrapper.eval(), metadata, capabilities=new_cap + ... ).to(device="cpu") + ... return heat_model + """ def __init__(self, model: AtomisticModel): @@ -345,17 +370,20 @@ def forward( selected_atoms: Optional[Labels], ) -> Dict[str, TensorMap]: outputs_wo_heat_flux = outputs.copy() - if "extra::heat_flux" in outputs: - del outputs_wo_heat_flux["extra::heat_flux"] - results = self._model(systems, outputs_wo_heat_flux, selected_atoms) + if "heat_flux" in outputs: + del outputs_wo_heat_flux["heat_flux"] + + if len(outputs_wo_heat_flux) == 0: + results = torch.jit.annotate(Dict[str, TensorMap], {}) + else: + results = self._model(systems, outputs_wo_heat_flux, selected_atoms) - if "extra::heat_flux" not in outputs: + if "heat_flux" not in outputs: return results device = systems[0].device heat_fluxes: List[torch.Tensor] = [] for system in systems: - system.positions.requires_grad_(True) heat_fluxes.append(self._calc_unfolded_heat_flux(system)) samples = Labels( @@ -368,7 +396,7 @@ def forward( components=[Labels(["xyz"], torch.arange(3, device=device).reshape(-1, 1))], properties=Labels(["heat_flux"], torch.tensor([[0]], device=device)), ) - results["extra::heat_flux"] = TensorMap( + results["heat_flux"] = TensorMap( Labels("_", torch.tensor([[0]], device=device)), [hf_block] ) return results diff --git a/python/metatomic_torch/tests/test_heat_flux.py b/python/metatomic_torch/tests/test_heat_flux.py index 767144ff..235dddb7 100644 --- a/python/metatomic_torch/tests/test_heat_flux.py +++ b/python/metatomic_torch/tests/test_heat_flux.py @@ -382,7 +382,7 @@ def test_heat_flux_wrapper_calc_heat_flux(model, atoms, use_script): wrapper = HeatFluxWrapper(model.eval()) cap = model.capabilities() outputs = cap.outputs.copy() - outputs["extra::heat_flux"] = ModelOutput( + outputs["heat_flux"] = ModelOutput( quantity="heat_flux", unit="", explicit_gradients=[], @@ -408,18 +408,19 @@ def test_heat_flux_wrapper_calc_heat_flux(model, atoms, use_script): heat_model, device="cpu", additional_outputs={ - "extra::heat_flux": ModelOutput( + "heat_flux": ModelOutput( quantity="heat_flux", unit="", explicit_gradients=[], per_atom=False, ) }, + check_consistency=True, ) atoms.calc = calc atoms.get_potential_energy() - assert "extra::heat_flux" in atoms.calc.additional_outputs - results = atoms.calc.additional_outputs["extra::heat_flux"].block().values + assert "heat_flux" in atoms.calc.additional_outputs + results = atoms.calc.additional_outputs["heat_flux"].block().values assert torch.allclose( results, torch.tensor(expected, dtype=results.dtype), From 881a2370d8bcd5754b8547786433016411184b88 Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Mon, 9 Mar 2026 22:56:57 +0100 Subject: [PATCH 20/22] Support energy variant --- .../metatomic/torch/heat_flux.py | 68 ++++++++++++++----- .../metatomic_torch/tests/test_heat_flux.py | 25 ++++--- 2 files changed, 69 insertions(+), 24 deletions(-) diff --git a/python/metatomic_torch/metatomic/torch/heat_flux.py b/python/metatomic_torch/metatomic/torch/heat_flux.py index 5a0e1e8f..b5bed20c 100644 --- a/python/metatomic_torch/metatomic/torch/heat_flux.py +++ b/python/metatomic_torch/metatomic/torch/heat_flux.py @@ -10,6 +10,7 @@ ModelOutput, NeighborListOptions, System, + pick_output, ) @@ -218,7 +219,7 @@ class HeatFluxWrapper(torch.nn.Module): ... wrapper = torch.jit.script(HeatFluxWrapper(model.eval())) ... capabilities = model.capabilities() ... outputs = capabilities.outputs.copy() - ... outputs["heat_flux"] = ModelOutput( + ... outputs[wrapper._hf_variant] = ModelOutput( ... quantity="heat_flux", ... unit="", ... explicit_gradients=[], @@ -239,10 +240,16 @@ class HeatFluxWrapper(torch.nn.Module): """ - def __init__(self, model: AtomisticModel): + def __init__( + self, model: AtomisticModel, variants: Optional[Dict[str, Optional[str]]] = None + ): """ :param model: the :py:class:`AtomisticModel` to wrap, which should be able to compute atomic energies and their gradients with respect to positions + :param variants: a dictionary of variants to use for each output, e.g. + ``{"energy": "pbe"}``, in which case the "pbe" energy output is used to compute + the heat flux. Defaults to ``None``, in which case the default energy output is + used to compute the heat flux. """ super().__init__() @@ -256,24 +263,53 @@ def __init__(self, model: AtomisticModel): "velocities": ModelOutput(quantity="velocity", unit="A/fs", per_atom=True), } + variants = variants or {} + default_variant = variants.get("energy") + + resolved_variants = { + key: variants.get(key, default_variant) + for key in [ + "energy", + "energy_uncertainty", + "non_conservative_forces", + "non_conservative_stress", + ] + } + + outputs = model.capabilities().outputs.copy() + has_energy = any( + "energy" == key or key.startswith("energy/") for key in outputs.keys() + ) + if has_energy: + self._energy_key = pick_output( + "energy", outputs, resolved_variants["energy"] + ) + else: + raise ValueError( + "The wrapped model must be able to compute energy outputs to use " + "HeatFluxWrapper." + ) + if outputs[self._energy_key].unit != "eV": + raise ValueError( + "HeatFluxWrapper can only be used with energy outputs in eV" + ) + energies_output = ModelOutput( + quantity="energy", unit=outputs[self._energy_key].unit, per_atom=True + ) + hf_output = ModelOutput( quantity="heat_flux", unit="", explicit_gradients=[], per_atom=False, ) - outputs = model.capabilities().outputs.copy() - outputs["extra::heat_flux"] = hf_output - if outputs["energy"].unit != "eV": - raise ValueError( - "HeatFluxWrapper can only be used with energy outputs in eV" - ) - energies_output = ModelOutput( - quantity="energy", unit=outputs["energy"].unit, per_atom=True + self._hf_variant = "heat_flux" + ( + "" if default_variant is None else "/" + default_variant ) + outputs[self._hf_variant] = hf_output self._unfolded_run_options = ModelEvaluationOptions( length_unit=model.capabilities().length_unit, - outputs={"energy": energies_output}, + outputs={self._energy_key: energies_output}, selected_atoms=None, ) @@ -290,7 +326,7 @@ def _barycenter_and_atomic_energies(self, system: System, n_atoms: int): [system], self._unfolded_run_options.outputs, self._unfolded_run_options.selected_atoms, - )["energy"].block(0) + )[self._energy_key].block(0) atom_indices = energy_block.samples.column("atom").to(torch.long) sorted_order = torch.argsort(atom_indices) atomic_e = energy_block.values.flatten()[sorted_order] @@ -370,15 +406,15 @@ def forward( selected_atoms: Optional[Labels], ) -> Dict[str, TensorMap]: outputs_wo_heat_flux = outputs.copy() - if "heat_flux" in outputs: - del outputs_wo_heat_flux["heat_flux"] + if self._hf_variant in outputs: + del outputs_wo_heat_flux[self._hf_variant] if len(outputs_wo_heat_flux) == 0: results = torch.jit.annotate(Dict[str, TensorMap], {}) else: results = self._model(systems, outputs_wo_heat_flux, selected_atoms) - if "heat_flux" not in outputs: + if self._hf_variant not in outputs: return results device = systems[0].device @@ -396,7 +432,7 @@ def forward( components=[Labels(["xyz"], torch.arange(3, device=device).reshape(-1, 1))], properties=Labels(["heat_flux"], torch.tensor([[0]], device=device)), ) - results["heat_flux"] = TensorMap( + results[self._hf_variant] = TensorMap( Labels("_", torch.tensor([[0]], device=device)), [hf_block] ) return results diff --git a/python/metatomic_torch/tests/test_heat_flux.py b/python/metatomic_torch/tests/test_heat_flux.py index 235dddb7..5fc9ee70 100644 --- a/python/metatomic_torch/tests/test_heat_flux.py +++ b/python/metatomic_torch/tests/test_heat_flux.py @@ -375,14 +375,23 @@ def test_heat_flux_wrapper_requested_inputs(model): @pytest.mark.parametrize("use_script", [True, False]) -def test_heat_flux_wrapper_calc_heat_flux(model, atoms, use_script): - expected = [[8.8238e-05], [-2.5559e-04], [-2.0570e-04]] - +@pytest.mark.parametrize( + "use_variant, expected", + [ + (True, [[9.0147e-05], [-2.6166e-04], [-1.9002e-04]]), + (False, [[8.8238e-05], [-2.5559e-04], [-2.0570e-04]]), + ], +) +def test_heat_flux_wrapper_calc_heat_flux( + model, atoms, expected, use_script, use_variant +): metadata = ModelMetadata() - wrapper = HeatFluxWrapper(model.eval()) + wrapper = HeatFluxWrapper( + model.eval(), variants=({"energy": "doubled"} if use_variant else None) + ) cap = model.capabilities() outputs = cap.outputs.copy() - outputs["heat_flux"] = ModelOutput( + outputs[wrapper._hf_variant] = ModelOutput( quantity="heat_flux", unit="", explicit_gradients=[], @@ -408,7 +417,7 @@ def test_heat_flux_wrapper_calc_heat_flux(model, atoms, use_script): heat_model, device="cpu", additional_outputs={ - "heat_flux": ModelOutput( + wrapper._hf_variant: ModelOutput( quantity="heat_flux", unit="", explicit_gradients=[], @@ -419,8 +428,8 @@ def test_heat_flux_wrapper_calc_heat_flux(model, atoms, use_script): ) atoms.calc = calc atoms.get_potential_energy() - assert "heat_flux" in atoms.calc.additional_outputs - results = atoms.calc.additional_outputs["heat_flux"].block().values + assert wrapper._hf_variant in atoms.calc.additional_outputs + results = atoms.calc.additional_outputs[wrapper._hf_variant].block().values assert torch.allclose( results, torch.tensor(expected, dtype=results.dtype), From da968150d80315394836b03b06bbee2936217912 Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Mon, 9 Mar 2026 23:43:51 +0100 Subject: [PATCH 21/22] Add a test for wrapper asking for the same quantity as the inner model but with different unit --- .../metatomic_torch/metatomic/torch/model.py | 14 +++++ .../metatomic_torch/tests/ase_calculator.py | 59 ++++++++++++++++++- 2 files changed, 72 insertions(+), 1 deletion(-) diff --git a/python/metatomic_torch/metatomic/torch/model.py b/python/metatomic_torch/metatomic/torch/model.py index c2a9f528..d980d400 100644 --- a/python/metatomic_torch/metatomic/torch/model.py +++ b/python/metatomic_torch/metatomic/torch/model.py @@ -666,7 +666,21 @@ def _get_requested_inputs( already_requested = False for existing in requested: if existing == new_options: + if ( + requested[existing].quantity + == requested_inputs[new_options].quantity + and requested[existing].unit + != requested_inputs[new_options].unit + and requested[existing].per_atom + == requested_inputs[new_options].per_atom + ): + raise NotImplementedError( + f"Different units for the same quantity " + f"`{requested_inputs[new_options].quantity}` is not " + "supported." + ) already_requested = True + print(f"{new_options = }") if not already_requested: requested[new_options] = requested_inputs[new_options] diff --git a/python/metatomic_torch/tests/ase_calculator.py b/python/metatomic_torch/tests/ase_calculator.py index b8e8da2a..6f304ed7 100644 --- a/python/metatomic_torch/tests/ase_calculator.py +++ b/python/metatomic_torch/tests/ase_calculator.py @@ -892,7 +892,7 @@ def forward( class AdditionalInputModel(torch.nn.Module): - def __init__(self, inputs): + def __init__(self, inputs: Dict[str, ModelOutput]): super().__init__() self._requested_inputs = inputs @@ -911,6 +911,32 @@ def forward( } +class SimpleWrapperModel(torch.nn.Module): + def __init__(self, model: AtomisticModel, inputs: Dict[str, ModelOutput]): + super().__init__() + self._model = model.module + self._requested_inputs = inputs + self._capabilities = model.capabilities() + + def requested_inputs(self) -> Dict[str, ModelOutput]: + return self._requested_inputs + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + results = self._model(systems, outputs, selected_atoms) + results.update( + { + ("extra::" + input): systems[0].get_data(input) + for input in self._requested_inputs + } + ) + return results + + def test_additional_input(atoms): inputs = { "masses": ModelOutput(quantity="mass", unit="u", per_atom=True), @@ -951,6 +977,37 @@ def test_additional_input(atoms): assert np.allclose(values, expected) +def test_wrapper_asks_for_inputs_with_different_units(atoms): + inputs = { + "masses": ModelOutput(quantity="mass", unit="u", per_atom=True), + "velocities": ModelOutput(quantity="velocity", unit="A/fs", per_atom=True), + "charges": ModelOutput(quantity="charge", unit="e", per_atom=True), + "ase::initial_charges": ModelOutput(quantity="charge", unit="e", per_atom=True), + } + outputs = {("extra::" + n): inputs[n] for n in inputs} + capabilities = ModelCapabilities( + outputs=outputs, + atomic_types=[28], + interaction_range=0.0, + supported_devices=["cpu"], + dtype="float64", + ) + + model = AtomisticModel( + AdditionalInputModel(inputs).eval(), ModelMetadata(), capabilities + ) + + inputs_wrapper = { + "masses": ModelOutput(quantity="mass", unit="kg", per_atom=True), + } + wrapper = SimpleWrapperModel(model, inputs_wrapper) + with pytest.raises( + NotImplementedError, + match="Different units for the same quantity `mass` is not supported.", + ): + AtomisticModel(wrapper.eval(), ModelMetadata(), capabilities) + + @pytest.mark.parametrize("device,dtype", ALL_DEVICE_DTYPE) def test_mixed_pbc(model, device, dtype): """Test that the calculator works on a mixed-PBC system""" From a63f4a4d6a4e83f3e273fe730e57a2463f549e9e Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Mon, 9 Mar 2026 23:47:07 +0100 Subject: [PATCH 22/22] Check the length unit of inner model --- python/metatomic_torch/metatomic/torch/heat_flux.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/metatomic_torch/metatomic/torch/heat_flux.py b/python/metatomic_torch/metatomic/torch/heat_flux.py index b5bed20c..02dd2f9a 100644 --- a/python/metatomic_torch/metatomic/torch/heat_flux.py +++ b/python/metatomic_torch/metatomic/torch/heat_flux.py @@ -256,6 +256,10 @@ def __init__( assert isinstance(model, AtomisticModel) self._model = model.module self._interaction_range = model.capabilities().interaction_range + if model.capabilities().length_unit not in ["Angstrom", "A"]: + raise NotImplementedError( + "HeatFluxWrapper only supports models with length unit 'Angstrom'" + ) self._requested_neighbor_lists = model.requested_neighbor_lists() self._requested_inputs = {