diff --git a/docs/src/outputs/heat_fluxes.rst b/docs/src/outputs/heat_fluxes.rst new file mode 100644 index 00000000..a45632ef --- /dev/null +++ b/docs/src/outputs/heat_fluxes.rst @@ -0,0 +1,65 @@ +.. _heat-fluxes-output: + +Heat Fluxes +^^^^^^^^^^^ + +Heat fluxes are associated with the ``"heat_flux"`` or +``"heat_flux/"`` name (see :ref:`output-variants`), and must have the +following metadata: + +.. list-table:: Metadata for heat fluxes + :widths: 2 3 7 + :header-rows: 1 + + * - Metadata + - Names + - Description + + * - keys + - ``"_"`` + - the keys must have a single dimension named ``"_"``, with a single + entry set to ``0``. Heat fluxes are always a + :py:class:`metatensor.torch.TensorMap` with a single block. + + * - samples + - ``["system"]`` + - the samples must be named ``["system"]``, since + heat fluxes are always not per-atom. + + ``"system"`` must range from 0 to the number of systems given as an input + to the model. ``"atom"`` must range between 0 and the number of + atoms/particles in the corresponding system. If ``selected_atoms`` is + provided, then only the selected atoms for each system should be part of + the samples. + + * - components + - ``"xyz"`` + - heat fluxes must have a single component dimension named + ``"xyz"``, with three entries set to ``0``, ``1``, and ``2``. The + heat fluxes are always 3D vectors, and the order of the + components is x, y, z. + + * - properties + - ``"heat_flux"`` + - heat fluxes must have a single property dimension named + ``"heat_flux"``, with a single entry set to ``0``. + +The following simulation engine can use the ``"heat_flux"`` output. + +.. grid:: 1 3 3 3 + + .. grid-item-card:: + :text-align: center + :padding: 1 + :link: engine-ase + :link-type: ref + + |ase-logo| + + .. grid-item-card:: + :text-align: center + :padding: 1 + :link: engine-ipi + :link-type: ref + + |ipi-logo| diff --git a/docs/src/outputs/index.rst b/docs/src/outputs/index.rst index 8b6ec11f..002a6cd3 100644 --- a/docs/src/outputs/index.rst +++ b/docs/src/outputs/index.rst @@ -25,6 +25,7 @@ schema they need and add a new section to these pages. momenta velocities charges + heat_fluxes features variants @@ -131,6 +132,15 @@ quantities, i.e. quantities with a well-defined physical meaning. Atomic charges, e.g. formal or partial charges on atoms + .. grid-item-card:: Heat fluxes + :link: heat-fluxes-output + :link-type: ref + + .. image:: /../static/images/charges-output.png + + Heat fluxes, i.e. the amount of energy transferred per unit time, i.e. + :math:`\sum_i E_i \times \vec v_i` + Machine learning quantities ^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/src/outputs/momenta.rst b/docs/src/outputs/momenta.rst index cc74eb52..8d3a1a76 100644 --- a/docs/src/outputs/momenta.rst +++ b/docs/src/outputs/momenta.rst @@ -59,3 +59,11 @@ The following simulation engine can provide ``"momenta"`` as inputs to the model :link-type: ref |ase-logo| + + .. grid-item-card:: + :text-align: center + :padding: 1 + :link: engine-ipi + :link-type: ref + + |ipi-logo| diff --git a/docs/src/outputs/velocities.rst b/docs/src/outputs/velocities.rst index 7e0834f9..9249bef8 100644 --- a/docs/src/outputs/velocities.rst +++ b/docs/src/outputs/velocities.rst @@ -55,3 +55,12 @@ The following simulation engine can provide ``"velocities"`` as inputs to the mo :link-type: ref |ase-logo| + + .. grid-item-card:: + :text-align: center + :padding: 1 + :link: engine-ipi + :link-type: ref + + |ipi-logo| + \ No newline at end of file diff --git a/metatomic-torch/src/misc.cpp b/metatomic-torch/src/misc.cpp index ab4a2498..647b685e 100644 --- a/metatomic-torch/src/misc.cpp +++ b/metatomic-torch/src/misc.cpp @@ -427,6 +427,7 @@ inline std::unordered_set KNOWN_INPUTS_OUTPUTS = { "velocities", "masses", "charges", + "heat_flux", }; std::tuple details::validate_name_and_check_variant( diff --git a/metatomic-torch/src/model.cpp b/metatomic-torch/src/model.cpp index a7b6bee2..93c07b4d 100644 --- a/metatomic-torch/src/model.cpp +++ b/metatomic-torch/src/model.cpp @@ -1157,6 +1157,12 @@ static std::map KNOWN_QUANTITIES = { // alternative names {"C", "Coulomb"}, }}}, + {"heat_flux", Quantity{/* name */ "heat_flux", /* baseline */ "eV*Angstrom/fs", { + {"eV*Angstrom/fs", 1.0}, + }, { + // alternative names + {"eV*A/fs", "eV*Angstrom/fs"}, + }}} }; bool metatomic_torch::valid_quantity(const std::string& quantity) { diff --git a/metatomic-torch/src/outputs.cpp b/metatomic-torch/src/outputs.cpp index 86aa23f1..546a1d01 100644 --- a/metatomic-torch/src/outputs.cpp +++ b/metatomic-torch/src/outputs.cpp @@ -584,6 +584,38 @@ static void check_charges( validate_no_gradients("charges", charges_block); } +/// Check output metadata for heat flux. +static void check_heat_flux( + const TensorMap& value, + const std::vector& systems, + const ModelOutput& request +) { + // Ensure the output contains a single block with the expected key + validate_single_block("heat_flux", value); + + // Check samples values from systems + validate_atomic_samples("heat_flux", value, systems, request, torch::nullopt); + + auto tensor_options = torch::TensorOptions().device(value->device()); + auto heat_flux_block = TensorMapHolder::block_by_id(value, 0); + std::vector expected_component { + torch::make_intrusive( + "xyz", + torch::tensor({{0}, {1}, {2}}, tensor_options) + ) + }; + validate_components("heat_flux", heat_flux_block->components(), expected_component); + + auto expected_properties = torch::make_intrusive( + "heat_flux", + torch::tensor({{0}}, tensor_options) + ); + validate_properties("heat_flux", heat_flux_block, expected_properties); + + // Should not have any gradients + validate_no_gradients("heat_flux", heat_flux_block); +} + void metatomic_torch::check_outputs( const std::vector& systems, const c10::Dict& requested, @@ -654,6 +686,8 @@ void metatomic_torch::check_outputs( check_velocities(value, systems, request); } else if (base == "charges") { check_charges(value, systems, request); + } else if (base == "heat_flux") { + check_heat_flux(value, systems, request); } else if (name.find("::") != std::string::npos) { // this is a non-standard output, there is nothing to check } else { diff --git a/metatomic-torch/tests/models.cpp b/metatomic-torch/tests/models.cpp index 0ae77300..9b812a25 100644 --- a/metatomic-torch/tests/models.cpp +++ b/metatomic-torch/tests/models.cpp @@ -110,7 +110,7 @@ TEST_CASE("Models metadata") { virtual ~WarningHandler() override = default; void process(const torch::Warning& warning) override { auto expected = std::string( - "unknown quantity 'unknown', only [charge energy force " + "unknown quantity 'unknown', only [charge energy force heat_flux " "length mass momentum pressure velocity] are supported" ); CHECK(warning.msg() == expected); diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index a4919d65..96b85f51 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -935,8 +935,7 @@ def _get_ase_input( tensor.set_info("quantity", infos["quantity"]) tensor.set_info("unit", infos["unit"]) - tensor = tensor.to(dtype=dtype, device=device) - return tensor + return tensor.to(dtype=dtype, device=device) def _ase_to_torch_data(atoms, dtype, device): diff --git a/python/metatomic_torch/metatomic/torch/heat_flux.py b/python/metatomic_torch/metatomic/torch/heat_flux.py new file mode 100644 index 00000000..02dd2f9a --- /dev/null +++ b/python/metatomic_torch/metatomic/torch/heat_flux.py @@ -0,0 +1,442 @@ +from typing import Dict, List, Optional + +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from vesin.metatomic import compute_requested_neighbors_from_options + +from metatomic.torch import ( + AtomisticModel, + ModelEvaluationOptions, + ModelOutput, + NeighborListOptions, + System, + pick_output, +) + + +def _wrap_positions(positions: torch.Tensor, cell: torch.Tensor) -> torch.Tensor: + """ + Wrap positions into the periodic cell. + """ + fractional_positions = positions @ cell.inverse() + fractional_positions = fractional_positions - torch.floor(fractional_positions) + wrapped_positions = fractional_positions @ cell + + return wrapped_positions + + +def _check_close_to_cell_boundary( + cell: torch.Tensor, positions: torch.Tensor, cutoff: float +) -> torch.Tensor: + """ + Detect atoms that lie within a cutoff distance (in our context, the interaction + range of the model) from the periodic cell boundaries, + i.e. have interactions with atoms at the opposite end of the cell. + """ + inv_cell = cell.inverse() + recip = inv_cell.T + norms = torch.linalg.norm(recip, dim=1) + heights = 1.0 / norms + if heights.min() < cutoff: + raise ValueError( + "Cell is too small compared to cutoff = " + str(cutoff) + ". " + "Ensure that all cell vectors are at least this length. Currently, the" + " minimum cell vector length is " + str(heights.min()) + "." + ) + + normals = recip / norms[:, None] + norm_coords = positions @ normals.T + collisions = torch.hstack( + [norm_coords <= cutoff, norm_coords >= heights - cutoff], + ).to(device=positions.device) + + return collisions[ + :, [0, 3, 1, 4, 2, 5] # reorder to (x_lo, x_hi, y_lo, y_hi, z_lo, z_hi) + ] + + +def _collisions_to_replicas(collisions: torch.Tensor) -> torch.Tensor: + """ + Convert boundary-collision flags into a boolean mask over all periodic image + displacements in {0, +1, -1}^3. e.g. for an atom colliding with the x_lo and y_hi + boundaries, we need the replicas at (1, 0, 0), (0, -1, 0), (1, -1, 0) image cells. + + collisions: [N, 6]: has collisions with (x_lo, x_hi, y_lo, y_hi, z_lo, z_hi) + + returns: [N, 3, 3, 3] boolean mask over image displacements in {0, +1, -1}^3 + 0: no replica needed along that axis + 1: +1 replica needed along that axis (i.e., near low boundary, a replica is + placed just outside the high boundary) + 2: -1 replica needed along that axis (i.e., near high boundary, a replica is + placed just outside the low boundary) + axis order: x, y, z + """ + origin = torch.full( + (len(collisions),), True, dtype=torch.bool, device=collisions.device + ) + axs = torch.vstack([origin, collisions[:, 0], collisions[:, 1]]) + ays = torch.vstack([origin, collisions[:, 2], collisions[:, 3]]) + azs = torch.vstack([origin, collisions[:, 4], collisions[:, 5]]) + # leverage broadcasting + outs = axs[:, None, None] & ays[None, :, None] & azs[None, None, :] + outs = torch.movedim(outs, -1, 0) + outs[:, 0, 0, 0] = False # not close to any boundary -> no replica needed + return outs.to(device=collisions.device) + + +def _generate_replica_atoms( + types: torch.Tensor, + positions: torch.Tensor, + cell: torch.Tensor, + replicas: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + For atoms near the low boundary (x_lo/y_lo/z_lo), generate their images shifted + by +1 cell vector (i.e., placed just outside the high boundary). + For atoms near the high boundary (x_hi/y_hi/z_hi), generate images shifted by −1 + cell vector. + """ + replicas = torch.argwhere(replicas) + replica_idx = replicas[:, 0] + replica_offsets = torch.tensor( + [0, 1, -1], device=positions.device, dtype=positions.dtype + )[replicas[:, 1:]] + replica_positions = positions[replica_idx] + replica_offsets @ cell + + return replica_idx, types[replica_idx], replica_positions + + +def _unfold_system(metatomic_system: System, cutoff: float) -> System: + """ + Unfold a periodic system by generating replica atoms for those near the cell + boundaries within the specified cutoff distance. + The unfolded system has no periodic boundary conditions. + """ + + if not metatomic_system.pbc.any(): + raise ValueError("Unfolding systems is only supported for periodic systems.") + wrapped_positions = _wrap_positions( + metatomic_system.positions, metatomic_system.cell + ) + collisions = _check_close_to_cell_boundary( + metatomic_system.cell, wrapped_positions, cutoff + ) + replicas = _collisions_to_replicas(collisions) + replica_idx, replica_types, replica_positions = _generate_replica_atoms( + metatomic_system.types, wrapped_positions, metatomic_system.cell, replicas + ) + unfolded_types = torch.cat( + [ + metatomic_system.types, + replica_types, + ] + ) + unfolded_positions = torch.cat( + [ + wrapped_positions, + replica_positions, + ] + ) + unfolded_idx = torch.cat( + [ + torch.arange(len(metatomic_system.types), device=metatomic_system.device), + replica_idx, + ] + ) + unfolded_n_atoms = len(unfolded_types) + masses_block = metatomic_system.get_data("masses").block() + velocities_block = metatomic_system.get_data("velocities").block() + unfolded_masses = masses_block.values[unfolded_idx] + unfolded_velocities = velocities_block.values[unfolded_idx] + unfolded_masses_block = TensorBlock( + values=unfolded_masses, + samples=Labels( + ["atoms"], + torch.arange(unfolded_n_atoms, device=metatomic_system.device).reshape( + -1, 1 + ), + ), + components=masses_block.components, + properties=masses_block.properties, + ) + unfolded_velocities_block = TensorBlock( + values=unfolded_velocities, + samples=Labels( + ["atoms"], + torch.arange(unfolded_n_atoms, device=metatomic_system.device).reshape( + -1, 1 + ), + ), + components=velocities_block.components, + properties=velocities_block.properties, + ) + unfolded_system = System( + types=unfolded_types, + positions=unfolded_positions, + cell=torch.tensor( + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + dtype=unfolded_positions.dtype, + device=metatomic_system.device, + ), + pbc=torch.tensor([False, False, False], device=metatomic_system.device), + ) + unfolded_system.add_data( + "masses", + TensorMap( + Labels("_", torch.tensor([[0]], device=metatomic_system.device)), + [unfolded_masses_block], + ), + ) + unfolded_system.add_data( + "velocities", + TensorMap( + Labels("_", torch.tensor([[0]], device=metatomic_system.device)), + [unfolded_velocities_block], + ), + ) + return unfolded_system.to(metatomic_system.dtype, metatomic_system.device) + + +class HeatFluxWrapper(torch.nn.Module): + """ + :py:class:`HeatFluxWrapper` is a wrapper around an :py:class:`AtomisticModel` that + computes the heat flux of a system using the unfolded system approach. The heat flux + is computed using the atomic energies (eV), positions(Å), masses(u), + velocities(Å/fs), and the energy gradients. + + The unfolded system is generated by creating replica atoms for those near the cell + boundaries within the interaction range of the model wrapped. The wrapper adds the + heat flux to the model's outputs under the key "heat_flux". + + For more details on the heat flux calculation, see `Langer, M. F., et al., Heat flux + for semilocal machine-learning potentials. (2023). Physical Review B, 108, L100302.` + + To use this wrapper, here's a helper code snippet: + + >>> import torch + >>> def create_heat_flux_model(model: AtomisticModel) -> AtomisticModel: + ... metadata = ModelMetadata() # your model's metadata here + ... wrapper = torch.jit.script(HeatFluxWrapper(model.eval())) + ... capabilities = model.capabilities() + ... outputs = capabilities.outputs.copy() + ... outputs[wrapper._hf_variant] = ModelOutput( + ... quantity="heat_flux", + ... unit="", + ... explicit_gradients=[], + ... per_atom=False, + ... ) + ... new_cap = ModelCapabilities( + ... outputs=outputs, + ... atomic_types=capabilities.atomic_types, + ... interaction_range=capabilities.interaction_range, + ... length_unit=capabilities.length_unit, + ... supported_devices=capabilities.supported_devices, + ... dtype=capabilities.dtype, + ... ) + ... heat_model = AtomisticModel( + ... wrapper.eval(), metadata, capabilities=new_cap + ... ).to(device="cpu") + ... return heat_model + + """ + + def __init__( + self, model: AtomisticModel, variants: Optional[Dict[str, Optional[str]]] = None + ): + """ + :param model: the :py:class:`AtomisticModel` to wrap, which should be able to + compute atomic energies and their gradients with respect to positions + :param variants: a dictionary of variants to use for each output, e.g. + ``{"energy": "pbe"}``, in which case the "pbe" energy output is used to compute + the heat flux. Defaults to ``None``, in which case the default energy output is + used to compute the heat flux. + """ + super().__init__() + + assert isinstance(model, AtomisticModel) + self._model = model.module + self._interaction_range = model.capabilities().interaction_range + if model.capabilities().length_unit not in ["Angstrom", "A"]: + raise NotImplementedError( + "HeatFluxWrapper only supports models with length unit 'Angstrom'" + ) + + self._requested_neighbor_lists = model.requested_neighbor_lists() + self._requested_inputs = { + "masses": ModelOutput(quantity="mass", unit="u", per_atom=True), + "velocities": ModelOutput(quantity="velocity", unit="A/fs", per_atom=True), + } + + variants = variants or {} + default_variant = variants.get("energy") + + resolved_variants = { + key: variants.get(key, default_variant) + for key in [ + "energy", + "energy_uncertainty", + "non_conservative_forces", + "non_conservative_stress", + ] + } + + outputs = model.capabilities().outputs.copy() + has_energy = any( + "energy" == key or key.startswith("energy/") for key in outputs.keys() + ) + if has_energy: + self._energy_key = pick_output( + "energy", outputs, resolved_variants["energy"] + ) + else: + raise ValueError( + "The wrapped model must be able to compute energy outputs to use " + "HeatFluxWrapper." + ) + if outputs[self._energy_key].unit != "eV": + raise ValueError( + "HeatFluxWrapper can only be used with energy outputs in eV" + ) + energies_output = ModelOutput( + quantity="energy", unit=outputs[self._energy_key].unit, per_atom=True + ) + + hf_output = ModelOutput( + quantity="heat_flux", + unit="", + explicit_gradients=[], + per_atom=False, + ) + self._hf_variant = "heat_flux" + ( + "" if default_variant is None else "/" + default_variant + ) + outputs[self._hf_variant] = hf_output + self._unfolded_run_options = ModelEvaluationOptions( + length_unit=model.capabilities().length_unit, + outputs={self._energy_key: energies_output}, + selected_atoms=None, + ) + + @torch.jit.export + def requested_neighbor_lists(self) -> List[NeighborListOptions]: + return self._requested_neighbor_lists + + @torch.jit.export + def requested_inputs(self) -> Dict[str, ModelOutput]: + return self._requested_inputs + + def _barycenter_and_atomic_energies(self, system: System, n_atoms: int): + energy_block = self._model( + [system], + self._unfolded_run_options.outputs, + self._unfolded_run_options.selected_atoms, + )[self._energy_key].block(0) + atom_indices = energy_block.samples.column("atom").to(torch.long) + sorted_order = torch.argsort(atom_indices) + atomic_e = energy_block.values.flatten()[sorted_order] + + total_e = atomic_e[:n_atoms].sum() + r_aux = system.positions.detach() + barycenter = (atomic_e[:n_atoms, None] * r_aux[:n_atoms]).sum(dim=0) + + return barycenter, atomic_e, total_e + + def _calc_unfolded_heat_flux(self, system: System) -> torch.Tensor: + n_atoms = len(system.positions) + unfolded_system = _unfold_system(system, self._interaction_range).to( + system.device + ) + compute_requested_neighbors_from_options( + [unfolded_system], + self.requested_neighbor_lists(), + self._unfolded_run_options.length_unit, + False, + ) + velocities: torch.Tensor = ( + unfolded_system.get_data("velocities").block().values.reshape(-1, 3) + ) + masses: torch.Tensor = ( + unfolded_system.get_data("masses").block().values.reshape(-1) + ) + barycenter, atomic_e, total_e = self._barycenter_and_atomic_energies( + unfolded_system, n_atoms + ) + + term1 = torch.zeros( + (3), device=system.positions.device, dtype=system.positions.dtype + ) + for i in range(3): + grad_i = torch.autograd.grad( + [barycenter[i]], + [unfolded_system.positions], + retain_graph=True, + create_graph=False, + )[0] + grad_i = torch.jit._unwrap_optional(grad_i) + term1[i] = (grad_i * velocities).sum() + + go = torch.jit.annotate( + Optional[List[Optional[torch.Tensor]]], [torch.ones_like(total_e)] + ) + grads = torch.autograd.grad( + [total_e], + [unfolded_system.positions], + grad_outputs=go, + )[0] + grads = torch.jit._unwrap_optional(grads) + term2 = ( + unfolded_system.positions * (grads * velocities).sum(dim=1, keepdim=True) + ).sum(dim=0) + + hf_pot = term1 - term2 + + hf_conv = ( + ( + atomic_e[:n_atoms] + + 0.5 + * masses[:n_atoms] + * torch.linalg.norm(velocities[:n_atoms], dim=1) ** 2 + * 103.6427 # u*A^2/fs^2 to eV + )[:, None] + * velocities[:n_atoms] + ).sum(dim=0) + + return hf_pot + hf_conv + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + ) -> Dict[str, TensorMap]: + outputs_wo_heat_flux = outputs.copy() + if self._hf_variant in outputs: + del outputs_wo_heat_flux[self._hf_variant] + + if len(outputs_wo_heat_flux) == 0: + results = torch.jit.annotate(Dict[str, TensorMap], {}) + else: + results = self._model(systems, outputs_wo_heat_flux, selected_atoms) + + if self._hf_variant not in outputs: + return results + + device = systems[0].device + heat_fluxes: List[torch.Tensor] = [] + for system in systems: + heat_fluxes.append(self._calc_unfolded_heat_flux(system)) + + samples = Labels( + ["system"], torch.arange(len(systems), device=device).reshape(-1, 1) + ) + + hf_block = TensorBlock( + values=torch.vstack(heat_fluxes).reshape(-1, 3, 1).to(device=device), + samples=samples, + components=[Labels(["xyz"], torch.arange(3, device=device).reshape(-1, 1))], + properties=Labels(["heat_flux"], torch.tensor([[0]], device=device)), + ) + results[self._hf_variant] = TensorMap( + Labels("_", torch.tensor([[0]], device=device)), [hf_block] + ) + return results diff --git a/python/metatomic_torch/metatomic/torch/model.py b/python/metatomic_torch/metatomic/torch/model.py index c2a9f528..d980d400 100644 --- a/python/metatomic_torch/metatomic/torch/model.py +++ b/python/metatomic_torch/metatomic/torch/model.py @@ -666,7 +666,21 @@ def _get_requested_inputs( already_requested = False for existing in requested: if existing == new_options: + if ( + requested[existing].quantity + == requested_inputs[new_options].quantity + and requested[existing].unit + != requested_inputs[new_options].unit + and requested[existing].per_atom + == requested_inputs[new_options].per_atom + ): + raise NotImplementedError( + f"Different units for the same quantity " + f"`{requested_inputs[new_options].quantity}` is not " + "supported." + ) already_requested = True + print(f"{new_options = }") if not already_requested: requested[new_options] = requested_inputs[new_options] diff --git a/python/metatomic_torch/tests/ase_calculator.py b/python/metatomic_torch/tests/ase_calculator.py index b8e8da2a..6f304ed7 100644 --- a/python/metatomic_torch/tests/ase_calculator.py +++ b/python/metatomic_torch/tests/ase_calculator.py @@ -892,7 +892,7 @@ def forward( class AdditionalInputModel(torch.nn.Module): - def __init__(self, inputs): + def __init__(self, inputs: Dict[str, ModelOutput]): super().__init__() self._requested_inputs = inputs @@ -911,6 +911,32 @@ def forward( } +class SimpleWrapperModel(torch.nn.Module): + def __init__(self, model: AtomisticModel, inputs: Dict[str, ModelOutput]): + super().__init__() + self._model = model.module + self._requested_inputs = inputs + self._capabilities = model.capabilities() + + def requested_inputs(self) -> Dict[str, ModelOutput]: + return self._requested_inputs + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + results = self._model(systems, outputs, selected_atoms) + results.update( + { + ("extra::" + input): systems[0].get_data(input) + for input in self._requested_inputs + } + ) + return results + + def test_additional_input(atoms): inputs = { "masses": ModelOutput(quantity="mass", unit="u", per_atom=True), @@ -951,6 +977,37 @@ def test_additional_input(atoms): assert np.allclose(values, expected) +def test_wrapper_asks_for_inputs_with_different_units(atoms): + inputs = { + "masses": ModelOutput(quantity="mass", unit="u", per_atom=True), + "velocities": ModelOutput(quantity="velocity", unit="A/fs", per_atom=True), + "charges": ModelOutput(quantity="charge", unit="e", per_atom=True), + "ase::initial_charges": ModelOutput(quantity="charge", unit="e", per_atom=True), + } + outputs = {("extra::" + n): inputs[n] for n in inputs} + capabilities = ModelCapabilities( + outputs=outputs, + atomic_types=[28], + interaction_range=0.0, + supported_devices=["cpu"], + dtype="float64", + ) + + model = AtomisticModel( + AdditionalInputModel(inputs).eval(), ModelMetadata(), capabilities + ) + + inputs_wrapper = { + "masses": ModelOutput(quantity="mass", unit="kg", per_atom=True), + } + wrapper = SimpleWrapperModel(model, inputs_wrapper) + with pytest.raises( + NotImplementedError, + match="Different units for the same quantity `mass` is not supported.", + ): + AtomisticModel(wrapper.eval(), ModelMetadata(), capabilities) + + @pytest.mark.parametrize("device,dtype", ALL_DEVICE_DTYPE) def test_mixed_pbc(model, device, dtype): """Test that the calculator works on a mixed-PBC system""" diff --git a/python/metatomic_torch/tests/test_heat_flux.py b/python/metatomic_torch/tests/test_heat_flux.py new file mode 100644 index 00000000..5fc9ee70 --- /dev/null +++ b/python/metatomic_torch/tests/test_heat_flux.py @@ -0,0 +1,436 @@ +import metatomic_lj_test +import numpy as np +import pytest +import torch +from ase import Atoms +from ase.md.velocitydistribution import MaxwellBoltzmannDistribution +from metatensor.torch import Labels, TensorBlock, TensorMap + +from metatomic.torch import ( + AtomisticModel, + ModelCapabilities, + ModelMetadata, + ModelOutput, + System, +) +from metatomic.torch.ase_calculator import MetatomicCalculator +from metatomic.torch.heat_flux import ( + HeatFluxWrapper, + _check_close_to_cell_boundary, + _collisions_to_replicas, + _generate_replica_atoms, + _unfold_system, + _wrap_positions, +) + + +@pytest.fixture +def model(): + return metatomic_lj_test.lennard_jones_model( + atomic_type=18, + cutoff=7.0, + sigma=3.405, + epsilon=0.01032, + length_unit="Angstrom", + energy_unit="eV", + with_extension=False, + ) + + +@pytest.fixture +def model_in_kcal_per_mol(): + return metatomic_lj_test.lennard_jones_model( + atomic_type=18, + cutoff=7.0, + sigma=3.405, + epsilon=0.01032, + length_unit="Angstrom", + energy_unit="kcal/mol", + with_extension=False, + ) + + +@pytest.fixture +def atoms(): + cell = np.array([[6.0, 0.0, 0.0], [0.0, 6.0, 0.0], [0.0, 0.0, 6.0]]) + positions = np.array([[3.0, 3.0, 3.0]]) + atoms = Atoms("Ar", scaled_positions=positions, cell=cell, pbc=True).repeat( + (2, 2, 2) + ) + MaxwellBoltzmannDistribution( + atoms, temperature_K=300, rng=np.random.default_rng(42) + ) + return atoms + + +def _make_scalar_tensormap(values: torch.Tensor, property_name: str) -> TensorMap: + block = TensorBlock( + values=values, + samples=Labels( + ["atoms"], + torch.arange(values.shape[0], device=values.device).reshape(-1, 1), + ), + components=[], + properties=Labels([property_name], torch.tensor([[0]], device=values.device)), + ) + return TensorMap(Labels("_", torch.tensor([[0]], device=values.device)), [block]) + + +def _make_velocity_tensormap(values: torch.Tensor) -> TensorMap: + block = TensorBlock( + values=values, + samples=Labels( + ["atoms"], + torch.arange(values.shape[0], device=values.device).reshape(-1, 1), + ), + components=[ + Labels( + ["xyz"], + torch.arange(3, device=values.device).reshape(-1, 1), + ) + ], + properties=Labels(["velocity"], torch.tensor([[0]], device=values.device)), + ) + return TensorMap(Labels("_", torch.tensor([[0]], device=values.device)), [block]) + + +def _make_system_with_data(positions: torch.Tensor, cell: torch.Tensor) -> System: + types = torch.tensor([1] * len(positions), dtype=torch.int32) + system = System( + types=types, + positions=positions, + cell=cell, + pbc=torch.tensor([True, True, True]), + ) + masses = torch.ones((len(positions), 1), dtype=positions.dtype) + velocities = torch.zeros((len(positions), 3, 1), dtype=positions.dtype) + system.add_data("masses", _make_scalar_tensormap(masses, "mass")) + system.add_data("velocities", _make_velocity_tensormap(velocities)) + return system + + +class _DummyCapabilities: + """Reusable stub for ``model.capabilities()``.""" + + def __init__(self, energy_unit: str = "eV"): + self.outputs = {"energy": ModelOutput(quantity="energy", unit=energy_unit)} + self.length_unit = "A" + self.interaction_range = 1.0 + + +class _ZeroDummyModel: + """Dummy model returning zero energies. Accepts an optional *energy_unit*.""" + + def __init__(self, energy_unit: str = "eV"): + self._capabilities = _DummyCapabilities(energy_unit) + self.module = None + + def capabilities(self): + return self._capabilities + + def __call__(self, systems, options, check_consistency): + values = torch.zeros((len(systems), 1), dtype=systems[0].positions.dtype) + block = TensorBlock( + values=values, + samples=Labels( + ["system"], + torch.arange(len(systems), device=values.device).reshape(-1, 1), + ), + components=[], + properties=Labels(["energy"], torch.tensor([[0]], device=values.device)), + ) + return { + "energy": TensorMap( + Labels("_", torch.tensor([[0]], device=values.device)), [block] + ) + } + + +def test_wrap_positions_cubic_matches_expected(): + cell = torch.eye(3) * 2.0 + positions = torch.tensor([[-0.1, 0.0, 0.0], [2.1, 1.0, -0.5]]) + wrapped = _wrap_positions(positions, cell) + expected = torch.tensor([[1.9, 0.0, 0.0], [0.1, 1.0, 1.5]]) + assert torch.allclose(wrapped, expected) + + +def test_check_close_to_cell_boundary_cubic_axis_order(): + cell = torch.eye(3) * 2.0 + positions = torch.tensor([[0.1, 1.0, 1.9]]) + collisions = _check_close_to_cell_boundary(cell, positions, cutoff=0.2) + assert collisions.shape == (1, 6) + assert collisions[0].tolist() == [True, False, False, False, False, True] + + +def test_generate_replica_atoms_cubic_offsets(): + types = torch.tensor([1]) + positions = torch.tensor([[0.1, 1.0, 1.0]]) + cell = torch.eye(3) * 2.0 + collisions = torch.tensor([[True, False, False, False, False, False]]) + replicas = _collisions_to_replicas(collisions) + replica_idx, replica_types, replica_positions = _generate_replica_atoms( + types, positions, cell, replicas + ) + assert replica_idx.tolist() == [0] + assert replica_types.tolist() == [1] + assert torch.allclose( + replica_positions, positions + torch.tensor([[2.0, 0.0, 0.0]]) + ) + + +def test_wrap_positions_triclinic_fractional_bounds_and_shift(): + cell = torch.tensor( + [ + [2.0, 0.3, 0.2], + [0.1, 1.7, 0.4], + [0.2, 0.5, 1.9], + ] + ) + positions = torch.tensor( + [ + [-0.1, 0.0, 0.0], + [2.1, 1.6, -0.5], + [4.2, -0.2, 6.1], + ] + ) + inv_cell = cell.inverse() + wrapped = _wrap_positions(positions, cell) + fractional_before = torch.einsum("iv,vk->ik", positions, inv_cell) + fractional_after = torch.einsum("iv,vk->ik", wrapped, inv_cell) + + assert torch.all(fractional_after >= 0) + assert torch.all(fractional_after < 1) + + delta_frac = fractional_after - fractional_before + rounded = torch.round(delta_frac) + assert torch.allclose(delta_frac, rounded, atol=1e-6, rtol=0) + assert torch.allclose(rounded, -torch.floor(fractional_before), atol=1e-6, rtol=0) + + +def test_check_close_to_cell_boundary_triclinic_targets(): + cell = torch.tensor( + [ + [2.0, 0.3, 0.2], + [0.1, 1.7, 0.4], + [0.2, 0.5, 1.9], + ] + ) + cutoff = 0.2 + inv_cell = cell.inverse() + recip = inv_cell.T + norms = torch.linalg.norm(recip, dim=1) + heights = 1.0 / norms + norm_vectors = recip / norms[:, None] + + target = torch.stack( + [ + torch.tensor([0.05, 0.6, 0.6]), + torch.tensor([heights[0] - 0.05, 0.05, heights[2] - 0.1]), + torch.tensor([0.3, heights[1] - 0.05, 0.1]), + ] + ) + positions = target @ torch.inverse(norm_vectors).T + + collisions = _check_close_to_cell_boundary(cell, positions, cutoff=cutoff) + + expected_low = target <= cutoff + expected_high = target >= heights - cutoff + expected = torch.hstack([expected_low, expected_high]) + expected = expected[:, [0, 3, 1, 4, 2, 5]] + + assert torch.equal(collisions, expected) + + +def test_check_close_to_cell_boundary_raises_on_small_cell(): + cell = torch.eye(3) * 1.0 + positions = torch.zeros((1, 3)) + with pytest.raises(ValueError, match="Cell is too small"): + _check_close_to_cell_boundary(cell, positions, cutoff=1.1) + + +def test_collisions_to_replicas_combines_displacements(): + collisions = torch.tensor([[True, False, False, True, False, False]]) + replicas = _collisions_to_replicas(collisions) + assert replicas.shape == (1, 3, 3, 3) + assert replicas[0, 0, 0, 0].item() is False + + nonzero = torch.nonzero(replicas, as_tuple=False) + expected = { + (0, 1, 0, 0), + (0, 0, 2, 0), + (0, 1, 2, 0), + } + assert {tuple(row.tolist()) for row in nonzero} == expected + + +def test_generate_replica_atoms_triclinic_offsets(): + cell = torch.tensor( + [ + [2.0, 0.3, 0.2], + [0.1, 1.7, 0.4], + [0.2, 0.5, 1.9], + ] + ) + types = torch.tensor([1]) + positions = torch.tensor([[0.2, 0.4, 0.6]]) + collisions = torch.tensor([[True, False, True, False, True, False]]) + replicas = _collisions_to_replicas(collisions) + replica_idx, replica_types, replica_positions = _generate_replica_atoms( + types, positions, cell, replicas + ) + + assert replica_idx.tolist() == [0, 0, 0, 0, 0, 0, 0] + assert replica_types.tolist() == [1, 1, 1, 1, 1, 1, 1] + + expected_offsets = [ + cell[0], + cell[1], + cell[2], + cell[0] + cell[1], + cell[0] + cell[2], + cell[1] + cell[2], + cell[0] + cell[1] + cell[2], + ] + expected_positions = [positions[0] + offset for offset in expected_offsets] + + for expected in expected_positions: + assert any( + torch.allclose(expected, actual, atol=1e-6, rtol=0) + for actual in replica_positions + ) + + +def test_unfold_system_adds_replica_and_data(): + cell = torch.eye(3) * 2.0 + positions = torch.tensor([[0.1, 1.0, 1.0]]) + system = _make_system_with_data(positions, cell) + unfolded = _unfold_system(system, cutoff=0.1) + assert len(unfolded.positions) == 2 + assert torch.all(unfolded.pbc == torch.tensor([False, False, False])) + assert torch.allclose(unfolded.cell, torch.zeros_like(unfolded.cell)) + + masses = unfolded.get_data("masses").block().values + velocities = unfolded.get_data("velocities").block().values + assert masses.shape[0] == 2 + assert velocities.shape[0] == 2 + + assert torch.allclose(unfolded.positions[0], positions[0]) + assert torch.allclose( + unfolded.positions[1], positions[0] + torch.tensor([2.0, 0.0, 0.0]) + ) + + +def test_unfold_system_no_replicas_for_interior_atoms(): + """Atoms well inside the cell should produce no replicas.""" + cell = torch.eye(3) * 10.0 + positions = torch.tensor([[5.0, 5.0, 5.0], [3.0, 4.0, 6.0]]) + system = _make_system_with_data(positions, cell) + unfolded = _unfold_system(system, cutoff=1.0) + + assert len(unfolded.positions) == 2 + assert torch.allclose(unfolded.positions, _wrap_positions(positions, cell)) + + +def test_unfold_system_triclinic_cell(): + """Unfolding should work for triclinic cells and propagate all data.""" + cell = torch.tensor( + [ + [4.0, 0.6, 0.4], + [0.2, 3.4, 0.8], + [0.4, 1.0, 3.8], + ] + ) + # One atom near the origin (close to low boundaries), one in the interior + positions = torch.tensor( + [ + [0.05, 0.05, 0.05], + [2.0, 1.7, 1.9], + ] + ) + system = _make_system_with_data(positions, cell) + unfolded = _unfold_system(system, cutoff=0.3) + + # The near-origin atom should generate at least one replica + assert len(unfolded.positions) > 2 + assert torch.all(unfolded.pbc == torch.tensor([False, False, False])) + assert torch.allclose(unfolded.cell, torch.zeros_like(unfolded.cell)) + assert torch.all(unfolded.types == 1) + assert unfolded.get_data("masses").block().values.shape[0] == len( + unfolded.positions + ) + assert unfolded.get_data("velocities").block().values.shape[0] == len( + unfolded.positions + ) + + +def test_heat_flux_wrapper_rejects_non_eV_energy(model_in_kcal_per_mol): + with pytest.raises(ValueError, match="energy outputs in eV"): + HeatFluxWrapper(model_in_kcal_per_mol) + + +def test_heat_flux_wrapper_requested_inputs(model): + wrapper = HeatFluxWrapper(model) + requested = wrapper.requested_inputs() + assert set(requested.keys()) == {"masses", "velocities"} + + +@pytest.mark.parametrize("use_script", [True, False]) +@pytest.mark.parametrize( + "use_variant, expected", + [ + (True, [[9.0147e-05], [-2.6166e-04], [-1.9002e-04]]), + (False, [[8.8238e-05], [-2.5559e-04], [-2.0570e-04]]), + ], +) +def test_heat_flux_wrapper_calc_heat_flux( + model, atoms, expected, use_script, use_variant +): + metadata = ModelMetadata() + wrapper = HeatFluxWrapper( + model.eval(), variants=({"energy": "doubled"} if use_variant else None) + ) + cap = model.capabilities() + outputs = cap.outputs.copy() + outputs[wrapper._hf_variant] = ModelOutput( + quantity="heat_flux", + unit="", + explicit_gradients=[], + per_atom=False, + ) + + new_cap = ModelCapabilities( + outputs=outputs, + atomic_types=cap.atomic_types, + interaction_range=cap.interaction_range, + length_unit=cap.length_unit, + supported_devices=cap.supported_devices, + dtype=cap.dtype, + ) + + if use_script: + wrapper = torch.jit.script(wrapper) + + heat_model = AtomisticModel(wrapper.eval(), metadata, capabilities=new_cap).to( + device="cpu" + ) + calc = MetatomicCalculator( + heat_model, + device="cpu", + additional_outputs={ + wrapper._hf_variant: ModelOutput( + quantity="heat_flux", + unit="", + explicit_gradients=[], + per_atom=False, + ) + }, + check_consistency=True, + ) + atoms.calc = calc + atoms.get_potential_energy() + assert wrapper._hf_variant in atoms.calc.additional_outputs + results = atoms.calc.additional_outputs[wrapper._hf_variant].block().values + assert torch.allclose( + results, + torch.tensor(expected, dtype=results.dtype), + ) diff --git a/tox.ini b/tox.ini index cd0a281d..d6ef0d9b 100644 --- a/tox.ini +++ b/tox.ini @@ -141,6 +141,7 @@ deps = torch=={env:METATOMIC_TESTS_TORCH_VERSION:2.10}.* numpy {env:METATOMIC_TESTS_NUMPY_VERSION_PIN} vesin + vesin-torch ase # for metatensor-lj-test setuptools-scm