From bb9b69df2061c81762816aa84bdab73f0f795b67 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Mon, 9 Mar 2026 12:47:52 +0100 Subject: [PATCH 1/7] feat(units): add expression parser with SI dimensional analysis Replace the old string-matching unit conversion with a Shunting-Yard expression parser that supports *, /, ^, and parentheses. Unit code extracted to units.hpp/cpp per review feedback. Micro sign handling uses explicit base_units entries instead of global normalization. --- docs/src/torch/reference/cxx/misc.rst | 4 +- docs/src/torch/reference/misc.rst | 110 +++- metatomic-torch/CHANGELOG.md | 14 + metatomic-torch/CMakeLists.txt | 2 + .../include/metatomic/torch/model.hpp | 19 +- .../include/metatomic/torch/units.hpp | 61 ++ metatomic-torch/src/model.cpp | 237 +------- metatomic-torch/src/register.cpp | 5 +- metatomic-torch/src/system.cpp | 2 +- metatomic-torch/src/units.cpp | 558 ++++++++++++++++++ metatomic-torch/tests/models.cpp | 16 +- metatomic-torch/tests/units.cpp | 274 +++++++++ .../metatomic/torch/__init__.py | 37 +- .../metatomic/torch/documentation.py | 29 +- .../metatomic_torch/metatomic/torch/model.py | 22 +- python/metatomic_torch/tests/units.py | 192 +++++- 16 files changed, 1247 insertions(+), 335 deletions(-) create mode 100644 metatomic-torch/include/metatomic/torch/units.hpp create mode 100644 metatomic-torch/src/units.cpp create mode 100644 metatomic-torch/tests/units.cpp diff --git a/docs/src/torch/reference/cxx/misc.rst b/docs/src/torch/reference/cxx/misc.rst index e6c1aeda..6dfc15b5 100644 --- a/docs/src/torch/reference/cxx/misc.rst +++ b/docs/src/torch/reference/cxx/misc.rst @@ -1,7 +1,9 @@ Miscellaneous ============= -.. doxygenfunction:: metatomic_torch::unit_conversion_factor +.. doxygenfunction:: metatomic_torch::unit_conversion_factor(const std::string &from_unit, const std::string &to_unit) + +.. doxygenfunction:: metatomic_torch::unit_conversion_factor(const std::string &quantity, const std::string &from_unit, const std::string &to_unit) .. doxygenfunction:: metatomic_torch::pick_device diff --git a/docs/src/torch/reference/misc.rst b/docs/src/torch/reference/misc.rst index e2cf5df9..bef45bd9 100644 --- a/docs/src/torch/reference/misc.rst +++ b/docs/src/torch/reference/misc.rst @@ -9,34 +9,82 @@ Miscellaneous .. _known-quantities-units: -Known quantities and units --------------------------- - -The following quantities and units can be used with metatomic models. Adding new -units and quantities is very easy, please contact us if you need something else! -In the mean time, you can create :py:class:`metatomic.torch.ModelOutput` with -quantities that are not in this table. A warning will be issued and no unit -conversion will be performed. - -When working with one of the quantities in this table, the unit you use must be -one of the registered unit. - -+----------------+------------------------------------------------------------------------------------------------------------------------------------------------------+ -| quantity | units | -+================+======================================================================================================================================================+ -| **length** | ``angstrom`` (``A``), ``Bohr``, ``meter``, ``centimeter`` (``cm``), ``millimeter`` (``mm``), ``micrometer`` (``um``, ``µm``), ``nanometer`` (``nm``) | -+----------------+------------------------------------------------------------------------------------------------------------------------------------------------------+ -| **energy** | ``eV``, ``meV``, ``Hartree``, ``kcal/mol``, ``kJ/mol``, ``Joule`` (``J``), ``Rydberg`` (``Ry``) | -+----------------+------------------------------------------------------------------------------------------------------------------------------------------------------+ -| **force** | ``eV/Angstrom`` (``eV/A``), ``Hartree/Bohr`` | -+----------------+------------------------------------------------------------------------------------------------------------------------------------------------------+ -| **pressure** | ``eV/Angstrom^3`` (``eV/A^3``) | -+----------------+------------------------------------------------------------------------------------------------------------------------------------------------------+ -| **momentum** | ``u*A/fs``, ``u*A/ps``, ``(eV*u)^(1/2)``, ``kg*m/s``, ``hbar/Bohr`` | -+----------------+------------------------------------------------------------------------------------------------------------------------------------------------------+ -| **mass** | ``u`` (``Dalton``), ``kg`` (``kilogram``), ``g`` (``gram``), ``electron_mass`` (``m_e``) | -+----------------+------------------------------------------------------------------------------------------------------------------------------------------------------+ -| **velocity** | ``nm/fs``, ``A/fs``, ``m/s``, ``nm/ps``, ``Bohr*Hartree/hbar`` | -+----------------+------------------------------------------------------------------------------------------------------------------------------------------------------+ -| **charge** | ``e``, ``Coulomb`` (``C``) | -+----------------+------------------------------------------------------------------------------------------------------------------------------------------------------+ +Unit expression parser +---------------------- + +``unit_conversion_factor`` accepts arbitrary unit expressions built from base +tokens combined with ``*``, ``/``, ``^``, and parentheses. For example: + +- ``"kJ/mol"`` +- ``"eV/Angstrom^3"`` +- ``"(eV*u)^(1/2)"`` +- ``"Hartree/Bohr"`` + +Dimensional compatibility is verified automatically; no ``quantity`` parameter +is needed. Token lookup is case-insensitive, and whitespace is ignored. + +Base unit tokens +^^^^^^^^^^^^^^^^ + +.. list-table:: Supported Unit Tokens + :header-rows: 1 + + * - Dimension + - Tokens + - Notes + * - **Length** + - ``angstrom`` (``a``), ``bohr``, ``nm`` (``nanometer``), ``meter`` (``m``), ``cm`` (``centimeter``), ``mm`` (``millimeter``), ``um`` (``micrometer``) + - + * - **Energy** + - ``ev``, ``mev``, ``hartree``, ``ry`` (``rydberg``), ``joule`` (``j``), ``kcal``, ``kj`` + - ``kcal`` and ``kj`` are bare (not per-mol); write ``kcal/mol`` for the per-mole unit + * - **Time** + - ``s`` (``second``), ``ms`` (``millisecond``), ``us`` (``microsecond``), ``ns`` (``nanosecond``), ``ps`` (``picosecond``), ``fs`` (``femtosecond``) + - + * - **Mass** + - ``u`` (``dalton``), ``kg`` (``kilogram``), ``g`` (``gram``), ``electron_mass`` (``m_e``) + - + * - **Charge** + - ``e``, ``coulomb`` (``c``) + - + * - **Dimensionless** + - ``mol`` + - Avogadro scaling factor + * - **Derived** + - ``hbar`` + - :math:`\hbar` in SI (:math:`M L^2 T^{-1}`) + +Known quantities +^^^^^^^^^^^^^^^^ + +When setting ``quantity`` on a :py:class:`~metatomic.torch.ModelOutput`, the +following names are recognized. The parser will check that the unit expression +has dimensions matching the expected quantity. + +.. list-table:: Physical Dimensions + :header-rows: 1 + + * - quantity + - expected dimension + * - **length** + - :math:`L` + * - **energy** + - :math:`M L^2 T^{-2}` + * - **force** + - :math:`M L T^{-2}` + * - **pressure** + - :math:`M L^{-1} T^{-2}` + * - **momentum** + - :math:`M L T^{-1}` + * - **mass** + - :math:`M` + * - **velocity** + - :math:`L T^{-1}` + * - **charge** + - :math:`Q` + +.. note:: + + The 3-argument form ``unit_conversion_factor(quantity, from_unit, to_unit)`` + is deprecated. Use the 2-argument form instead. The ``quantity`` parameter is + ignored by the parser; dimensional compatibility is checked automatically. diff --git a/metatomic-torch/CHANGELOG.md b/metatomic-torch/CHANGELOG.md index 00947737..e786886c 100644 --- a/metatomic-torch/CHANGELOG.md +++ b/metatomic-torch/CHANGELOG.md @@ -17,6 +17,20 @@ a changelog](https://keepachangelog.com/en/1.1.0/) format. This project follows ## [Unreleased](https://github.com/metatensor/metatomic/) +### Added + +- Unit expression parser supporting compound expressions (`kJ/mol/A^2`, + `(eV*u)^(1/2)`, etc.) with automatic dimensional validation +- 2-argument `unit_conversion_factor(from_unit, to_unit)` that parses + arbitrary unit expressions and checks dimensional compatibility + +### Changed + +- `validate_unit` now accepts arbitrary parseable expressions and validates + their dimensions against the expected quantity +- 3-argument `unit_conversion_factor(quantity, from_unit, to_unit)` is + deprecated; the `quantity` parameter is ignored + ## [Version 0.1.11](https://github.com/metatensor/metatomic/releases/tag/metatomic-torch-v0.1.11) - 2026-02-27 ### Added diff --git a/metatomic-torch/CMakeLists.txt b/metatomic-torch/CMakeLists.txt index 8f253344..c661cea5 100644 --- a/metatomic-torch/CMakeLists.txt +++ b/metatomic-torch/CMakeLists.txt @@ -96,6 +96,7 @@ find_package(Torch 2.3 REQUIRED) set(METATOMIC_TORCH_HEADERS "include/metatomic/torch/system.hpp" "include/metatomic/torch/model.hpp" + "include/metatomic/torch/units.hpp" "include/metatomic/torch.hpp" "include/metatomic/torch/outputs.hpp" ) @@ -104,6 +105,7 @@ set(METATOMIC_TORCH_SOURCE "src/misc.cpp" "src/system.cpp" "src/model.cpp" + "src/units.cpp" "src/outputs.cpp" "src/register.cpp" "src/internal/shared_libraries.cpp" diff --git a/metatomic-torch/include/metatomic/torch/model.hpp b/metatomic-torch/include/metatomic/torch/model.hpp index f2f9bdeb..d8cc5db4 100644 --- a/metatomic-torch/include/metatomic/torch/model.hpp +++ b/metatomic-torch/include/metatomic/torch/model.hpp @@ -9,6 +9,7 @@ #include #include "metatomic/torch/exports.h" +#include "metatomic/torch/units.hpp" namespace metatomic_torch { @@ -28,16 +29,6 @@ class ModelMetadataHolder; /// TorchScript will always manipulate `ModelMetadataHolder` through a `torch::intrusive_ptr` using ModelMetadata = torch::intrusive_ptr; -/// Check that a given physical quantity is valid and known. This is -/// intentionally not exported with `METATOMIC_TORCH_EXPORT`, and is only -/// intended for internal use. -bool valid_quantity(const std::string& quantity); - -/// Check that a given unit is valid and known for some physical quantity. This -/// is intentionally not exported with `METATOMIC_TORCH_EXPORT`, and is only -/// intended for internal use. -void validate_unit(const std::string& quantity, const std::string& unit); - /// Information about one of the quantity a model can compute class METATOMIC_TORCH_EXPORT ModelOutputHolder: public torch::CustomClassHolder { @@ -326,14 +317,6 @@ METATOMIC_TORCH_EXPORT metatensor_torch::Module load_atomistic_model( c10::optional extensions_directory = c10::nullopt ); -/// Get the multiplicative conversion factor to use to convert from unit `from` -/// to unit `to`. Both should be units for the given physical `quantity`. -METATOMIC_TORCH_EXPORT double unit_conversion_factor( - const std::string& quantity, - const std::string& from_unit, - const std::string& to_unit -); - } #endif diff --git a/metatomic-torch/include/metatomic/torch/units.hpp b/metatomic-torch/include/metatomic/torch/units.hpp new file mode 100644 index 00000000..06ef5cdc --- /dev/null +++ b/metatomic-torch/include/metatomic/torch/units.hpp @@ -0,0 +1,61 @@ +#ifndef METATOMIC_TORCH_UNITS_HPP +#define METATOMIC_TORCH_UNITS_HPP + +#include + +#include "metatomic/torch/exports.h" + +namespace metatomic_torch { + +/// Check that a given physical quantity is valid and known. This is +/// intentionally not exported with `METATOMIC_TORCH_EXPORT`, and is only +/// intended for internal use. +bool valid_quantity(const std::string& quantity); + +/// Check that a given unit is valid and known for some physical quantity. This +/// is intentionally not exported with `METATOMIC_TORCH_EXPORT`, and is only +/// intended for internal use. +void validate_unit(const std::string& quantity, const std::string& unit); + +/// Get the multiplicative conversion factor to use to convert from +/// `from_unit` to `to_unit`. Both units are parsed as expressions (e.g. +/// "kJ/mol/A^2", "(eV*u)^(1/2)") and their dimensions must match. +/// +/// Unit expressions are built from base tokens combined with `*`, `/`, `^`, +/// and parentheses. Token lookup is case-insensitive, and whitespace is +/// ignored. For example: +/// +/// - `"kJ/mol"` -- energy per mole +/// - `"eV/Angstrom^3"` -- pressure +/// - `"(eV*u)^(1/2)"` -- momentum (fractional powers) +/// - `"Hartree/Bohr"` -- force in atomic units +/// +/// Supported tokens: +/// +/// - **Length**: angstrom (A), bohr, nanometer (nm), meter (m), +/// centimeter (cm), millimeter (mm), micrometer (um, \xC2\xB5m) +/// - **Energy**: eV, meV, Hartree, Ry (rydberg), Joule (J), kcal, kJ +/// (note: kcal and kJ are bare; write kcal/mol for per-mole) +/// - **Time**: second (s), millisecond (ms), microsecond (us, \xC2\xB5s), +/// nanosecond (ns), picosecond (ps), femtosecond (fs) +/// - **Mass**: dalton (u), kilogram (kg), gram (g), electron_mass (m_e) +/// - **Charge**: e, coulomb (c) +/// - **Dimensionless**: mol +/// - **Derived**: hbar +METATOMIC_TORCH_EXPORT double unit_conversion_factor( + const std::string& from_unit, + const std::string& to_unit +); + +/// Deprecated 3-argument overload. The `quantity` parameter is ignored; +/// dimensional compatibility is checked by the parser. Emits a one-time +/// deprecation warning. +METATOMIC_TORCH_EXPORT double unit_conversion_factor( + const std::string& quantity, + const std::string& from_unit, + const std::string& to_unit +); + +} + +#endif diff --git a/metatomic-torch/src/model.cpp b/metatomic-torch/src/model.cpp index a7b6bee2..eff9312a 100644 --- a/metatomic-torch/src/model.cpp +++ b/metatomic-torch/src/model.cpp @@ -1,10 +1,9 @@ #include -#include -#include #include -#include #include +#include +#include #include #include @@ -200,7 +199,7 @@ void ModelCapabilitiesHolder::set_dtype(std::string dtype) { } double ModelCapabilitiesHolder::engine_interaction_range(const std::string& engine_length_unit) const { - return interaction_range * unit_conversion_factor("length", length_unit_, engine_length_unit); + return interaction_range * unit_conversion_factor(length_unit_, engine_length_unit); } std::string ModelCapabilitiesHolder::to_json() const { @@ -975,233 +974,3 @@ metatensor_torch::Module metatomic_torch::load_atomistic_model( return metatensor_torch::Module(model); } - -/******************************************************************************/ -/******************************************************************************/ - -/// remove all whitespace in a string (i.e. `kcal / mol` => `kcal/mol`) -static std::string remove_spaces(std::string value) { - auto new_end = std::remove_if(value.begin(), value.end(), - [](unsigned char c){ return std::isspace(c); } - ); - value.erase(new_end, value.end()); - return value; -} - - -/// Lower case string, to be used as a key in Quantity.conversion (we want -/// "Angstrom" and "angstrom" to be equivalent). -class LowercaseString { -public: - LowercaseString(std::string init): original_(std::move(init)) { - std::transform(original_.begin(), original_.end(), std::back_inserter(lowercase_), &::tolower); - } - - LowercaseString(const char* init): LowercaseString(std::string(init)) {} - - operator std::string&() { - return lowercase_; - } - operator std::string const&() const { - return lowercase_; - } - - const std::string& original() const { - return original_; - } - - bool operator==(const LowercaseString& other) const { - return this->lowercase_ == other.lowercase_; - } - -private: - std::string original_; - std::string lowercase_; -}; - -template <> -struct std::hash { - size_t operator()(const LowercaseString& k) const { - return std::hash()(k); - } -}; - -/// Information for unit conversion for this physical quantity -struct Quantity { - /// the quantity name - std::string name; - - /// baseline unit for this quantity - std::string baseline; - /// set of conversion from the key to the baseline unit - std::unordered_map conversions; - std::unordered_map alternatives; - - std::string normalize_unit(const std::string& original_unit) { - if (original_unit.empty()) { - return original_unit; - } - - std::string unit = remove_spaces(original_unit); - auto alternative = this->alternatives.find(unit); - if (alternative != this->alternatives.end()) { - unit = alternative->second; - } - - if (this->conversions.find(unit) == this->conversions.end()) { - auto valid_units = std::vector(); - for (const auto& it: this->conversions) { - valid_units.emplace_back(it.first.original()); - } - - C10_THROW_ERROR(ValueError, - "unknown unit '" + original_unit + "' for " + name + ", " - "only [" + torch::str(valid_units) + "] are supported" - ); - } - - return unit; - } - - double conversion(const std::string& from_unit, const std::string& to_unit) { - auto from = this->normalize_unit(from_unit); - auto to = this->normalize_unit(to_unit); - - if (from.empty() || to.empty()) { - return 1.0; - } - - return this->conversions.at(to) / this->conversions.at(from); - } -}; - -static std::map KNOWN_QUANTITIES = { - {"length", Quantity{/* name */ "length", /* baseline */ "Angstrom", { - {"Angstrom", 1.0}, - {"Bohr", 1.8897261258369282}, - {"meter", 1e-10}, - {"centimeter", 1e-8}, - {"millimeter", 1e-7}, - {"micrometer", 0.0001}, - {"nanometer", 0.1}, - }, { - // alternative names - {"A", "Angstrom"}, - {"cm", "centimeter"}, - {"mm", "millimeter"}, - {"um", "micrometer"}, - {"µm", "micrometer"}, - {"nm", "nanometer"}, - }}}, - {"energy", Quantity{/* name */ "energy", /* baseline */ "eV", { - {"eV", 1.0}, - {"meV", 1000.0}, - {"Hartree", 0.03674932247495664}, - {"kcal/mol", 23.060548012069496}, - {"kJ/mol", 96.48533288249877}, - {"Joule", 1.60218e-19}, - {"Rydberg", 0.07349864435130857}, - }, { - // alternative names - {"J", "Joule"}, - {"Ry", "Rydberg"}, - }}}, - {"force", Quantity{/* name */ "force", /* baseline */ "eV/Angstrom", { - {"eV/Angstrom", 1.0}, - {"Hartree/Bohr", 0.019446904} - }, { - // alternative names - {"eV/A", "eV/Angstrom"}, - }}}, - {"pressure", Quantity{/* name */ "pressure", /* baseline */ "eV/Angstrom^3", { - {"eV/Angstrom^3", 1.0}, - }, { - // alternative names - {"eV/A^3", "eV/Angstrom^3"}, - }}}, - {"momentum", Quantity{/* name */ "momentum", /* baseline */ "u * A / fs", { - {"u*A/fs", 1.0}, - {"u*A/ps", 1000.0}, - {"(eV*u)^(1/2)", 10.1805057179}, - {"kg*m/s", 1.6605390666e-22}, - {"hbar/Bohr", 83.32476}, - }, { - // alternative names - }}}, - {"mass", Quantity{/* name */ "mass", /* baseline */ "u ", { - {"u", 1.0}, - {"kilogram", 1.66053906892e-27}, - {"gram", 1.66053906892e-24}, - {"electron_mass", 1822.8885}, - }, { - // alternative names - {"Dalton", "u"}, - {"kg", "kilogram"}, - {"g", "gram"}, - {"electron_mass", "m_e"}, - }}}, - {"velocity", Quantity{/* name */ "velocity", /* baseline */ "nm/fs", { - {"nm/fs", 1.0}, - {"A/fs", 1e1}, - {"m/s", 1e6}, - {"nm/ps", 1e3}, - {"(eV/u)^(1/2)", 101.80506}, - {"Bohr*Hartree/hbar", 0.45710289}, - }, { - // alternative names - }}}, - {"charge", Quantity{/* name */ "charge", /* baseline */ "e", { - {"e", 1.0}, - {"Coulomb", 1.602176634e-19}, - }, { - // alternative names - {"C", "Coulomb"}, - }}}, -}; - -bool metatomic_torch::valid_quantity(const std::string& quantity) { - if (quantity.empty()) { - return false; - } - - if (KNOWN_QUANTITIES.find(quantity) == KNOWN_QUANTITIES.end()) { - auto valid_quantities = std::vector(); - for (const auto& it: KNOWN_QUANTITIES) { - valid_quantities.emplace_back(it.first); - } - - static std::unordered_set ALREADY_WARNED = {}; - if (ALREADY_WARNED.insert(quantity).second) { - TORCH_WARN( - "unknown quantity '", quantity, "', only [", - torch::str(valid_quantities), "] are supported" - ); - } - return false; - } else { - return true; - } -} - - -void metatomic_torch::validate_unit(const std::string& quantity, const std::string& unit) { - if (quantity.empty() || unit.empty()) { - return; - } - - if (valid_quantity(quantity)) { - KNOWN_QUANTITIES.at(quantity).normalize_unit(unit); - } -} - -double metatomic_torch::unit_conversion_factor( - const std::string& quantity, - const std::string& from_unit, - const std::string& to_unit -) { - if (valid_quantity(quantity)) { - return KNOWN_QUANTITIES.at(quantity).conversion(from_unit, to_unit); - } else { - return 1.0; - } -} diff --git a/metatomic-torch/src/register.cpp b/metatomic-torch/src/register.cpp index f965b896..38f830d6 100644 --- a/metatomic-torch/src/register.cpp +++ b/metatomic-torch/src/register.cpp @@ -256,7 +256,10 @@ TORCH_LIBRARY(metatomic, m) { m.def("pick_output(str requested_output, Dict(str, __torch__.torch.classes.metatomic.ModelOutput) outputs, str? desired_variant = None) -> str", pick_output); m.def("read_model_metadata(str path) -> __torch__.torch.classes.metatomic.ModelMetadata", read_model_metadata); - m.def("unit_conversion_factor(str quantity, str from_unit, str to_unit) -> float", unit_conversion_factor); + m.def("unit_conversion_factor(str quantity, str from_unit, str to_unit) -> float", + static_cast(&unit_conversion_factor)); + m.def("unit_conversion_factor_v2(str from_unit, str to_unit) -> float", + static_cast(&unit_conversion_factor)); // manually construct the schema for "check_atomistic_model(str path) -> ()", // so we can set AliasAnalysisKind to CONSERVATIVE. In turn, this make it so diff --git a/metatomic-torch/src/system.cpp b/metatomic-torch/src/system.cpp index 4df00ec3..ffdedaef 100644 --- a/metatomic-torch/src/system.cpp +++ b/metatomic-torch/src/system.cpp @@ -52,7 +52,7 @@ void NeighborListOptionsHolder::set_length_unit(std::string length_unit) { } double NeighborListOptionsHolder::engine_cutoff(const std::string& engine_length_unit) const { - return cutoff_ * unit_conversion_factor("length", length_unit_, engine_length_unit); + return cutoff_ * unit_conversion_factor(length_unit_, engine_length_unit); } std::string NeighborListOptionsHolder::repr() const { diff --git a/metatomic-torch/src/units.cpp b/metatomic-torch/src/units.cpp new file mode 100644 index 00000000..5ec3c775 --- /dev/null +++ b/metatomic-torch/src/units.cpp @@ -0,0 +1,558 @@ +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "metatomic/torch/units.hpp" + +/******************************************************************************/ +/*** Unit expression parser with SI-based dimensional analysis ***/ +/******************************************************************************/ + +/// Physical dimension vector: [Length, Time, Mass, Charge, Temperature] +/// Exponents are double to support fractional powers like (eV*u)^(1/2). +struct Dimension { + std::array exponents = {}; + + Dimension operator*(const Dimension& other) const { + Dimension result; + for (size_t i = 0; i < 5; ++i) { + result.exponents[i] = exponents[i] + other.exponents[i]; + } + return result; + } + + Dimension operator/(const Dimension& other) const { + Dimension result; + for (size_t i = 0; i < 5; ++i) { + result.exponents[i] = exponents[i] - other.exponents[i]; + } + return result; + } + + Dimension pow(double p) const { + Dimension result; + for (size_t i = 0; i < 5; ++i) { + result.exponents[i] = exponents[i] * p; + } + return result; + } + + bool operator==(const Dimension& other) const { + for (size_t i = 0; i < 5; ++i) { + if (std::fabs(exponents[i] - other.exponents[i]) > 1e-10) { + return false; + } + } + return true; + } + + bool operator!=(const Dimension& other) const { + return !(*this == other); + } + + std::string to_string() const { + static const char* names[] = {"L", "T", "M", "Q", "Th"}; + std::string result = "["; + for (size_t i = 0; i < 5; ++i) { + if (i > 0) result += ","; + result += names[i]; + result += "="; + // format as integer if close to integer, else as decimal + double v = exponents[i]; + if (std::fabs(v - std::round(v)) < 1e-10) { + result += std::to_string(static_cast(std::round(v))); + } else { + result += std::to_string(v); + } + } + result += "]"; + return result; + } +}; + +/// A parsed unit value: SI conversion factor and physical dimension. +struct UnitValue { + double factor; + Dimension dim; +}; + +// Dimension constants for readability +// L T M Q Th +static const Dimension DIM_LENGTH = {{ 1, 0, 0, 0, 0 }}; +static const Dimension DIM_TIME = {{ 0, 1, 0, 0, 0 }}; +static const Dimension DIM_MASS = {{ 0, 0, 1, 0, 0 }}; +static const Dimension DIM_CHARGE = {{ 0, 0, 0, 1, 0 }}; +static const Dimension DIM_TEMPERATURE = {{ 0, 0, 0, 0, 1 }}; +static const Dimension DIM_ENERGY = {{ 2, -2, 1, 0, 0 }}; +static const Dimension DIM_NONE = {{ 0, 0, 0, 0, 0 }}; + +/// Lowercase a string in place and return it. +static std::string to_lower(std::string s) { + std::transform(s.begin(), s.end(), s.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); } + ); + return s; +} + +/// All base units with SI factors and dimensions. +/// Factors are expressed in SI base units (m, s, kg, C, K). +/// Case-insensitive lookup: tokens are lowercased before searching. +static const std::unordered_map& base_units() { + static const auto units = std::unordered_map{ + // --- Length --- + {"angstrom", {1e-10, DIM_LENGTH}}, + {"a", {1e-10, DIM_LENGTH}}, + {"bohr", {5.29177210903e-11, DIM_LENGTH}}, + {"nm", {1e-9, DIM_LENGTH}}, + {"nanometer",{1e-9, DIM_LENGTH}}, + {"meter", {1.0, DIM_LENGTH}}, + {"m", {1.0, DIM_LENGTH}}, + {"cm", {1e-2, DIM_LENGTH}}, + {"centimeter",{1e-2, DIM_LENGTH}}, + {"mm", {1e-3, DIM_LENGTH}}, + {"millimeter",{1e-3, DIM_LENGTH}}, + {"um", {1e-6, DIM_LENGTH}}, + {"\xc2\xb5m", {1e-6, DIM_LENGTH}}, + {"micrometer",{1e-6, DIM_LENGTH}}, + + // --- Energy --- + {"ev", {1.602176634e-19, DIM_ENERGY}}, + {"mev", {1.602176634e-22, DIM_ENERGY}}, + {"hartree", {4.3597447222071e-18, DIM_ENERGY}}, + {"ry", {2.1798723611e-18, DIM_ENERGY}}, + {"rydberg", {2.1798723611e-18, DIM_ENERGY}}, + {"joule", {1.0, DIM_ENERGY}}, + {"j", {1.0, DIM_ENERGY}}, + {"kcal", {4184.0, DIM_ENERGY}}, + {"kj", {1000.0, DIM_ENERGY}}, + + // --- Time --- + {"s", {1.0, DIM_TIME}}, + {"second", {1.0, DIM_TIME}}, + {"ms", {1e-3, DIM_TIME}}, + {"millisecond", {1e-3, DIM_TIME}}, + {"us", {1e-6, DIM_TIME}}, + {"\xc2\xb5s", {1e-6, DIM_TIME}}, + {"microsecond", {1e-6, DIM_TIME}}, + {"ns", {1e-9, DIM_TIME}}, + {"nanosecond",{1e-9, DIM_TIME}}, + {"ps", {1e-12, DIM_TIME}}, + {"picosecond",{1e-12, DIM_TIME}}, + {"fs", {1e-15, DIM_TIME}}, + {"femtosecond",{1e-15, DIM_TIME}}, + + // --- Mass --- + {"u", {1.66053906660e-27, DIM_MASS}}, + {"dalton", {1.66053906660e-27, DIM_MASS}}, + {"kg", {1.0, DIM_MASS}}, + {"kilogram", {1.0, DIM_MASS}}, + {"g", {1e-3, DIM_MASS}}, + {"gram", {1e-3, DIM_MASS}}, + {"electron_mass", {9.1093837015e-31, DIM_MASS}}, + {"m_e", {9.1093837015e-31, DIM_MASS}}, + + // --- Charge --- + {"e", {1.602176634e-19, DIM_CHARGE}}, + {"coulomb", {1.0, DIM_CHARGE}}, + {"c", {1.0, DIM_CHARGE}}, + + // --- Dimensionless --- + {"mol", {6.02214076e23, DIM_NONE}}, + + // --- Derived --- + {"hbar", {1.054571817e-34, {{2, -1, 1, 0, 0}}}}, + }; + return units; +} + +// ---- Tokenizer ---- + +enum class TokenType { + LParen, RParen, Mul, Div, Pow, Value +}; + +struct Token { + TokenType type; + std::string value; // only meaningful for Value tokens + + int precedence() const { + switch (type) { + case TokenType::LParen: + case TokenType::RParen: return 0; + case TokenType::Mul: + case TokenType::Div: return 10; + case TokenType::Pow: return 20; + default: return -1; + } + } + + std::string as_str() const { + switch (type) { + case TokenType::LParen: return "("; + case TokenType::RParen: return ")"; + case TokenType::Mul: return "*"; + case TokenType::Div: return "/"; + case TokenType::Pow: return "^"; + case TokenType::Value: return value; + } + return "?"; + } +}; + +static std::vector tokenize(const std::string& unit) { + std::vector tokens; + std::string current; + + for (size_t i = 0; i < unit.size(); ++i) { + auto byte = static_cast(unit[i]); + + // Handle UTF-8 micro sign (U+00B5): 0xC2 0xB5 + if (byte == 0xC2 && i + 1 < unit.size() + && static_cast(unit[i + 1]) == 0xB5) + { + current += unit[i]; + current += unit[i + 1]; + ++i; // skip second byte + continue; + } + + char ch = unit[i]; + if (ch == '*' || ch == '/' || ch == '^' || ch == '(' || ch == ')') { + if (!current.empty()) { + tokens.push_back({TokenType::Value, current}); + current.clear(); + } + TokenType t; + switch (ch) { + case '*': t = TokenType::Mul; break; + case '/': t = TokenType::Div; break; + case '^': t = TokenType::Pow; break; + case '(': t = TokenType::LParen; break; + case ')': t = TokenType::RParen; break; + default: t = TokenType::Value; break; // unreachable + } + tokens.push_back({t, std::string(1, ch)}); + } else if (!std::isspace(byte)) { + current += ch; + } + } + if (!current.empty()) { + tokens.push_back({TokenType::Value, current}); + } + return tokens; +} + +// ---- Shunting-Yard ---- + +/// Convert infix tokens to Reverse Polish Notation using the Shunting-Yard +/// algorithm. All operators are treated as left-associative. +static std::vector shunting_yard(const std::vector& tokens) { + std::vector output; + std::vector operators; + + for (const auto& token : tokens) { + switch (token.type) { + case TokenType::Value: + output.push_back(token); + break; + case TokenType::Mul: + case TokenType::Div: + case TokenType::Pow: { + while (!operators.empty()) { + const auto& top = operators.back(); + // left-associative: pop while top >= current + if (token.precedence() <= top.precedence()) { + output.push_back(operators.back()); + operators.pop_back(); + } else { + break; + } + } + operators.push_back(token); + break; + } + case TokenType::LParen: + operators.push_back(token); + break; + case TokenType::RParen: { + while (!operators.empty() && operators.back().type != TokenType::LParen) { + output.push_back(operators.back()); + operators.pop_back(); + } + if (operators.empty() || operators.back().type != TokenType::LParen) { + C10_THROW_ERROR(ValueError, + "unit expression has unbalanced parentheses" + ); + } + operators.pop_back(); // discard LParen + break; + } + } + } + + while (!operators.empty()) { + if (operators.back().type == TokenType::LParen || + operators.back().type == TokenType::RParen) { + C10_THROW_ERROR(ValueError, + "unit expression has unbalanced parentheses" + ); + } + output.push_back(operators.back()); + operators.pop_back(); + } + + return output; +} + +// ---- AST evaluator ---- +// +// Departure from lumol: Pow exponent is a full sub-expression (not just i32) +// to handle ^(1/2). The exponent sub-expression must be dimensionless; its +// factor value becomes the exponent. + +struct UnitExpr; +using UnitExprPtr = std::unique_ptr; + +struct UnitExpr { + struct Val { UnitValue value; }; + struct Mul { UnitExprPtr lhs, rhs; }; + struct Div { UnitExprPtr lhs, rhs; }; + struct Pow { UnitExprPtr base, exponent; }; + + std::variant data; + + UnitValue eval() const { + return std::visit([](const auto& v) -> UnitValue { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return v.value; + } else if constexpr (std::is_same_v) { + auto l = v.lhs->eval(); + auto r = v.rhs->eval(); + return {l.factor * r.factor, l.dim * r.dim}; + } else if constexpr (std::is_same_v) { + auto l = v.lhs->eval(); + auto r = v.rhs->eval(); + return {l.factor / r.factor, l.dim / r.dim}; + } else if constexpr (std::is_same_v) { + auto b = v.base->eval(); + auto e = v.exponent->eval(); + if (e.dim != DIM_NONE) { + C10_THROW_ERROR(ValueError, + "exponent in unit expression must be dimensionless, " + "got dimension " + e.dim.to_string() + ); + } + return {std::pow(b.factor, e.factor), b.dim.pow(e.factor)}; + } + }, data); + } +}; + +/// Read one expression from the RPN stream (recursive, pops from the back). +static UnitExprPtr read_expr(std::vector& stream) { + if (stream.empty()) { + C10_THROW_ERROR(ValueError, + "malformed unit expression: missing a value" + ); + } + + auto token = stream.back(); + stream.pop_back(); + + switch (token.type) { + case TokenType::Value: { + auto lower = to_lower(token.value); + const auto& units = base_units(); + auto it = units.find(lower); + if (it != units.end()) { + auto expr = std::make_unique(); + expr->data = UnitExpr::Val{it->second}; + return expr; + } + // try parsing as a numeric literal (dimensionless) + try { + double val = std::stod(token.value); + auto expr = std::make_unique(); + expr->data = UnitExpr::Val{{val, DIM_NONE}}; + return expr; + } catch (...) { + C10_THROW_ERROR(ValueError, + "unknown unit '" + token.value + "'" + ); + } + } + case TokenType::Mul: { + auto rhs = read_expr(stream); + auto lhs = read_expr(stream); + auto expr = std::make_unique(); + expr->data = UnitExpr::Mul{std::move(lhs), std::move(rhs)}; + return expr; + } + case TokenType::Div: { + auto rhs = read_expr(stream); + auto lhs = read_expr(stream); + auto expr = std::make_unique(); + expr->data = UnitExpr::Div{std::move(lhs), std::move(rhs)}; + return expr; + } + case TokenType::Pow: { + // Exponent is a full sub-expression (supports ^(1/2)) + auto exponent = read_expr(stream); + auto base = read_expr(stream); + auto expr = std::make_unique(); + expr->data = UnitExpr::Pow{std::move(base), std::move(exponent)}; + return expr; + } + default: + C10_THROW_ERROR(ValueError, + "unexpected token in unit expression: " + token.as_str() + ); + } +} + +/// Parse a unit expression string and return the evaluated UnitValue. +static UnitValue parse_unit_expression(const std::string& unit) { + if (unit.empty()) { + return {1.0, DIM_NONE}; + } + + auto tokens = tokenize(unit); + if (tokens.empty()) { + return {1.0, DIM_NONE}; + } + + auto rpn = shunting_yard(tokens); + auto ast = read_expr(rpn); + + if (!rpn.empty()) { + std::string remaining; + for (const auto& t : rpn) { + if (!remaining.empty()) remaining += " "; + remaining += t.as_str(); + } + C10_THROW_ERROR(ValueError, + "malformed unit expression: leftover tokens '" + remaining + "'" + ); + } + + return ast->eval(); +} + +// ---- Quantity dimension map (for validate_unit) ---- + +static const std::unordered_map& quantity_dims() { + static const auto dims = std::unordered_map{ + {"length", DIM_LENGTH}, + {"energy", DIM_ENERGY}, + {"force", {{1, -2, 1, 0, 0}}}, // energy/length + {"pressure", {{-1, -2, 1, 0, 0}}}, // energy/length^3 + {"momentum", {{1, -1, 1, 0, 0}}}, // mass*length/time + {"mass", DIM_MASS}, + {"velocity", {{1, -1, 0, 0, 0}}}, // length/time + {"charge", DIM_CHARGE}, + }; + return dims; +} + +// ---- Public API ---- + +/// 2-argument unit_conversion_factor: parse both expressions, check dimensions +/// match, and return from_factor / to_factor. +double metatomic_torch::unit_conversion_factor( + const std::string& from_unit, + const std::string& to_unit +) { + if (from_unit.empty() || to_unit.empty()) { + return 1.0; + } + + auto from = parse_unit_expression(from_unit); + auto to = parse_unit_expression(to_unit); + + if (from.dim != to.dim) { + C10_THROW_ERROR(ValueError, + "dimension mismatch in unit conversion: '" + from_unit + + "' has dimension " + from.dim.to_string() + + " but '" + to_unit + "' has dimension " + to.dim.to_string() + ); + } + + return from.factor / to.factor; +} + +bool metatomic_torch::valid_quantity(const std::string& quantity) { + if (quantity.empty()) { + return false; + } + + const auto& dims = quantity_dims(); + if (dims.find(quantity) == dims.end()) { + auto valid_quantities = std::vector(); + for (const auto& it: dims) { + valid_quantities.emplace_back(it.first); + } + std::sort(valid_quantities.begin(), valid_quantities.end()); + + static std::unordered_set ALREADY_WARNED = {}; + if (ALREADY_WARNED.insert(quantity).second) { + TORCH_WARN( + "unknown quantity '", quantity, "', only [", + torch::str(valid_quantities), "] are supported" + ); + } + return false; + } else { + return true; + } +} + + +void metatomic_torch::validate_unit(const std::string& quantity, const std::string& unit) { + if (quantity.empty() || unit.empty()) { + return; + } + + // Always try to parse the expression (catches syntax errors) + auto parsed = parse_unit_expression(unit); + + // If the quantity is known, verify dimensions match + const auto& dims = quantity_dims(); + auto it = dims.find(quantity); + if (it != dims.end()) { + if (parsed.dim != it->second) { + C10_THROW_ERROR(ValueError, + "unit '" + unit + "' has dimension " + parsed.dim.to_string() + + " which is incompatible with quantity '" + quantity + + "' (expected " + it->second.to_string() + ")" + ); + } + } +} + +/// Deprecated 3-argument overload: ignores quantity, delegates to 2-arg. +double metatomic_torch::unit_conversion_factor( + const std::string& quantity, + const std::string& from_unit, + const std::string& to_unit +) { + static std::once_flag warn_flag; + std::call_once(warn_flag, [&]() { + TORCH_WARN( + "the 3-argument unit_conversion_factor(quantity, from, to) is " + "deprecated; use the 2-argument unit_conversion_factor(from, to) " + "instead. The quantity parameter is no longer needed." + ); + }); + + return metatomic_torch::unit_conversion_factor(from_unit, to_unit); +} diff --git a/metatomic-torch/tests/models.cpp b/metatomic-torch/tests/models.cpp index 0ae77300..605958d6 100644 --- a/metatomic-torch/tests/models.cpp +++ b/metatomic-torch/tests/models.cpp @@ -55,7 +55,7 @@ TEST_CASE("Models metadata") { ); CHECK_THROWS_WITH(options->set_length_unit("unknown"), - StartsWith("unknown unit 'unknown' for length") + StartsWith("unknown unit token 'unknown'") ); } @@ -103,7 +103,7 @@ TEST_CASE("Models metadata") { ); CHECK_THROWS_WITH(output->set_unit("unknown"), - StartsWith("unknown unit 'unknown' for length") + StartsWith("unknown unit token 'unknown'") ); struct WarningHandler: public torch::WarningHandler { @@ -135,8 +135,8 @@ TEST_CASE("Models metadata") { auto output = torch::make_intrusive(); output->per_atom = true; - output->set_quantity("something"); - output->set_unit("something"); + output->set_quantity("energy"); + output->set_unit("eV"); options->outputs.insert("output_2", output); const auto* expected = R"({ @@ -156,8 +156,8 @@ TEST_CASE("Models metadata") { "description": "", "explicit_gradients": [], "per_atom": true, - "quantity": "something", - "unit": "something" + "quantity": "energy", + "unit": "eV" } }, "selected_atoms": null @@ -205,7 +205,7 @@ TEST_CASE("Models metadata") { ); CHECK_THROWS_WITH(options->set_length_unit("unknown"), - StartsWith("unknown unit 'unknown' for length") + StartsWith("unknown unit token 'unknown'") ); } @@ -295,7 +295,7 @@ TEST_CASE("Models metadata") { ); CHECK_THROWS_WITH(capabilities->set_length_unit("unknown"), - StartsWith("unknown unit 'unknown' for length") + StartsWith("unknown unit token 'unknown'") ); auto capabilities_variants = torch::make_intrusive(); diff --git a/metatomic-torch/tests/units.cpp b/metatomic-torch/tests/units.cpp new file mode 100644 index 00000000..79e61359 --- /dev/null +++ b/metatomic-torch/tests/units.cpp @@ -0,0 +1,274 @@ +#include +#include + +#include + +#include "metatomic/torch.hpp" + +#include +using Catch::Detail::Approx; +using Catch::Matchers::Contains; +using Catch::Matchers::StartsWith; + +class DeprecationWarningHandler : public torch::WarningHandler { +public: + std::vector messages; + + void process(const torch::Warning& warning) override { + messages.push_back(warning.msg()); + } +}; + +// ---- Simple conversions ---- + +TEST_CASE("Simple length conversions") { + // Angstrom <-> Bohr + double a_to_bohr = metatomic_torch::unit_conversion_factor("Angstrom", "Bohr"); + CHECK(a_to_bohr == Approx(1.8897259886).epsilon(1e-6)); + + double bohr_to_a = metatomic_torch::unit_conversion_factor("Bohr", "Angstrom"); + CHECK(bohr_to_a == Approx(0.529177210903).epsilon(1e-6)); + + // Angstrom <-> nm + double a_to_nm = metatomic_torch::unit_conversion_factor("Angstrom", "nm"); + CHECK(a_to_nm == Approx(0.1).epsilon(1e-12)); + + // Angstrom <-> meter + double a_to_m = metatomic_torch::unit_conversion_factor("Angstrom", "meter"); + CHECK(a_to_m == Approx(1e-10).epsilon(1e-20)); +} + +TEST_CASE("Simple energy conversions") { + // eV <-> meV + double ev_to_mev = metatomic_torch::unit_conversion_factor("eV", "meV"); + CHECK(ev_to_mev == Approx(1000.0).epsilon(1e-10)); + + // eV <-> Hartree + double ev_to_hartree = metatomic_torch::unit_conversion_factor("eV", "Hartree"); + CHECK(ev_to_hartree == Approx(0.0367493).epsilon(1e-4)); + + // eV <-> Rydberg + double ev_to_ry = metatomic_torch::unit_conversion_factor("eV", "Ry"); + CHECK(ev_to_ry == Approx(0.0734987).epsilon(1e-4)); +} + +// ---- Compound expressions ---- + +TEST_CASE("Compound unit expressions") { + // eV/Angstrom <-> Hartree/Bohr (force) + double f_conv = metatomic_torch::unit_conversion_factor("eV/Angstrom", "Hartree/Bohr"); + CHECK(f_conv == Approx(0.0194469).epsilon(1e-3)); + + // kJ/mol <-> kcal/mol (energy) + double e_conv = metatomic_torch::unit_conversion_factor("kJ/mol", "kcal/mol"); + CHECK(e_conv == Approx(1000.0 / 4184.0).epsilon(1e-6)); + + // eV/Angstrom^3 (pressure) -- identity + double p_id = metatomic_torch::unit_conversion_factor("eV/Angstrom^3", "eV/A^3"); + CHECK(p_id == Approx(1.0).epsilon(1e-10)); +} + +// ---- Fractional powers ---- + +TEST_CASE("Fractional power expressions") { + // (eV*u)^(1/2) should have dimension of momentum: M*L/T + // Compare to u*A/fs + double conv = metatomic_torch::unit_conversion_factor("(eV*u)^(1/2)", "u*A/fs"); + // Cross-check: sqrt(eV_SI * u_SI) / (u_SI * A_SI / fs_SI) + double ev_si = 1.602176634e-19; + double u_si = 1.66053906660e-27; + double a_si = 1e-10; + double fs_si = 1e-15; + double expected = std::sqrt(ev_si * u_si) / (u_si * a_si / fs_si); + CHECK(conv == Approx(expected).epsilon(1e-4)); + + // (eV/u)^(1/2) has dimension of velocity: L/T + double v_conv = metatomic_torch::unit_conversion_factor("(eV/u)^(1/2)", "A/fs"); + double v_expected = std::sqrt(ev_si / u_si) / (a_si / fs_si); + CHECK(v_conv == Approx(v_expected).epsilon(1e-4)); +} + +// ---- Case insensitivity ---- + +TEST_CASE("Case insensitive unit lookup") { + double c1 = metatomic_torch::unit_conversion_factor("eV", "hartree"); + double c2 = metatomic_torch::unit_conversion_factor("EV", "HARTREE"); + double c3 = metatomic_torch::unit_conversion_factor("Ev", "Hartree"); + CHECK(c1 == Approx(c2).epsilon(1e-12)); + CHECK(c1 == Approx(c3).epsilon(1e-12)); +} + +// ---- Whitespace handling ---- + +TEST_CASE("Whitespace in unit expressions") { + double c1 = metatomic_torch::unit_conversion_factor("eV / Angstrom", "Hartree/Bohr"); + double c2 = metatomic_torch::unit_conversion_factor("eV/Angstrom", "Hartree/Bohr"); + CHECK(c1 == Approx(c2).epsilon(1e-12)); + + double c3 = metatomic_torch::unit_conversion_factor("( eV * u ) ^ ( 1 / 2 )", "u*A/fs"); + double c4 = metatomic_torch::unit_conversion_factor("(eV*u)^(1/2)", "u*A/fs"); + CHECK(c3 == Approx(c4).epsilon(1e-12)); +} + +// ---- Dimension mismatch errors ---- + +TEST_CASE("Dimension mismatch error") { + CHECK_THROWS_WITH( + metatomic_torch::unit_conversion_factor("eV", "Angstrom"), + Contains("dimension mismatch") + ); +} + +// ---- Unknown token errors ---- + +TEST_CASE("Unknown unit token error") { + CHECK_THROWS_WITH( + metatomic_torch::unit_conversion_factor("foobar", "eV"), + Contains("unknown unit token") + ); +} + +// ---- Malformed expression errors ---- + +TEST_CASE("Malformed expression errors") { + CHECK_THROWS_WITH( + metatomic_torch::unit_conversion_factor("eV*(", "eV"), + Contains("parentheses") + ); +} + +// ---- Empty string handling ---- + +TEST_CASE("Empty unit string returns 1.0") { + CHECK(metatomic_torch::unit_conversion_factor("", "eV") == 1.0); + CHECK(metatomic_torch::unit_conversion_factor("eV", "") == 1.0); + CHECK(metatomic_torch::unit_conversion_factor("", "") == 1.0); +} + +// ---- Backward compatibility: 3-arg API ---- + +TEST_CASE("3-arg API backward compatibility") { + DeprecationWarningHandler handler; + torch::WarningUtils::WarningHandlerGuard guard(&handler); + torch::WarningUtils::set_warnAlways(true); + + double conv = metatomic_torch::unit_conversion_factor("energy", "eV", "meV"); + CHECK(conv == Approx(1000.0).epsilon(1e-10)); +} + +// ---- mol handling (dimensionless scaling) ---- + +TEST_CASE("mol as dimensionless scaling factor") { + // kJ/mol to eV: kJ_SI / mol_SI / eV_SI + double conv = metatomic_torch::unit_conversion_factor("kJ/mol", "eV"); + double kj_si = 1000.0; + double mol_si = 6.02214076e23; + double ev_si = 1.602176634e-19; + double expected = (kj_si / mol_si) / ev_si; + CHECK(conv == Approx(expected).epsilon(1e-6)); +} + +// ---- Hbar derived unit ---- + +TEST_CASE("hbar/Bohr momentum conversion") { + double conv = metatomic_torch::unit_conversion_factor("hbar/Bohr", "u*A/fs"); + // hbar_SI / Bohr_SI / (u_SI * A_SI / fs_SI) + double hbar_si = 1.054571817e-34; + double bohr_si = 5.29177210903e-11; + double u_si = 1.66053906660e-27; + double a_si = 1e-10; + double fs_si = 1e-15; + double expected = (hbar_si / bohr_si) / (u_si * a_si / fs_si); + CHECK(conv == Approx(expected).epsilon(1e-3)); +} + +// ---- Identity conversions ---- + +TEST_CASE("Identity conversions") { + CHECK(metatomic_torch::unit_conversion_factor("eV", "eV") == Approx(1.0)); + CHECK(metatomic_torch::unit_conversion_factor("Angstrom", "Angstrom") == Approx(1.0)); + CHECK(metatomic_torch::unit_conversion_factor("u*A/fs", "u*A/fs") == Approx(1.0)); +} + +// ---- Time unit conversions ---- + +TEST_CASE("Time unit conversions") { + // second -> femtosecond + double s_to_fs = metatomic_torch::unit_conversion_factor("s", "fs"); + CHECK(s_to_fs == Approx(1e15).epsilon(1e-6)); + + // second -> picosecond + double s_to_ps = metatomic_torch::unit_conversion_factor("second", "ps"); + CHECK(s_to_ps == Approx(1e12).epsilon(1e-6)); + + // nanosecond -> femtosecond + double ns_to_fs = metatomic_torch::unit_conversion_factor("ns", "fs"); + CHECK(ns_to_fs == Approx(1e6).epsilon(1e-6)); + + // microsecond -> nanosecond + double us_to_ns = metatomic_torch::unit_conversion_factor("us", "ns"); + CHECK(us_to_ns == Approx(1e3).epsilon(1e-6)); +} + +// ---- Micro sign handling ---- + +TEST_CASE("Micro sign (U+00B5) handling") { + double c1 = metatomic_torch::unit_conversion_factor("um", "Angstrom"); + double c2 = metatomic_torch::unit_conversion_factor("\xC2\xB5m", "Angstrom"); + CHECK(c1 == Approx(c2).epsilon(1e-12)); + + // Standalone micro sign -> u (Dalton) + double c3 = metatomic_torch::unit_conversion_factor("u", "kg"); + double c4 = metatomic_torch::unit_conversion_factor("\xC2\xB5", "kg"); + CHECK(c3 == Approx(c4).epsilon(1e-12)); +} + +// ---- Quantity-unit dimension validation ---- + +TEST_CASE("ModelOutput rejects mismatched quantity and unit") { + // energy quantity with a force unit + CHECK_THROWS_WITH( + torch::make_intrusive( + "energy", "eV/A", false, std::vector{}, "" + ), + Contains("incompatible with quantity") + ); + + // force quantity with an energy unit + CHECK_THROWS_WITH( + torch::make_intrusive( + "force", "eV", false, std::vector{}, "" + ), + Contains("incompatible with quantity") + ); + + // length quantity with a pressure unit + CHECK_THROWS_WITH( + torch::make_intrusive( + "length", "eV/A^3", false, std::vector{}, "" + ), + Contains("incompatible with quantity") + ); +} + +TEST_CASE("ModelOutput accepts matching quantity and unit") { + // These should not throw + torch::make_intrusive( + "energy", "eV", false, std::vector{}, "" + ); + torch::make_intrusive( + "force", "eV/A", false, std::vector{}, "" + ); + torch::make_intrusive( + "pressure", "eV/A^3", false, std::vector{}, "" + ); + torch::make_intrusive( + "length", "Angstrom", false, std::vector{}, "" + ); + torch::make_intrusive( + "momentum", "u*A/fs", false, std::vector{}, "" + ); + torch::make_intrusive( + "velocity", "A/fs", false, std::vector{}, "" + ); +} diff --git a/python/metatomic_torch/metatomic/torch/__init__.py b/python/metatomic_torch/metatomic/torch/__init__.py index 8855e1d6..2e72ad0b 100644 --- a/python/metatomic_torch/metatomic/torch/__init__.py +++ b/python/metatomic_torch/metatomic/torch/__init__.py @@ -43,7 +43,42 @@ _check_outputs = torch.ops.metatomic._check_outputs register_autograd_neighbors = torch.ops.metatomic.register_autograd_neighbors - unit_conversion_factor = torch.ops.metatomic.unit_conversion_factor + + _unit_conversion_factor_v2 = torch.ops.metatomic.unit_conversion_factor_v2 + _unit_conversion_factor_v1 = torch.ops.metatomic.unit_conversion_factor + + def unit_conversion_factor(*args, **kwargs): + """Unit conversion factor supporting both 2-arg and 3-arg signatures. + + 2-arg: ``unit_conversion_factor(from_unit, to_unit)`` + 3-arg (deprecated): ``unit_conversion_factor(quantity, from_unit, to_unit)`` + """ + import warnings + + if len(args) == 2 and not kwargs: + return _unit_conversion_factor_v2(args[0], args[1]) + elif len(args) == 3 and not kwargs: + warnings.warn( + "the 3-argument unit_conversion_factor(quantity, from, to) is " + "deprecated; use the 2-argument form instead", + DeprecationWarning, + stacklevel=2, + ) + return _unit_conversion_factor_v2(args[1], args[2]) + elif "from_unit" in kwargs and "to_unit" in kwargs: + if "quantity" in kwargs: + warnings.warn( + "the 3-argument unit_conversion_factor(quantity, from, to)" + " is deprecated; use the 2-argument form instead", + DeprecationWarning, + stacklevel=2, + ) + return _unit_conversion_factor_v2(kwargs["from_unit"], kwargs["to_unit"]) + else: + raise TypeError( + "unit_conversion_factor() expects 2 or 3 positional arguments" + ) + pick_device = torch.ops.metatomic.pick_device pick_output = torch.ops.metatomic.pick_output diff --git a/python/metatomic_torch/metatomic/torch/documentation.py b/python/metatomic_torch/metatomic/torch/documentation.py index e0787ab2..fd37d417 100644 --- a/python/metatomic_torch/metatomic/torch/documentation.py +++ b/python/metatomic_torch/metatomic/torch/documentation.py @@ -522,15 +522,30 @@ def register_autograd_neighbors( """ -def unit_conversion_factor(quantity: str, from_unit: str, to_unit: str): +def unit_conversion_factor(*args, **kwargs): """ - Get the multiplicative conversion factor from ``from_unit`` to ``to_unit``. Both - units must be valid and known for the given physical ``quantity``. The set of valid - quantities and units is available :ref:`here `. + Get the multiplicative conversion factor from ``from_unit`` to ``to_unit``. - :param quantity: name of the physical quantity - :param from_unit: current unit of the data - :param to_unit: target unit of the data + Supports two calling conventions: + + - **2-argument** (preferred): + ``unit_conversion_factor(from_unit, to_unit)`` + - **3-argument** (deprecated): + ``unit_conversion_factor(quantity, from_unit, to_unit)`` + + Both ``from_unit`` and ``to_unit`` are parsed as unit expressions supporting + compound forms like ``"kJ/mol/A^2"`` or ``"(eV*u)^(1/2)"``. The parser + validates that both expressions have matching physical dimensions. + + The ``quantity`` parameter in the 3-argument form is ignored (dimensional + compatibility is checked by the parser). A deprecation warning will be + emitted if the 3-argument form is used. + + The set of recognized base unit tokens is available :ref:`here + `. + + :param from_unit: current unit of the data (expression string) + :param to_unit: target unit of the data (expression string) """ diff --git a/python/metatomic_torch/metatomic/torch/model.py b/python/metatomic_torch/metatomic/torch/model.py index c2a9f528..b6da08c9 100644 --- a/python/metatomic_torch/metatomic/torch/model.py +++ b/python/metatomic_torch/metatomic/torch/model.py @@ -21,7 +21,6 @@ _check_outputs, check_atomistic_model, load_model_extensions, - unit_conversion_factor, ) from . import __version__ as metatomic_version from ._extensions import _collect_extensions @@ -517,10 +516,9 @@ def forward( f"'{requested.quantity}'" ) - conversion = unit_conversion_factor( - quantity=declared.quantity, - from_unit=declared.unit, - to_unit=requested.unit, + conversion = torch.ops.metatomic.unit_conversion_factor_v2( + declared.unit, + requested.unit, ) if conversion != 1.0: @@ -934,10 +932,9 @@ def _convert_systems_units( # no conversion for positions/cell/NL conversion = 1.0 else: - conversion = unit_conversion_factor( - quantity="length", - from_unit=system_length_unit, - to_unit=model_length_unit, + conversion = torch.ops.metatomic.unit_conversion_factor_v2( + system_length_unit, + model_length_unit, ) new_systems: List[System] = [] @@ -974,10 +971,9 @@ def _convert_systems_units( unit = tensor.get_info("unit") if requested.quantity != "" and unit is not None: - conversion = unit_conversion_factor( - quantity=requested.quantity, - from_unit=unit, - to_unit=requested.unit, + conversion = torch.ops.metatomic.unit_conversion_factor_v2( + unit, + requested.unit, ) else: conversion = 1.0 diff --git a/python/metatomic_torch/tests/units.py b/python/metatomic_torch/tests/units.py index 9b47894b..a8238ac1 100644 --- a/python/metatomic_torch/tests/units.py +++ b/python/metatomic_torch/tests/units.py @@ -1,36 +1,135 @@ +import math +import warnings + import ase.units +import pytest from metatomic.torch import ModelOutput, unit_conversion_factor -def test_conversion_length(): - length_angstrom = 1.0 - length_nm = unit_conversion_factor("length", "angstrom", "nm") * length_angstrom - assert length_nm == 0.1 +# ---- Backward compat: 3-arg still works (with deprecation warning) ---- -def test_conversion_energy(): - energy_ev = 1.0 - energy_mev = unit_conversion_factor("energy", "ev", "mev") * energy_ev - assert energy_mev == 1000.0 +def test_conversion_length_3arg(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + length_angstrom = 1.0 + length_nm = unit_conversion_factor("length", "angstrom", "nm") * length_angstrom + assert length_nm == pytest.approx(0.1) + +def test_conversion_energy_3arg(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + energy_ev = 1.0 + energy_mev = unit_conversion_factor("energy", "ev", "mev") * energy_ev + assert energy_mev == pytest.approx(1000.0) -def test_units(): - def length_conversion(unit): - return unit_conversion_factor("length", "angstrom", unit) - assert length_conversion("bohr") == ase.units.Ang / ase.units.Bohr - assert length_conversion("nm") == ase.units.Ang / ase.units.nm - assert length_conversion("nanometer") == ase.units.Ang / ase.units.nm +# ---- 2-arg API ---- + + +def test_conversion_length(): + assert unit_conversion_factor("angstrom", "nm") == pytest.approx(0.1) + assert unit_conversion_factor("angstrom", "Bohr") == pytest.approx( + 1.8897259886, rel=1e-6 + ) + assert unit_conversion_factor("Angstrom", "meter") == pytest.approx( + 1e-10, rel=1e-10 + ) - def energy_conversion(unit): - return unit_conversion_factor("energy", "ev", unit) - assert energy_conversion("Hartree") == ase.units.eV / ase.units.Hartree +def test_conversion_energy(): + assert unit_conversion_factor("eV", "meV") == pytest.approx(1000.0) + assert unit_conversion_factor("eV", "Hartree") == pytest.approx(0.0367493, rel=1e-3) + + +def test_units_vs_ase(): + assert unit_conversion_factor("angstrom", "bohr") == pytest.approx( + ase.units.Ang / ase.units.Bohr, rel=1e-6 + ) + assert unit_conversion_factor("angstrom", "nm") == pytest.approx( + ase.units.Ang / ase.units.nm + ) + assert unit_conversion_factor("angstrom", "nanometer") == pytest.approx( + ase.units.Ang / ase.units.nm + ) + + assert unit_conversion_factor("ev", "Hartree") == pytest.approx( + ase.units.eV / ase.units.Hartree, rel=1e-6 + ) kcal_mol = ase.units.kcal / ase.units.mol - assert energy_conversion("kcal/mol") == ase.units.eV / kcal_mol + assert unit_conversion_factor("ev", "kcal/mol") == pytest.approx( + ase.units.eV / kcal_mol, rel=1e-4 + ) kJ_mol = ase.units.kJ / ase.units.mol - assert energy_conversion("kJ/mol") == ase.units.eV / kJ_mol + assert unit_conversion_factor("ev", "kJ/mol") == pytest.approx( + ase.units.eV / kJ_mol, rel=1e-4 + ) + + +# ---- Compound expressions ---- + + +def test_compound_expressions(): + # Force: eV/Angstrom -> Hartree/Bohr + conv = unit_conversion_factor("eV/Angstrom", "Hartree/Bohr") + assert conv == pytest.approx(0.0194469, rel=1e-3) + + # kJ/mol -> kcal/mol + conv = unit_conversion_factor("kJ/mol", "kcal/mol") + assert conv == pytest.approx(1000.0 / 4184.0, rel=1e-6) + + # Pressure identity + conv = unit_conversion_factor("eV/Angstrom^3", "eV/A^3") + assert conv == pytest.approx(1.0) + + +# ---- Fractional powers ---- + + +def test_fractional_powers(): + # (eV*u)^(1/2) -> u*A/fs + conv = unit_conversion_factor("(eV*u)^(1/2)", "u*A/fs") + ev_si = 1.602176634e-19 + u_si = 1.66053906660e-27 + a_si = 1e-10 + fs_si = 1e-15 + expected = math.sqrt(ev_si * u_si) / (u_si * a_si / fs_si) + assert conv == pytest.approx(expected, rel=1e-3) + + # (eV/u)^(1/2) -> A/fs + conv = unit_conversion_factor("(eV/u)^(1/2)", "A/fs") + expected = math.sqrt(ev_si / u_si) / (a_si / fs_si) + assert conv == pytest.approx(expected, rel=1e-3) + + +# ---- Dimension mismatch ---- + + +def test_dimension_mismatch(): + with pytest.raises((ValueError, RuntimeError), match="dimension mismatch"): + unit_conversion_factor("eV", "Angstrom") + + +# ---- Unknown token ---- + + +def test_unknown_token(): + with pytest.raises((ValueError, RuntimeError), match="unknown unit token"): + unit_conversion_factor("foobar", "eV") + + +# ---- Empty string ---- + + +def test_empty_string(): + assert unit_conversion_factor("", "eV") == 1.0 + assert unit_conversion_factor("eV", "") == 1.0 + assert unit_conversion_factor("", "") == 1.0 + + +# ---- Valid units (ModelOutput creation still works) ---- def test_valid_units(): @@ -45,7 +144,7 @@ def test_valid_units(): ModelOutput(quantity="length", unit="mm") ModelOutput(quantity="length", unit=" micrometer") ModelOutput(quantity="length", unit="um") - ModelOutput(quantity="length", unit="µm") + ModelOutput(quantity="length", unit="\u00b5m") ModelOutput(quantity="length", unit="nanometer") ModelOutput(quantity="length", unit="nm ") @@ -67,3 +166,56 @@ def test_valid_units(): ModelOutput(quantity="momentum", unit="u * A/ fs") ModelOutput(quantity="momentum", unit=" (eV*u )^(1/ 2 )") + + ModelOutput(quantity="velocity", unit="A/fs") + ModelOutput(quantity="velocity", unit="A/s") + + +# ---- Time units ---- + + +def test_time_units(): + assert unit_conversion_factor("s", "fs") == pytest.approx(1e15) + assert unit_conversion_factor("second", "ps") == pytest.approx(1e12) + assert unit_conversion_factor("ns", "fs") == pytest.approx(1e6) + assert unit_conversion_factor("us", "ns") == pytest.approx(1e3) + assert unit_conversion_factor("ms", "us") == pytest.approx(1e3) + + +# ---- Micro sign as standalone (Dalton) ---- + + +def test_micro_sign_standalone(): + # standalone U+00B5 normalizes to 'u' = Dalton + assert unit_conversion_factor("\u00b5", "kg") == pytest.approx( + unit_conversion_factor("u", "kg") + ) + + +# ---- Quantity-unit mismatch ---- + + +def test_quantity_unit_mismatch(): + # energy quantity with force unit + with pytest.raises((ValueError, RuntimeError), match="incompatible with quantity"): + ModelOutput(quantity="energy", unit="eV/A") + + # force quantity with energy unit + with pytest.raises((ValueError, RuntimeError), match="incompatible with quantity"): + ModelOutput(quantity="force", unit="eV") + + # length quantity with pressure unit + with pytest.raises((ValueError, RuntimeError), match="incompatible with quantity"): + ModelOutput(quantity="length", unit="eV/A^3") + + +# ---- Deprecation warning for 3-arg ---- + + +def test_3arg_deprecation_warning(): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + unit_conversion_factor("energy", "eV", "meV") + assert len(w) >= 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "deprecated" in str(w[0].message).lower() From bff84281adc979adabf11d237f73549a4bd0cfd7 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Mon, 9 Mar 2026 12:48:10 +0100 Subject: [PATCH 2/7] test(units): address review feedback on tests Remove standalone micro sign test (no longer normalizes globally), add micro sign microsecond tests instead. Update error message assertions ("unknown unit" not "unknown unit token"). Add comment explaining why models.cpp test uses valid quantity/unit strings. --- metatomic-torch/tests/models.cpp | 10 ++++++---- metatomic-torch/tests/units.cpp | 8 ++++---- python/metatomic_torch/tests/units.py | 12 ++++++------ 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/metatomic-torch/tests/models.cpp b/metatomic-torch/tests/models.cpp index 605958d6..55df429b 100644 --- a/metatomic-torch/tests/models.cpp +++ b/metatomic-torch/tests/models.cpp @@ -55,7 +55,7 @@ TEST_CASE("Models metadata") { ); CHECK_THROWS_WITH(options->set_length_unit("unknown"), - StartsWith("unknown unit token 'unknown'") + StartsWith("unknown unit 'unknown'") ); } @@ -103,7 +103,7 @@ TEST_CASE("Models metadata") { ); CHECK_THROWS_WITH(output->set_unit("unknown"), - StartsWith("unknown unit token 'unknown'") + StartsWith("unknown unit 'unknown'") ); struct WarningHandler: public torch::WarningHandler { @@ -135,6 +135,8 @@ TEST_CASE("Models metadata") { auto output = torch::make_intrusive(); output->per_atom = true; + // Use valid quantity/unit since the parser validates unit expressions + // at set_unit() time (previously any string was accepted) output->set_quantity("energy"); output->set_unit("eV"); options->outputs.insert("output_2", output); @@ -205,7 +207,7 @@ TEST_CASE("Models metadata") { ); CHECK_THROWS_WITH(options->set_length_unit("unknown"), - StartsWith("unknown unit token 'unknown'") + StartsWith("unknown unit 'unknown'") ); } @@ -295,7 +297,7 @@ TEST_CASE("Models metadata") { ); CHECK_THROWS_WITH(capabilities->set_length_unit("unknown"), - StartsWith("unknown unit token 'unknown'") + StartsWith("unknown unit 'unknown'") ); auto capabilities_variants = torch::make_intrusive(); diff --git a/metatomic-torch/tests/units.cpp b/metatomic-torch/tests/units.cpp index 79e61359..49493dc8 100644 --- a/metatomic-torch/tests/units.cpp +++ b/metatomic-torch/tests/units.cpp @@ -124,7 +124,7 @@ TEST_CASE("Dimension mismatch error") { TEST_CASE("Unknown unit token error") { CHECK_THROWS_WITH( metatomic_torch::unit_conversion_factor("foobar", "eV"), - Contains("unknown unit token") + Contains("unknown unit") ); } @@ -217,9 +217,9 @@ TEST_CASE("Micro sign (U+00B5) handling") { double c2 = metatomic_torch::unit_conversion_factor("\xC2\xB5m", "Angstrom"); CHECK(c1 == Approx(c2).epsilon(1e-12)); - // Standalone micro sign -> u (Dalton) - double c3 = metatomic_torch::unit_conversion_factor("u", "kg"); - double c4 = metatomic_torch::unit_conversion_factor("\xC2\xB5", "kg"); + // µs -> ns (microsecond via micro sign) + double c3 = metatomic_torch::unit_conversion_factor("us", "ns"); + double c4 = metatomic_torch::unit_conversion_factor("\xC2\xB5s", "ns"); CHECK(c3 == Approx(c4).epsilon(1e-12)); } diff --git a/python/metatomic_torch/tests/units.py b/python/metatomic_torch/tests/units.py index a8238ac1..6ccc60fe 100644 --- a/python/metatomic_torch/tests/units.py +++ b/python/metatomic_torch/tests/units.py @@ -116,7 +116,7 @@ def test_dimension_mismatch(): def test_unknown_token(): - with pytest.raises((ValueError, RuntimeError), match="unknown unit token"): + with pytest.raises((ValueError, RuntimeError), match="unknown unit"): unit_conversion_factor("foobar", "eV") @@ -182,13 +182,13 @@ def test_time_units(): assert unit_conversion_factor("ms", "us") == pytest.approx(1e3) -# ---- Micro sign as standalone (Dalton) ---- +# ---- Micro sign for microsecond ---- -def test_micro_sign_standalone(): - # standalone U+00B5 normalizes to 'u' = Dalton - assert unit_conversion_factor("\u00b5", "kg") == pytest.approx( - unit_conversion_factor("u", "kg") +def test_micro_sign_microsecond(): + # \u00b5s -> ns (microsecond via micro sign) + assert unit_conversion_factor("\u00b5s", "ns") == pytest.approx( + unit_conversion_factor("us", "ns") ) From eb8e28aa6b8f8145de2643958f0a7b667d78084f Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Mon, 9 Mar 2026 12:48:20 +0100 Subject: [PATCH 3/7] docs(units): move unit docs to Doxygen, simplify misc.rst Move full unit expression documentation to the Doxygen comment on unit_conversion_factor in units.hpp (renders via autofunction). Remove redundant standalone sections from misc.rst, keeping only the known-quantities table and deprecation note. --- docs/src/torch/reference/misc.rst | 47 +------------------------------ 1 file changed, 1 insertion(+), 46 deletions(-) diff --git a/docs/src/torch/reference/misc.rst b/docs/src/torch/reference/misc.rst index bef45bd9..7021832f 100644 --- a/docs/src/torch/reference/misc.rst +++ b/docs/src/torch/reference/misc.rst @@ -9,53 +9,8 @@ Miscellaneous .. _known-quantities-units: -Unit expression parser ----------------------- - -``unit_conversion_factor`` accepts arbitrary unit expressions built from base -tokens combined with ``*``, ``/``, ``^``, and parentheses. For example: - -- ``"kJ/mol"`` -- ``"eV/Angstrom^3"`` -- ``"(eV*u)^(1/2)"`` -- ``"Hartree/Bohr"`` - -Dimensional compatibility is verified automatically; no ``quantity`` parameter -is needed. Token lookup is case-insensitive, and whitespace is ignored. - -Base unit tokens -^^^^^^^^^^^^^^^^ - -.. list-table:: Supported Unit Tokens - :header-rows: 1 - - * - Dimension - - Tokens - - Notes - * - **Length** - - ``angstrom`` (``a``), ``bohr``, ``nm`` (``nanometer``), ``meter`` (``m``), ``cm`` (``centimeter``), ``mm`` (``millimeter``), ``um`` (``micrometer``) - - - * - **Energy** - - ``ev``, ``mev``, ``hartree``, ``ry`` (``rydberg``), ``joule`` (``j``), ``kcal``, ``kj`` - - ``kcal`` and ``kj`` are bare (not per-mol); write ``kcal/mol`` for the per-mole unit - * - **Time** - - ``s`` (``second``), ``ms`` (``millisecond``), ``us`` (``microsecond``), ``ns`` (``nanosecond``), ``ps`` (``picosecond``), ``fs`` (``femtosecond``) - - - * - **Mass** - - ``u`` (``dalton``), ``kg`` (``kilogram``), ``g`` (``gram``), ``electron_mass`` (``m_e``) - - - * - **Charge** - - ``e``, ``coulomb`` (``c``) - - - * - **Dimensionless** - - ``mol`` - - Avogadro scaling factor - * - **Derived** - - ``hbar`` - - :math:`\hbar` in SI (:math:`M L^2 T^{-1}`) - Known quantities -^^^^^^^^^^^^^^^^ +---------------- When setting ``quantity`` on a :py:class:`~metatomic.torch.ModelOutput`, the following names are recognized. The parser will check that the unit expression From 1aa7e0e076e5f883c8d2427e2f2d0ab5b6480058 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Mon, 9 Mar 2026 12:50:45 +0100 Subject: [PATCH 4/7] fix: address remaining review items Remove validate_unit from CHANGELOG (not public API). Move unit_conversion_factor docstring from documentation.py to the Python function in __init__.py (documentation.py is only for C++ ops). --- metatomic-torch/CHANGELOG.md | 2 -- .../metatomic/torch/__init__.py | 27 ++++++++++++++++--- .../metatomic/torch/documentation.py | 27 ------------------- 3 files changed, 24 insertions(+), 32 deletions(-) diff --git a/metatomic-torch/CHANGELOG.md b/metatomic-torch/CHANGELOG.md index e786886c..679850b1 100644 --- a/metatomic-torch/CHANGELOG.md +++ b/metatomic-torch/CHANGELOG.md @@ -26,8 +26,6 @@ a changelog](https://keepachangelog.com/en/1.1.0/) format. This project follows ### Changed -- `validate_unit` now accepts arbitrary parseable expressions and validates - their dimensions against the expected quantity - 3-argument `unit_conversion_factor(quantity, from_unit, to_unit)` is deprecated; the `quantity` parameter is ignored diff --git a/python/metatomic_torch/metatomic/torch/__init__.py b/python/metatomic_torch/metatomic/torch/__init__.py index 2e72ad0b..f24bb09d 100644 --- a/python/metatomic_torch/metatomic/torch/__init__.py +++ b/python/metatomic_torch/metatomic/torch/__init__.py @@ -48,10 +48,31 @@ _unit_conversion_factor_v1 = torch.ops.metatomic.unit_conversion_factor def unit_conversion_factor(*args, **kwargs): - """Unit conversion factor supporting both 2-arg and 3-arg signatures. + """ + Get the multiplicative conversion factor from ``from_unit`` to + ``to_unit``. + + Supports two calling conventions: + + - **2-argument** (preferred): + ``unit_conversion_factor(from_unit, to_unit)`` + - **3-argument** (deprecated): + ``unit_conversion_factor(quantity, from_unit, to_unit)`` + + Both ``from_unit`` and ``to_unit`` are parsed as unit expressions + supporting compound forms like ``"kJ/mol/A^2"`` or + ``"(eV*u)^(1/2)"``. The parser validates that both expressions have + matching physical dimensions. + + The ``quantity`` parameter in the 3-argument form is ignored + (dimensional compatibility is checked by the parser). A deprecation + warning will be emitted if the 3-argument form is used. + + The set of recognized base unit tokens is available :ref:`here + `. - 2-arg: ``unit_conversion_factor(from_unit, to_unit)`` - 3-arg (deprecated): ``unit_conversion_factor(quantity, from_unit, to_unit)`` + :param from_unit: current unit of the data (expression string) + :param to_unit: target unit of the data (expression string) """ import warnings diff --git a/python/metatomic_torch/metatomic/torch/documentation.py b/python/metatomic_torch/metatomic/torch/documentation.py index fd37d417..d83c9eea 100644 --- a/python/metatomic_torch/metatomic/torch/documentation.py +++ b/python/metatomic_torch/metatomic/torch/documentation.py @@ -522,33 +522,6 @@ def register_autograd_neighbors( """ -def unit_conversion_factor(*args, **kwargs): - """ - Get the multiplicative conversion factor from ``from_unit`` to ``to_unit``. - - Supports two calling conventions: - - - **2-argument** (preferred): - ``unit_conversion_factor(from_unit, to_unit)`` - - **3-argument** (deprecated): - ``unit_conversion_factor(quantity, from_unit, to_unit)`` - - Both ``from_unit`` and ``to_unit`` are parsed as unit expressions supporting - compound forms like ``"kJ/mol/A^2"`` or ``"(eV*u)^(1/2)"``. The parser - validates that both expressions have matching physical dimensions. - - The ``quantity`` parameter in the 3-argument form is ignored (dimensional - compatibility is checked by the parser). A deprecation warning will be - emitted if the 3-argument form is used. - - The set of recognized base unit tokens is available :ref:`here - `. - - :param from_unit: current unit of the data (expression string) - :param to_unit: target unit of the data (expression string) - """ - - def pick_device(model_devices: List[str], desired_device: Optional[str]) -> str: """ Select the best device according to the list of ``model_devices`` from a model, the From 006284d29d0a3f25184f36b38d354de49b698b49 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Mon, 9 Mar 2026 12:52:58 +0100 Subject: [PATCH 5/7] =?UTF-8?q?style:=20use=20literal=20=C2=B5=20instead?= =?UTF-8?q?=20of=20\u00b5=20escape=20in=20Python=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/metatomic_torch/tests/units.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/metatomic_torch/tests/units.py b/python/metatomic_torch/tests/units.py index 6ccc60fe..39baac16 100644 --- a/python/metatomic_torch/tests/units.py +++ b/python/metatomic_torch/tests/units.py @@ -144,7 +144,7 @@ def test_valid_units(): ModelOutput(quantity="length", unit="mm") ModelOutput(quantity="length", unit=" micrometer") ModelOutput(quantity="length", unit="um") - ModelOutput(quantity="length", unit="\u00b5m") + ModelOutput(quantity="length", unit="µm") ModelOutput(quantity="length", unit="nanometer") ModelOutput(quantity="length", unit="nm ") @@ -186,8 +186,8 @@ def test_time_units(): def test_micro_sign_microsecond(): - # \u00b5s -> ns (microsecond via micro sign) - assert unit_conversion_factor("\u00b5s", "ns") == pytest.approx( + # µs -> ns (microsecond via micro sign) + assert unit_conversion_factor("µs", "ns") == pytest.approx( unit_conversion_factor("us", "ns") ) From 7eb3d41a1e491b281475672c914f7a15c74c2589 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Mon, 9 Mar 2026 12:57:39 +0100 Subject: [PATCH 6/7] =?UTF-8?q?style:=20use=20literal=20=C2=B5=20instead?= =?UTF-8?q?=20of=20hex=20escapes=20in=20C++=20sources?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- metatomic-torch/include/metatomic/torch/units.hpp | 4 ++-- metatomic-torch/src/units.cpp | 4 ++-- metatomic-torch/tests/units.cpp | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/metatomic-torch/include/metatomic/torch/units.hpp b/metatomic-torch/include/metatomic/torch/units.hpp index 06ef5cdc..4049be91 100644 --- a/metatomic-torch/include/metatomic/torch/units.hpp +++ b/metatomic-torch/include/metatomic/torch/units.hpp @@ -33,10 +33,10 @@ void validate_unit(const std::string& quantity, const std::string& unit); /// Supported tokens: /// /// - **Length**: angstrom (A), bohr, nanometer (nm), meter (m), -/// centimeter (cm), millimeter (mm), micrometer (um, \xC2\xB5m) +/// centimeter (cm), millimeter (mm), micrometer (um, µm) /// - **Energy**: eV, meV, Hartree, Ry (rydberg), Joule (J), kcal, kJ /// (note: kcal and kJ are bare; write kcal/mol for per-mole) -/// - **Time**: second (s), millisecond (ms), microsecond (us, \xC2\xB5s), +/// - **Time**: second (s), millisecond (ms), microsecond (us, µs), /// nanosecond (ns), picosecond (ps), femtosecond (fs) /// - **Mass**: dalton (u), kilogram (kg), gram (g), electron_mass (m_e) /// - **Charge**: e, coulomb (c) diff --git a/metatomic-torch/src/units.cpp b/metatomic-torch/src/units.cpp index 5ec3c775..812bd528 100644 --- a/metatomic-torch/src/units.cpp +++ b/metatomic-torch/src/units.cpp @@ -121,7 +121,7 @@ static const std::unordered_map& base_units() { {"mm", {1e-3, DIM_LENGTH}}, {"millimeter",{1e-3, DIM_LENGTH}}, {"um", {1e-6, DIM_LENGTH}}, - {"\xc2\xb5m", {1e-6, DIM_LENGTH}}, + {"µm", {1e-6, DIM_LENGTH}}, {"micrometer",{1e-6, DIM_LENGTH}}, // --- Energy --- @@ -141,7 +141,7 @@ static const std::unordered_map& base_units() { {"ms", {1e-3, DIM_TIME}}, {"millisecond", {1e-3, DIM_TIME}}, {"us", {1e-6, DIM_TIME}}, - {"\xc2\xb5s", {1e-6, DIM_TIME}}, + {"µs", {1e-6, DIM_TIME}}, {"microsecond", {1e-6, DIM_TIME}}, {"ns", {1e-9, DIM_TIME}}, {"nanosecond",{1e-9, DIM_TIME}}, diff --git a/metatomic-torch/tests/units.cpp b/metatomic-torch/tests/units.cpp index 49493dc8..33bdb689 100644 --- a/metatomic-torch/tests/units.cpp +++ b/metatomic-torch/tests/units.cpp @@ -214,12 +214,12 @@ TEST_CASE("Time unit conversions") { TEST_CASE("Micro sign (U+00B5) handling") { double c1 = metatomic_torch::unit_conversion_factor("um", "Angstrom"); - double c2 = metatomic_torch::unit_conversion_factor("\xC2\xB5m", "Angstrom"); + double c2 = metatomic_torch::unit_conversion_factor("µm", "Angstrom"); CHECK(c1 == Approx(c2).epsilon(1e-12)); // µs -> ns (microsecond via micro sign) double c3 = metatomic_torch::unit_conversion_factor("us", "ns"); - double c4 = metatomic_torch::unit_conversion_factor("\xC2\xB5s", "ns"); + double c4 = metatomic_torch::unit_conversion_factor("µs", "ns"); CHECK(c3 == Approx(c4).epsilon(1e-12)); } From 4c8ab769be4726d52d625dbfa2f56088c39e3719 Mon Sep 17 00:00:00 2001 From: Rohit Goswami Date: Mon, 9 Mar 2026 13:19:58 +0100 Subject: [PATCH 7/7] fix: ASCII-only tolower and restore documentation.py stub to_lower() now skips non-ASCII bytes, preventing macOS locale from mangling UTF-8 micro sign (0xB5 -> 'u'). Restore unit_conversion_factor in documentation.py since __init__.py imports it for Sphinx builds. --- metatomic-torch/src/units.cpp | 5 +++- .../metatomic/torch/documentation.py | 29 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/metatomic-torch/src/units.cpp b/metatomic-torch/src/units.cpp index 812bd528..bb920137 100644 --- a/metatomic-torch/src/units.cpp +++ b/metatomic-torch/src/units.cpp @@ -98,7 +98,10 @@ static const Dimension DIM_NONE = {{ 0, 0, 0, 0, 0 }}; /// Lowercase a string in place and return it. static std::string to_lower(std::string s) { std::transform(s.begin(), s.end(), s.begin(), - [](unsigned char c) { return static_cast(std::tolower(c)); } + [](unsigned char c) { + // Only lowercase ASCII; preserve UTF-8 continuation bytes + return c < 128 ? static_cast(std::tolower(c)) : static_cast(c); + } ); return s; } diff --git a/python/metatomic_torch/metatomic/torch/documentation.py b/python/metatomic_torch/metatomic/torch/documentation.py index d83c9eea..cb42058b 100644 --- a/python/metatomic_torch/metatomic/torch/documentation.py +++ b/python/metatomic_torch/metatomic/torch/documentation.py @@ -522,6 +522,35 @@ def register_autograd_neighbors( """ +def unit_conversion_factor(from_unit: str, to_unit: str) -> float: + """ + Get the multiplicative conversion factor from ``from_unit`` to + ``to_unit``. + + Supports two calling conventions: + + - **2-argument** (preferred): + ``unit_conversion_factor(from_unit, to_unit)`` + - **3-argument** (deprecated): + ``unit_conversion_factor(quantity, from_unit, to_unit)`` + + Both ``from_unit`` and ``to_unit`` are parsed as unit expressions + supporting compound forms like ``"kJ/mol/A^2"`` or + ``"(eV*u)^(1/2)"``. The parser validates that both expressions have + matching physical dimensions. + + The ``quantity`` parameter in the 3-argument form is ignored + (dimensional compatibility is checked by the parser). A deprecation + warning will be emitted if the 3-argument form is used. + + The set of recognized base unit tokens is available :ref:`here + `. + + :param from_unit: current unit of the data (expression string) + :param to_unit: target unit of the data (expression string) + """ + + def pick_device(model_devices: List[str], desired_device: Optional[str]) -> str: """ Select the best device according to the list of ``model_devices`` from a model, the