diff --git a/docs/src/torch/reference/cxx/misc.rst b/docs/src/torch/reference/cxx/misc.rst index e6c1aeda..6dfc15b5 100644 --- a/docs/src/torch/reference/cxx/misc.rst +++ b/docs/src/torch/reference/cxx/misc.rst @@ -1,7 +1,9 @@ Miscellaneous ============= -.. doxygenfunction:: metatomic_torch::unit_conversion_factor +.. doxygenfunction:: metatomic_torch::unit_conversion_factor(const std::string &from_unit, const std::string &to_unit) + +.. doxygenfunction:: metatomic_torch::unit_conversion_factor(const std::string &quantity, const std::string &from_unit, const std::string &to_unit) .. doxygenfunction:: metatomic_torch::pick_device diff --git a/docs/src/torch/reference/misc.rst b/docs/src/torch/reference/misc.rst index e2cf5df9..7021832f 100644 --- a/docs/src/torch/reference/misc.rst +++ b/docs/src/torch/reference/misc.rst @@ -9,34 +9,37 @@ Miscellaneous .. _known-quantities-units: -Known quantities and units --------------------------- - -The following quantities and units can be used with metatomic models. Adding new -units and quantities is very easy, please contact us if you need something else! -In the mean time, you can create :py:class:`metatomic.torch.ModelOutput` with -quantities that are not in this table. A warning will be issued and no unit -conversion will be performed. - -When working with one of the quantities in this table, the unit you use must be -one of the registered unit. - -+----------------+------------------------------------------------------------------------------------------------------------------------------------------------------+ -| quantity | units | -+================+======================================================================================================================================================+ -| **length** | ``angstrom`` (``A``), ``Bohr``, ``meter``, ``centimeter`` (``cm``), ``millimeter`` (``mm``), ``micrometer`` (``um``, ``µm``), ``nanometer`` (``nm``) | -+----------------+------------------------------------------------------------------------------------------------------------------------------------------------------+ -| **energy** | ``eV``, ``meV``, ``Hartree``, ``kcal/mol``, ``kJ/mol``, ``Joule`` (``J``), ``Rydberg`` (``Ry``) | -+----------------+------------------------------------------------------------------------------------------------------------------------------------------------------+ -| **force** | ``eV/Angstrom`` (``eV/A``), ``Hartree/Bohr`` | -+----------------+------------------------------------------------------------------------------------------------------------------------------------------------------+ -| **pressure** | ``eV/Angstrom^3`` (``eV/A^3``) | -+----------------+------------------------------------------------------------------------------------------------------------------------------------------------------+ -| **momentum** | ``u*A/fs``, ``u*A/ps``, ``(eV*u)^(1/2)``, ``kg*m/s``, ``hbar/Bohr`` | -+----------------+------------------------------------------------------------------------------------------------------------------------------------------------------+ -| **mass** | ``u`` (``Dalton``), ``kg`` (``kilogram``), ``g`` (``gram``), ``electron_mass`` (``m_e``) | -+----------------+------------------------------------------------------------------------------------------------------------------------------------------------------+ -| **velocity** | ``nm/fs``, ``A/fs``, ``m/s``, ``nm/ps``, ``Bohr*Hartree/hbar`` | -+----------------+------------------------------------------------------------------------------------------------------------------------------------------------------+ -| **charge** | ``e``, ``Coulomb`` (``C``) | -+----------------+------------------------------------------------------------------------------------------------------------------------------------------------------+ +Known quantities +---------------- + +When setting ``quantity`` on a :py:class:`~metatomic.torch.ModelOutput`, the +following names are recognized. The parser will check that the unit expression +has dimensions matching the expected quantity. + +.. list-table:: Physical Dimensions + :header-rows: 1 + + * - quantity + - expected dimension + * - **length** + - :math:`L` + * - **energy** + - :math:`M L^2 T^{-2}` + * - **force** + - :math:`M L T^{-2}` + * - **pressure** + - :math:`M L^{-1} T^{-2}` + * - **momentum** + - :math:`M L T^{-1}` + * - **mass** + - :math:`M` + * - **velocity** + - :math:`L T^{-1}` + * - **charge** + - :math:`Q` + +.. note:: + + The 3-argument form ``unit_conversion_factor(quantity, from_unit, to_unit)`` + is deprecated. Use the 2-argument form instead. The ``quantity`` parameter is + ignored by the parser; dimensional compatibility is checked automatically. diff --git a/metatomic-torch/CHANGELOG.md b/metatomic-torch/CHANGELOG.md index 00947737..679850b1 100644 --- a/metatomic-torch/CHANGELOG.md +++ b/metatomic-torch/CHANGELOG.md @@ -17,6 +17,18 @@ a changelog](https://keepachangelog.com/en/1.1.0/) format. This project follows ## [Unreleased](https://github.com/metatensor/metatomic/) +### Added + +- Unit expression parser supporting compound expressions (`kJ/mol/A^2`, + `(eV*u)^(1/2)`, etc.) with automatic dimensional validation +- 2-argument `unit_conversion_factor(from_unit, to_unit)` that parses + arbitrary unit expressions and checks dimensional compatibility + +### Changed + +- 3-argument `unit_conversion_factor(quantity, from_unit, to_unit)` is + deprecated; the `quantity` parameter is ignored + ## [Version 0.1.11](https://github.com/metatensor/metatomic/releases/tag/metatomic-torch-v0.1.11) - 2026-02-27 ### Added diff --git a/metatomic-torch/CMakeLists.txt b/metatomic-torch/CMakeLists.txt index 8f253344..c661cea5 100644 --- a/metatomic-torch/CMakeLists.txt +++ b/metatomic-torch/CMakeLists.txt @@ -96,6 +96,7 @@ find_package(Torch 2.3 REQUIRED) set(METATOMIC_TORCH_HEADERS "include/metatomic/torch/system.hpp" "include/metatomic/torch/model.hpp" + "include/metatomic/torch/units.hpp" "include/metatomic/torch.hpp" "include/metatomic/torch/outputs.hpp" ) @@ -104,6 +105,7 @@ set(METATOMIC_TORCH_SOURCE "src/misc.cpp" "src/system.cpp" "src/model.cpp" + "src/units.cpp" "src/outputs.cpp" "src/register.cpp" "src/internal/shared_libraries.cpp" diff --git a/metatomic-torch/include/metatomic/torch/model.hpp b/metatomic-torch/include/metatomic/torch/model.hpp index f2f9bdeb..d8cc5db4 100644 --- a/metatomic-torch/include/metatomic/torch/model.hpp +++ b/metatomic-torch/include/metatomic/torch/model.hpp @@ -9,6 +9,7 @@ #include #include "metatomic/torch/exports.h" +#include "metatomic/torch/units.hpp" namespace metatomic_torch { @@ -28,16 +29,6 @@ class ModelMetadataHolder; /// TorchScript will always manipulate `ModelMetadataHolder` through a `torch::intrusive_ptr` using ModelMetadata = torch::intrusive_ptr; -/// Check that a given physical quantity is valid and known. This is -/// intentionally not exported with `METATOMIC_TORCH_EXPORT`, and is only -/// intended for internal use. -bool valid_quantity(const std::string& quantity); - -/// Check that a given unit is valid and known for some physical quantity. This -/// is intentionally not exported with `METATOMIC_TORCH_EXPORT`, and is only -/// intended for internal use. -void validate_unit(const std::string& quantity, const std::string& unit); - /// Information about one of the quantity a model can compute class METATOMIC_TORCH_EXPORT ModelOutputHolder: public torch::CustomClassHolder { @@ -326,14 +317,6 @@ METATOMIC_TORCH_EXPORT metatensor_torch::Module load_atomistic_model( c10::optional extensions_directory = c10::nullopt ); -/// Get the multiplicative conversion factor to use to convert from unit `from` -/// to unit `to`. Both should be units for the given physical `quantity`. -METATOMIC_TORCH_EXPORT double unit_conversion_factor( - const std::string& quantity, - const std::string& from_unit, - const std::string& to_unit -); - } #endif diff --git a/metatomic-torch/include/metatomic/torch/units.hpp b/metatomic-torch/include/metatomic/torch/units.hpp new file mode 100644 index 00000000..4049be91 --- /dev/null +++ b/metatomic-torch/include/metatomic/torch/units.hpp @@ -0,0 +1,61 @@ +#ifndef METATOMIC_TORCH_UNITS_HPP +#define METATOMIC_TORCH_UNITS_HPP + +#include + +#include "metatomic/torch/exports.h" + +namespace metatomic_torch { + +/// Check that a given physical quantity is valid and known. This is +/// intentionally not exported with `METATOMIC_TORCH_EXPORT`, and is only +/// intended for internal use. +bool valid_quantity(const std::string& quantity); + +/// Check that a given unit is valid and known for some physical quantity. This +/// is intentionally not exported with `METATOMIC_TORCH_EXPORT`, and is only +/// intended for internal use. +void validate_unit(const std::string& quantity, const std::string& unit); + +/// Get the multiplicative conversion factor to use to convert from +/// `from_unit` to `to_unit`. Both units are parsed as expressions (e.g. +/// "kJ/mol/A^2", "(eV*u)^(1/2)") and their dimensions must match. +/// +/// Unit expressions are built from base tokens combined with `*`, `/`, `^`, +/// and parentheses. Token lookup is case-insensitive, and whitespace is +/// ignored. For example: +/// +/// - `"kJ/mol"` -- energy per mole +/// - `"eV/Angstrom^3"` -- pressure +/// - `"(eV*u)^(1/2)"` -- momentum (fractional powers) +/// - `"Hartree/Bohr"` -- force in atomic units +/// +/// Supported tokens: +/// +/// - **Length**: angstrom (A), bohr, nanometer (nm), meter (m), +/// centimeter (cm), millimeter (mm), micrometer (um, µm) +/// - **Energy**: eV, meV, Hartree, Ry (rydberg), Joule (J), kcal, kJ +/// (note: kcal and kJ are bare; write kcal/mol for per-mole) +/// - **Time**: second (s), millisecond (ms), microsecond (us, µs), +/// nanosecond (ns), picosecond (ps), femtosecond (fs) +/// - **Mass**: dalton (u), kilogram (kg), gram (g), electron_mass (m_e) +/// - **Charge**: e, coulomb (c) +/// - **Dimensionless**: mol +/// - **Derived**: hbar +METATOMIC_TORCH_EXPORT double unit_conversion_factor( + const std::string& from_unit, + const std::string& to_unit +); + +/// Deprecated 3-argument overload. The `quantity` parameter is ignored; +/// dimensional compatibility is checked by the parser. Emits a one-time +/// deprecation warning. +METATOMIC_TORCH_EXPORT double unit_conversion_factor( + const std::string& quantity, + const std::string& from_unit, + const std::string& to_unit +); + +} + +#endif diff --git a/metatomic-torch/src/model.cpp b/metatomic-torch/src/model.cpp index a7b6bee2..eff9312a 100644 --- a/metatomic-torch/src/model.cpp +++ b/metatomic-torch/src/model.cpp @@ -1,10 +1,9 @@ #include -#include -#include #include -#include #include +#include +#include #include #include @@ -200,7 +199,7 @@ void ModelCapabilitiesHolder::set_dtype(std::string dtype) { } double ModelCapabilitiesHolder::engine_interaction_range(const std::string& engine_length_unit) const { - return interaction_range * unit_conversion_factor("length", length_unit_, engine_length_unit); + return interaction_range * unit_conversion_factor(length_unit_, engine_length_unit); } std::string ModelCapabilitiesHolder::to_json() const { @@ -975,233 +974,3 @@ metatensor_torch::Module metatomic_torch::load_atomistic_model( return metatensor_torch::Module(model); } - -/******************************************************************************/ -/******************************************************************************/ - -/// remove all whitespace in a string (i.e. `kcal / mol` => `kcal/mol`) -static std::string remove_spaces(std::string value) { - auto new_end = std::remove_if(value.begin(), value.end(), - [](unsigned char c){ return std::isspace(c); } - ); - value.erase(new_end, value.end()); - return value; -} - - -/// Lower case string, to be used as a key in Quantity.conversion (we want -/// "Angstrom" and "angstrom" to be equivalent). -class LowercaseString { -public: - LowercaseString(std::string init): original_(std::move(init)) { - std::transform(original_.begin(), original_.end(), std::back_inserter(lowercase_), &::tolower); - } - - LowercaseString(const char* init): LowercaseString(std::string(init)) {} - - operator std::string&() { - return lowercase_; - } - operator std::string const&() const { - return lowercase_; - } - - const std::string& original() const { - return original_; - } - - bool operator==(const LowercaseString& other) const { - return this->lowercase_ == other.lowercase_; - } - -private: - std::string original_; - std::string lowercase_; -}; - -template <> -struct std::hash { - size_t operator()(const LowercaseString& k) const { - return std::hash()(k); - } -}; - -/// Information for unit conversion for this physical quantity -struct Quantity { - /// the quantity name - std::string name; - - /// baseline unit for this quantity - std::string baseline; - /// set of conversion from the key to the baseline unit - std::unordered_map conversions; - std::unordered_map alternatives; - - std::string normalize_unit(const std::string& original_unit) { - if (original_unit.empty()) { - return original_unit; - } - - std::string unit = remove_spaces(original_unit); - auto alternative = this->alternatives.find(unit); - if (alternative != this->alternatives.end()) { - unit = alternative->second; - } - - if (this->conversions.find(unit) == this->conversions.end()) { - auto valid_units = std::vector(); - for (const auto& it: this->conversions) { - valid_units.emplace_back(it.first.original()); - } - - C10_THROW_ERROR(ValueError, - "unknown unit '" + original_unit + "' for " + name + ", " - "only [" + torch::str(valid_units) + "] are supported" - ); - } - - return unit; - } - - double conversion(const std::string& from_unit, const std::string& to_unit) { - auto from = this->normalize_unit(from_unit); - auto to = this->normalize_unit(to_unit); - - if (from.empty() || to.empty()) { - return 1.0; - } - - return this->conversions.at(to) / this->conversions.at(from); - } -}; - -static std::map KNOWN_QUANTITIES = { - {"length", Quantity{/* name */ "length", /* baseline */ "Angstrom", { - {"Angstrom", 1.0}, - {"Bohr", 1.8897261258369282}, - {"meter", 1e-10}, - {"centimeter", 1e-8}, - {"millimeter", 1e-7}, - {"micrometer", 0.0001}, - {"nanometer", 0.1}, - }, { - // alternative names - {"A", "Angstrom"}, - {"cm", "centimeter"}, - {"mm", "millimeter"}, - {"um", "micrometer"}, - {"µm", "micrometer"}, - {"nm", "nanometer"}, - }}}, - {"energy", Quantity{/* name */ "energy", /* baseline */ "eV", { - {"eV", 1.0}, - {"meV", 1000.0}, - {"Hartree", 0.03674932247495664}, - {"kcal/mol", 23.060548012069496}, - {"kJ/mol", 96.48533288249877}, - {"Joule", 1.60218e-19}, - {"Rydberg", 0.07349864435130857}, - }, { - // alternative names - {"J", "Joule"}, - {"Ry", "Rydberg"}, - }}}, - {"force", Quantity{/* name */ "force", /* baseline */ "eV/Angstrom", { - {"eV/Angstrom", 1.0}, - {"Hartree/Bohr", 0.019446904} - }, { - // alternative names - {"eV/A", "eV/Angstrom"}, - }}}, - {"pressure", Quantity{/* name */ "pressure", /* baseline */ "eV/Angstrom^3", { - {"eV/Angstrom^3", 1.0}, - }, { - // alternative names - {"eV/A^3", "eV/Angstrom^3"}, - }}}, - {"momentum", Quantity{/* name */ "momentum", /* baseline */ "u * A / fs", { - {"u*A/fs", 1.0}, - {"u*A/ps", 1000.0}, - {"(eV*u)^(1/2)", 10.1805057179}, - {"kg*m/s", 1.6605390666e-22}, - {"hbar/Bohr", 83.32476}, - }, { - // alternative names - }}}, - {"mass", Quantity{/* name */ "mass", /* baseline */ "u ", { - {"u", 1.0}, - {"kilogram", 1.66053906892e-27}, - {"gram", 1.66053906892e-24}, - {"electron_mass", 1822.8885}, - }, { - // alternative names - {"Dalton", "u"}, - {"kg", "kilogram"}, - {"g", "gram"}, - {"electron_mass", "m_e"}, - }}}, - {"velocity", Quantity{/* name */ "velocity", /* baseline */ "nm/fs", { - {"nm/fs", 1.0}, - {"A/fs", 1e1}, - {"m/s", 1e6}, - {"nm/ps", 1e3}, - {"(eV/u)^(1/2)", 101.80506}, - {"Bohr*Hartree/hbar", 0.45710289}, - }, { - // alternative names - }}}, - {"charge", Quantity{/* name */ "charge", /* baseline */ "e", { - {"e", 1.0}, - {"Coulomb", 1.602176634e-19}, - }, { - // alternative names - {"C", "Coulomb"}, - }}}, -}; - -bool metatomic_torch::valid_quantity(const std::string& quantity) { - if (quantity.empty()) { - return false; - } - - if (KNOWN_QUANTITIES.find(quantity) == KNOWN_QUANTITIES.end()) { - auto valid_quantities = std::vector(); - for (const auto& it: KNOWN_QUANTITIES) { - valid_quantities.emplace_back(it.first); - } - - static std::unordered_set ALREADY_WARNED = {}; - if (ALREADY_WARNED.insert(quantity).second) { - TORCH_WARN( - "unknown quantity '", quantity, "', only [", - torch::str(valid_quantities), "] are supported" - ); - } - return false; - } else { - return true; - } -} - - -void metatomic_torch::validate_unit(const std::string& quantity, const std::string& unit) { - if (quantity.empty() || unit.empty()) { - return; - } - - if (valid_quantity(quantity)) { - KNOWN_QUANTITIES.at(quantity).normalize_unit(unit); - } -} - -double metatomic_torch::unit_conversion_factor( - const std::string& quantity, - const std::string& from_unit, - const std::string& to_unit -) { - if (valid_quantity(quantity)) { - return KNOWN_QUANTITIES.at(quantity).conversion(from_unit, to_unit); - } else { - return 1.0; - } -} diff --git a/metatomic-torch/src/register.cpp b/metatomic-torch/src/register.cpp index f965b896..38f830d6 100644 --- a/metatomic-torch/src/register.cpp +++ b/metatomic-torch/src/register.cpp @@ -256,7 +256,10 @@ TORCH_LIBRARY(metatomic, m) { m.def("pick_output(str requested_output, Dict(str, __torch__.torch.classes.metatomic.ModelOutput) outputs, str? desired_variant = None) -> str", pick_output); m.def("read_model_metadata(str path) -> __torch__.torch.classes.metatomic.ModelMetadata", read_model_metadata); - m.def("unit_conversion_factor(str quantity, str from_unit, str to_unit) -> float", unit_conversion_factor); + m.def("unit_conversion_factor(str quantity, str from_unit, str to_unit) -> float", + static_cast(&unit_conversion_factor)); + m.def("unit_conversion_factor_v2(str from_unit, str to_unit) -> float", + static_cast(&unit_conversion_factor)); // manually construct the schema for "check_atomistic_model(str path) -> ()", // so we can set AliasAnalysisKind to CONSERVATIVE. In turn, this make it so diff --git a/metatomic-torch/src/system.cpp b/metatomic-torch/src/system.cpp index 4df00ec3..ffdedaef 100644 --- a/metatomic-torch/src/system.cpp +++ b/metatomic-torch/src/system.cpp @@ -52,7 +52,7 @@ void NeighborListOptionsHolder::set_length_unit(std::string length_unit) { } double NeighborListOptionsHolder::engine_cutoff(const std::string& engine_length_unit) const { - return cutoff_ * unit_conversion_factor("length", length_unit_, engine_length_unit); + return cutoff_ * unit_conversion_factor(length_unit_, engine_length_unit); } std::string NeighborListOptionsHolder::repr() const { diff --git a/metatomic-torch/src/units.cpp b/metatomic-torch/src/units.cpp new file mode 100644 index 00000000..bb920137 --- /dev/null +++ b/metatomic-torch/src/units.cpp @@ -0,0 +1,561 @@ +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "metatomic/torch/units.hpp" + +/******************************************************************************/ +/*** Unit expression parser with SI-based dimensional analysis ***/ +/******************************************************************************/ + +/// Physical dimension vector: [Length, Time, Mass, Charge, Temperature] +/// Exponents are double to support fractional powers like (eV*u)^(1/2). +struct Dimension { + std::array exponents = {}; + + Dimension operator*(const Dimension& other) const { + Dimension result; + for (size_t i = 0; i < 5; ++i) { + result.exponents[i] = exponents[i] + other.exponents[i]; + } + return result; + } + + Dimension operator/(const Dimension& other) const { + Dimension result; + for (size_t i = 0; i < 5; ++i) { + result.exponents[i] = exponents[i] - other.exponents[i]; + } + return result; + } + + Dimension pow(double p) const { + Dimension result; + for (size_t i = 0; i < 5; ++i) { + result.exponents[i] = exponents[i] * p; + } + return result; + } + + bool operator==(const Dimension& other) const { + for (size_t i = 0; i < 5; ++i) { + if (std::fabs(exponents[i] - other.exponents[i]) > 1e-10) { + return false; + } + } + return true; + } + + bool operator!=(const Dimension& other) const { + return !(*this == other); + } + + std::string to_string() const { + static const char* names[] = {"L", "T", "M", "Q", "Th"}; + std::string result = "["; + for (size_t i = 0; i < 5; ++i) { + if (i > 0) result += ","; + result += names[i]; + result += "="; + // format as integer if close to integer, else as decimal + double v = exponents[i]; + if (std::fabs(v - std::round(v)) < 1e-10) { + result += std::to_string(static_cast(std::round(v))); + } else { + result += std::to_string(v); + } + } + result += "]"; + return result; + } +}; + +/// A parsed unit value: SI conversion factor and physical dimension. +struct UnitValue { + double factor; + Dimension dim; +}; + +// Dimension constants for readability +// L T M Q Th +static const Dimension DIM_LENGTH = {{ 1, 0, 0, 0, 0 }}; +static const Dimension DIM_TIME = {{ 0, 1, 0, 0, 0 }}; +static const Dimension DIM_MASS = {{ 0, 0, 1, 0, 0 }}; +static const Dimension DIM_CHARGE = {{ 0, 0, 0, 1, 0 }}; +static const Dimension DIM_TEMPERATURE = {{ 0, 0, 0, 0, 1 }}; +static const Dimension DIM_ENERGY = {{ 2, -2, 1, 0, 0 }}; +static const Dimension DIM_NONE = {{ 0, 0, 0, 0, 0 }}; + +/// Lowercase a string in place and return it. +static std::string to_lower(std::string s) { + std::transform(s.begin(), s.end(), s.begin(), + [](unsigned char c) { + // Only lowercase ASCII; preserve UTF-8 continuation bytes + return c < 128 ? static_cast(std::tolower(c)) : static_cast(c); + } + ); + return s; +} + +/// All base units with SI factors and dimensions. +/// Factors are expressed in SI base units (m, s, kg, C, K). +/// Case-insensitive lookup: tokens are lowercased before searching. +static const std::unordered_map& base_units() { + static const auto units = std::unordered_map{ + // --- Length --- + {"angstrom", {1e-10, DIM_LENGTH}}, + {"a", {1e-10, DIM_LENGTH}}, + {"bohr", {5.29177210903e-11, DIM_LENGTH}}, + {"nm", {1e-9, DIM_LENGTH}}, + {"nanometer",{1e-9, DIM_LENGTH}}, + {"meter", {1.0, DIM_LENGTH}}, + {"m", {1.0, DIM_LENGTH}}, + {"cm", {1e-2, DIM_LENGTH}}, + {"centimeter",{1e-2, DIM_LENGTH}}, + {"mm", {1e-3, DIM_LENGTH}}, + {"millimeter",{1e-3, DIM_LENGTH}}, + {"um", {1e-6, DIM_LENGTH}}, + {"µm", {1e-6, DIM_LENGTH}}, + {"micrometer",{1e-6, DIM_LENGTH}}, + + // --- Energy --- + {"ev", {1.602176634e-19, DIM_ENERGY}}, + {"mev", {1.602176634e-22, DIM_ENERGY}}, + {"hartree", {4.3597447222071e-18, DIM_ENERGY}}, + {"ry", {2.1798723611e-18, DIM_ENERGY}}, + {"rydberg", {2.1798723611e-18, DIM_ENERGY}}, + {"joule", {1.0, DIM_ENERGY}}, + {"j", {1.0, DIM_ENERGY}}, + {"kcal", {4184.0, DIM_ENERGY}}, + {"kj", {1000.0, DIM_ENERGY}}, + + // --- Time --- + {"s", {1.0, DIM_TIME}}, + {"second", {1.0, DIM_TIME}}, + {"ms", {1e-3, DIM_TIME}}, + {"millisecond", {1e-3, DIM_TIME}}, + {"us", {1e-6, DIM_TIME}}, + {"µs", {1e-6, DIM_TIME}}, + {"microsecond", {1e-6, DIM_TIME}}, + {"ns", {1e-9, DIM_TIME}}, + {"nanosecond",{1e-9, DIM_TIME}}, + {"ps", {1e-12, DIM_TIME}}, + {"picosecond",{1e-12, DIM_TIME}}, + {"fs", {1e-15, DIM_TIME}}, + {"femtosecond",{1e-15, DIM_TIME}}, + + // --- Mass --- + {"u", {1.66053906660e-27, DIM_MASS}}, + {"dalton", {1.66053906660e-27, DIM_MASS}}, + {"kg", {1.0, DIM_MASS}}, + {"kilogram", {1.0, DIM_MASS}}, + {"g", {1e-3, DIM_MASS}}, + {"gram", {1e-3, DIM_MASS}}, + {"electron_mass", {9.1093837015e-31, DIM_MASS}}, + {"m_e", {9.1093837015e-31, DIM_MASS}}, + + // --- Charge --- + {"e", {1.602176634e-19, DIM_CHARGE}}, + {"coulomb", {1.0, DIM_CHARGE}}, + {"c", {1.0, DIM_CHARGE}}, + + // --- Dimensionless --- + {"mol", {6.02214076e23, DIM_NONE}}, + + // --- Derived --- + {"hbar", {1.054571817e-34, {{2, -1, 1, 0, 0}}}}, + }; + return units; +} + +// ---- Tokenizer ---- + +enum class TokenType { + LParen, RParen, Mul, Div, Pow, Value +}; + +struct Token { + TokenType type; + std::string value; // only meaningful for Value tokens + + int precedence() const { + switch (type) { + case TokenType::LParen: + case TokenType::RParen: return 0; + case TokenType::Mul: + case TokenType::Div: return 10; + case TokenType::Pow: return 20; + default: return -1; + } + } + + std::string as_str() const { + switch (type) { + case TokenType::LParen: return "("; + case TokenType::RParen: return ")"; + case TokenType::Mul: return "*"; + case TokenType::Div: return "/"; + case TokenType::Pow: return "^"; + case TokenType::Value: return value; + } + return "?"; + } +}; + +static std::vector tokenize(const std::string& unit) { + std::vector tokens; + std::string current; + + for (size_t i = 0; i < unit.size(); ++i) { + auto byte = static_cast(unit[i]); + + // Handle UTF-8 micro sign (U+00B5): 0xC2 0xB5 + if (byte == 0xC2 && i + 1 < unit.size() + && static_cast(unit[i + 1]) == 0xB5) + { + current += unit[i]; + current += unit[i + 1]; + ++i; // skip second byte + continue; + } + + char ch = unit[i]; + if (ch == '*' || ch == '/' || ch == '^' || ch == '(' || ch == ')') { + if (!current.empty()) { + tokens.push_back({TokenType::Value, current}); + current.clear(); + } + TokenType t; + switch (ch) { + case '*': t = TokenType::Mul; break; + case '/': t = TokenType::Div; break; + case '^': t = TokenType::Pow; break; + case '(': t = TokenType::LParen; break; + case ')': t = TokenType::RParen; break; + default: t = TokenType::Value; break; // unreachable + } + tokens.push_back({t, std::string(1, ch)}); + } else if (!std::isspace(byte)) { + current += ch; + } + } + if (!current.empty()) { + tokens.push_back({TokenType::Value, current}); + } + return tokens; +} + +// ---- Shunting-Yard ---- + +/// Convert infix tokens to Reverse Polish Notation using the Shunting-Yard +/// algorithm. All operators are treated as left-associative. +static std::vector shunting_yard(const std::vector& tokens) { + std::vector output; + std::vector operators; + + for (const auto& token : tokens) { + switch (token.type) { + case TokenType::Value: + output.push_back(token); + break; + case TokenType::Mul: + case TokenType::Div: + case TokenType::Pow: { + while (!operators.empty()) { + const auto& top = operators.back(); + // left-associative: pop while top >= current + if (token.precedence() <= top.precedence()) { + output.push_back(operators.back()); + operators.pop_back(); + } else { + break; + } + } + operators.push_back(token); + break; + } + case TokenType::LParen: + operators.push_back(token); + break; + case TokenType::RParen: { + while (!operators.empty() && operators.back().type != TokenType::LParen) { + output.push_back(operators.back()); + operators.pop_back(); + } + if (operators.empty() || operators.back().type != TokenType::LParen) { + C10_THROW_ERROR(ValueError, + "unit expression has unbalanced parentheses" + ); + } + operators.pop_back(); // discard LParen + break; + } + } + } + + while (!operators.empty()) { + if (operators.back().type == TokenType::LParen || + operators.back().type == TokenType::RParen) { + C10_THROW_ERROR(ValueError, + "unit expression has unbalanced parentheses" + ); + } + output.push_back(operators.back()); + operators.pop_back(); + } + + return output; +} + +// ---- AST evaluator ---- +// +// Departure from lumol: Pow exponent is a full sub-expression (not just i32) +// to handle ^(1/2). The exponent sub-expression must be dimensionless; its +// factor value becomes the exponent. + +struct UnitExpr; +using UnitExprPtr = std::unique_ptr; + +struct UnitExpr { + struct Val { UnitValue value; }; + struct Mul { UnitExprPtr lhs, rhs; }; + struct Div { UnitExprPtr lhs, rhs; }; + struct Pow { UnitExprPtr base, exponent; }; + + std::variant data; + + UnitValue eval() const { + return std::visit([](const auto& v) -> UnitValue { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return v.value; + } else if constexpr (std::is_same_v) { + auto l = v.lhs->eval(); + auto r = v.rhs->eval(); + return {l.factor * r.factor, l.dim * r.dim}; + } else if constexpr (std::is_same_v) { + auto l = v.lhs->eval(); + auto r = v.rhs->eval(); + return {l.factor / r.factor, l.dim / r.dim}; + } else if constexpr (std::is_same_v) { + auto b = v.base->eval(); + auto e = v.exponent->eval(); + if (e.dim != DIM_NONE) { + C10_THROW_ERROR(ValueError, + "exponent in unit expression must be dimensionless, " + "got dimension " + e.dim.to_string() + ); + } + return {std::pow(b.factor, e.factor), b.dim.pow(e.factor)}; + } + }, data); + } +}; + +/// Read one expression from the RPN stream (recursive, pops from the back). +static UnitExprPtr read_expr(std::vector& stream) { + if (stream.empty()) { + C10_THROW_ERROR(ValueError, + "malformed unit expression: missing a value" + ); + } + + auto token = stream.back(); + stream.pop_back(); + + switch (token.type) { + case TokenType::Value: { + auto lower = to_lower(token.value); + const auto& units = base_units(); + auto it = units.find(lower); + if (it != units.end()) { + auto expr = std::make_unique(); + expr->data = UnitExpr::Val{it->second}; + return expr; + } + // try parsing as a numeric literal (dimensionless) + try { + double val = std::stod(token.value); + auto expr = std::make_unique(); + expr->data = UnitExpr::Val{{val, DIM_NONE}}; + return expr; + } catch (...) { + C10_THROW_ERROR(ValueError, + "unknown unit '" + token.value + "'" + ); + } + } + case TokenType::Mul: { + auto rhs = read_expr(stream); + auto lhs = read_expr(stream); + auto expr = std::make_unique(); + expr->data = UnitExpr::Mul{std::move(lhs), std::move(rhs)}; + return expr; + } + case TokenType::Div: { + auto rhs = read_expr(stream); + auto lhs = read_expr(stream); + auto expr = std::make_unique(); + expr->data = UnitExpr::Div{std::move(lhs), std::move(rhs)}; + return expr; + } + case TokenType::Pow: { + // Exponent is a full sub-expression (supports ^(1/2)) + auto exponent = read_expr(stream); + auto base = read_expr(stream); + auto expr = std::make_unique(); + expr->data = UnitExpr::Pow{std::move(base), std::move(exponent)}; + return expr; + } + default: + C10_THROW_ERROR(ValueError, + "unexpected token in unit expression: " + token.as_str() + ); + } +} + +/// Parse a unit expression string and return the evaluated UnitValue. +static UnitValue parse_unit_expression(const std::string& unit) { + if (unit.empty()) { + return {1.0, DIM_NONE}; + } + + auto tokens = tokenize(unit); + if (tokens.empty()) { + return {1.0, DIM_NONE}; + } + + auto rpn = shunting_yard(tokens); + auto ast = read_expr(rpn); + + if (!rpn.empty()) { + std::string remaining; + for (const auto& t : rpn) { + if (!remaining.empty()) remaining += " "; + remaining += t.as_str(); + } + C10_THROW_ERROR(ValueError, + "malformed unit expression: leftover tokens '" + remaining + "'" + ); + } + + return ast->eval(); +} + +// ---- Quantity dimension map (for validate_unit) ---- + +static const std::unordered_map& quantity_dims() { + static const auto dims = std::unordered_map{ + {"length", DIM_LENGTH}, + {"energy", DIM_ENERGY}, + {"force", {{1, -2, 1, 0, 0}}}, // energy/length + {"pressure", {{-1, -2, 1, 0, 0}}}, // energy/length^3 + {"momentum", {{1, -1, 1, 0, 0}}}, // mass*length/time + {"mass", DIM_MASS}, + {"velocity", {{1, -1, 0, 0, 0}}}, // length/time + {"charge", DIM_CHARGE}, + }; + return dims; +} + +// ---- Public API ---- + +/// 2-argument unit_conversion_factor: parse both expressions, check dimensions +/// match, and return from_factor / to_factor. +double metatomic_torch::unit_conversion_factor( + const std::string& from_unit, + const std::string& to_unit +) { + if (from_unit.empty() || to_unit.empty()) { + return 1.0; + } + + auto from = parse_unit_expression(from_unit); + auto to = parse_unit_expression(to_unit); + + if (from.dim != to.dim) { + C10_THROW_ERROR(ValueError, + "dimension mismatch in unit conversion: '" + from_unit + + "' has dimension " + from.dim.to_string() + + " but '" + to_unit + "' has dimension " + to.dim.to_string() + ); + } + + return from.factor / to.factor; +} + +bool metatomic_torch::valid_quantity(const std::string& quantity) { + if (quantity.empty()) { + return false; + } + + const auto& dims = quantity_dims(); + if (dims.find(quantity) == dims.end()) { + auto valid_quantities = std::vector(); + for (const auto& it: dims) { + valid_quantities.emplace_back(it.first); + } + std::sort(valid_quantities.begin(), valid_quantities.end()); + + static std::unordered_set ALREADY_WARNED = {}; + if (ALREADY_WARNED.insert(quantity).second) { + TORCH_WARN( + "unknown quantity '", quantity, "', only [", + torch::str(valid_quantities), "] are supported" + ); + } + return false; + } else { + return true; + } +} + + +void metatomic_torch::validate_unit(const std::string& quantity, const std::string& unit) { + if (quantity.empty() || unit.empty()) { + return; + } + + // Always try to parse the expression (catches syntax errors) + auto parsed = parse_unit_expression(unit); + + // If the quantity is known, verify dimensions match + const auto& dims = quantity_dims(); + auto it = dims.find(quantity); + if (it != dims.end()) { + if (parsed.dim != it->second) { + C10_THROW_ERROR(ValueError, + "unit '" + unit + "' has dimension " + parsed.dim.to_string() + + " which is incompatible with quantity '" + quantity + + "' (expected " + it->second.to_string() + ")" + ); + } + } +} + +/// Deprecated 3-argument overload: ignores quantity, delegates to 2-arg. +double metatomic_torch::unit_conversion_factor( + const std::string& quantity, + const std::string& from_unit, + const std::string& to_unit +) { + static std::once_flag warn_flag; + std::call_once(warn_flag, [&]() { + TORCH_WARN( + "the 3-argument unit_conversion_factor(quantity, from, to) is " + "deprecated; use the 2-argument unit_conversion_factor(from, to) " + "instead. The quantity parameter is no longer needed." + ); + }); + + return metatomic_torch::unit_conversion_factor(from_unit, to_unit); +} diff --git a/metatomic-torch/tests/models.cpp b/metatomic-torch/tests/models.cpp index 0ae77300..55df429b 100644 --- a/metatomic-torch/tests/models.cpp +++ b/metatomic-torch/tests/models.cpp @@ -55,7 +55,7 @@ TEST_CASE("Models metadata") { ); CHECK_THROWS_WITH(options->set_length_unit("unknown"), - StartsWith("unknown unit 'unknown' for length") + StartsWith("unknown unit 'unknown'") ); } @@ -103,7 +103,7 @@ TEST_CASE("Models metadata") { ); CHECK_THROWS_WITH(output->set_unit("unknown"), - StartsWith("unknown unit 'unknown' for length") + StartsWith("unknown unit 'unknown'") ); struct WarningHandler: public torch::WarningHandler { @@ -135,8 +135,10 @@ TEST_CASE("Models metadata") { auto output = torch::make_intrusive(); output->per_atom = true; - output->set_quantity("something"); - output->set_unit("something"); + // Use valid quantity/unit since the parser validates unit expressions + // at set_unit() time (previously any string was accepted) + output->set_quantity("energy"); + output->set_unit("eV"); options->outputs.insert("output_2", output); const auto* expected = R"({ @@ -156,8 +158,8 @@ TEST_CASE("Models metadata") { "description": "", "explicit_gradients": [], "per_atom": true, - "quantity": "something", - "unit": "something" + "quantity": "energy", + "unit": "eV" } }, "selected_atoms": null @@ -205,7 +207,7 @@ TEST_CASE("Models metadata") { ); CHECK_THROWS_WITH(options->set_length_unit("unknown"), - StartsWith("unknown unit 'unknown' for length") + StartsWith("unknown unit 'unknown'") ); } @@ -295,7 +297,7 @@ TEST_CASE("Models metadata") { ); CHECK_THROWS_WITH(capabilities->set_length_unit("unknown"), - StartsWith("unknown unit 'unknown' for length") + StartsWith("unknown unit 'unknown'") ); auto capabilities_variants = torch::make_intrusive(); diff --git a/metatomic-torch/tests/units.cpp b/metatomic-torch/tests/units.cpp new file mode 100644 index 00000000..33bdb689 --- /dev/null +++ b/metatomic-torch/tests/units.cpp @@ -0,0 +1,274 @@ +#include +#include + +#include + +#include "metatomic/torch.hpp" + +#include +using Catch::Detail::Approx; +using Catch::Matchers::Contains; +using Catch::Matchers::StartsWith; + +class DeprecationWarningHandler : public torch::WarningHandler { +public: + std::vector messages; + + void process(const torch::Warning& warning) override { + messages.push_back(warning.msg()); + } +}; + +// ---- Simple conversions ---- + +TEST_CASE("Simple length conversions") { + // Angstrom <-> Bohr + double a_to_bohr = metatomic_torch::unit_conversion_factor("Angstrom", "Bohr"); + CHECK(a_to_bohr == Approx(1.8897259886).epsilon(1e-6)); + + double bohr_to_a = metatomic_torch::unit_conversion_factor("Bohr", "Angstrom"); + CHECK(bohr_to_a == Approx(0.529177210903).epsilon(1e-6)); + + // Angstrom <-> nm + double a_to_nm = metatomic_torch::unit_conversion_factor("Angstrom", "nm"); + CHECK(a_to_nm == Approx(0.1).epsilon(1e-12)); + + // Angstrom <-> meter + double a_to_m = metatomic_torch::unit_conversion_factor("Angstrom", "meter"); + CHECK(a_to_m == Approx(1e-10).epsilon(1e-20)); +} + +TEST_CASE("Simple energy conversions") { + // eV <-> meV + double ev_to_mev = metatomic_torch::unit_conversion_factor("eV", "meV"); + CHECK(ev_to_mev == Approx(1000.0).epsilon(1e-10)); + + // eV <-> Hartree + double ev_to_hartree = metatomic_torch::unit_conversion_factor("eV", "Hartree"); + CHECK(ev_to_hartree == Approx(0.0367493).epsilon(1e-4)); + + // eV <-> Rydberg + double ev_to_ry = metatomic_torch::unit_conversion_factor("eV", "Ry"); + CHECK(ev_to_ry == Approx(0.0734987).epsilon(1e-4)); +} + +// ---- Compound expressions ---- + +TEST_CASE("Compound unit expressions") { + // eV/Angstrom <-> Hartree/Bohr (force) + double f_conv = metatomic_torch::unit_conversion_factor("eV/Angstrom", "Hartree/Bohr"); + CHECK(f_conv == Approx(0.0194469).epsilon(1e-3)); + + // kJ/mol <-> kcal/mol (energy) + double e_conv = metatomic_torch::unit_conversion_factor("kJ/mol", "kcal/mol"); + CHECK(e_conv == Approx(1000.0 / 4184.0).epsilon(1e-6)); + + // eV/Angstrom^3 (pressure) -- identity + double p_id = metatomic_torch::unit_conversion_factor("eV/Angstrom^3", "eV/A^3"); + CHECK(p_id == Approx(1.0).epsilon(1e-10)); +} + +// ---- Fractional powers ---- + +TEST_CASE("Fractional power expressions") { + // (eV*u)^(1/2) should have dimension of momentum: M*L/T + // Compare to u*A/fs + double conv = metatomic_torch::unit_conversion_factor("(eV*u)^(1/2)", "u*A/fs"); + // Cross-check: sqrt(eV_SI * u_SI) / (u_SI * A_SI / fs_SI) + double ev_si = 1.602176634e-19; + double u_si = 1.66053906660e-27; + double a_si = 1e-10; + double fs_si = 1e-15; + double expected = std::sqrt(ev_si * u_si) / (u_si * a_si / fs_si); + CHECK(conv == Approx(expected).epsilon(1e-4)); + + // (eV/u)^(1/2) has dimension of velocity: L/T + double v_conv = metatomic_torch::unit_conversion_factor("(eV/u)^(1/2)", "A/fs"); + double v_expected = std::sqrt(ev_si / u_si) / (a_si / fs_si); + CHECK(v_conv == Approx(v_expected).epsilon(1e-4)); +} + +// ---- Case insensitivity ---- + +TEST_CASE("Case insensitive unit lookup") { + double c1 = metatomic_torch::unit_conversion_factor("eV", "hartree"); + double c2 = metatomic_torch::unit_conversion_factor("EV", "HARTREE"); + double c3 = metatomic_torch::unit_conversion_factor("Ev", "Hartree"); + CHECK(c1 == Approx(c2).epsilon(1e-12)); + CHECK(c1 == Approx(c3).epsilon(1e-12)); +} + +// ---- Whitespace handling ---- + +TEST_CASE("Whitespace in unit expressions") { + double c1 = metatomic_torch::unit_conversion_factor("eV / Angstrom", "Hartree/Bohr"); + double c2 = metatomic_torch::unit_conversion_factor("eV/Angstrom", "Hartree/Bohr"); + CHECK(c1 == Approx(c2).epsilon(1e-12)); + + double c3 = metatomic_torch::unit_conversion_factor("( eV * u ) ^ ( 1 / 2 )", "u*A/fs"); + double c4 = metatomic_torch::unit_conversion_factor("(eV*u)^(1/2)", "u*A/fs"); + CHECK(c3 == Approx(c4).epsilon(1e-12)); +} + +// ---- Dimension mismatch errors ---- + +TEST_CASE("Dimension mismatch error") { + CHECK_THROWS_WITH( + metatomic_torch::unit_conversion_factor("eV", "Angstrom"), + Contains("dimension mismatch") + ); +} + +// ---- Unknown token errors ---- + +TEST_CASE("Unknown unit token error") { + CHECK_THROWS_WITH( + metatomic_torch::unit_conversion_factor("foobar", "eV"), + Contains("unknown unit") + ); +} + +// ---- Malformed expression errors ---- + +TEST_CASE("Malformed expression errors") { + CHECK_THROWS_WITH( + metatomic_torch::unit_conversion_factor("eV*(", "eV"), + Contains("parentheses") + ); +} + +// ---- Empty string handling ---- + +TEST_CASE("Empty unit string returns 1.0") { + CHECK(metatomic_torch::unit_conversion_factor("", "eV") == 1.0); + CHECK(metatomic_torch::unit_conversion_factor("eV", "") == 1.0); + CHECK(metatomic_torch::unit_conversion_factor("", "") == 1.0); +} + +// ---- Backward compatibility: 3-arg API ---- + +TEST_CASE("3-arg API backward compatibility") { + DeprecationWarningHandler handler; + torch::WarningUtils::WarningHandlerGuard guard(&handler); + torch::WarningUtils::set_warnAlways(true); + + double conv = metatomic_torch::unit_conversion_factor("energy", "eV", "meV"); + CHECK(conv == Approx(1000.0).epsilon(1e-10)); +} + +// ---- mol handling (dimensionless scaling) ---- + +TEST_CASE("mol as dimensionless scaling factor") { + // kJ/mol to eV: kJ_SI / mol_SI / eV_SI + double conv = metatomic_torch::unit_conversion_factor("kJ/mol", "eV"); + double kj_si = 1000.0; + double mol_si = 6.02214076e23; + double ev_si = 1.602176634e-19; + double expected = (kj_si / mol_si) / ev_si; + CHECK(conv == Approx(expected).epsilon(1e-6)); +} + +// ---- Hbar derived unit ---- + +TEST_CASE("hbar/Bohr momentum conversion") { + double conv = metatomic_torch::unit_conversion_factor("hbar/Bohr", "u*A/fs"); + // hbar_SI / Bohr_SI / (u_SI * A_SI / fs_SI) + double hbar_si = 1.054571817e-34; + double bohr_si = 5.29177210903e-11; + double u_si = 1.66053906660e-27; + double a_si = 1e-10; + double fs_si = 1e-15; + double expected = (hbar_si / bohr_si) / (u_si * a_si / fs_si); + CHECK(conv == Approx(expected).epsilon(1e-3)); +} + +// ---- Identity conversions ---- + +TEST_CASE("Identity conversions") { + CHECK(metatomic_torch::unit_conversion_factor("eV", "eV") == Approx(1.0)); + CHECK(metatomic_torch::unit_conversion_factor("Angstrom", "Angstrom") == Approx(1.0)); + CHECK(metatomic_torch::unit_conversion_factor("u*A/fs", "u*A/fs") == Approx(1.0)); +} + +// ---- Time unit conversions ---- + +TEST_CASE("Time unit conversions") { + // second -> femtosecond + double s_to_fs = metatomic_torch::unit_conversion_factor("s", "fs"); + CHECK(s_to_fs == Approx(1e15).epsilon(1e-6)); + + // second -> picosecond + double s_to_ps = metatomic_torch::unit_conversion_factor("second", "ps"); + CHECK(s_to_ps == Approx(1e12).epsilon(1e-6)); + + // nanosecond -> femtosecond + double ns_to_fs = metatomic_torch::unit_conversion_factor("ns", "fs"); + CHECK(ns_to_fs == Approx(1e6).epsilon(1e-6)); + + // microsecond -> nanosecond + double us_to_ns = metatomic_torch::unit_conversion_factor("us", "ns"); + CHECK(us_to_ns == Approx(1e3).epsilon(1e-6)); +} + +// ---- Micro sign handling ---- + +TEST_CASE("Micro sign (U+00B5) handling") { + double c1 = metatomic_torch::unit_conversion_factor("um", "Angstrom"); + double c2 = metatomic_torch::unit_conversion_factor("µm", "Angstrom"); + CHECK(c1 == Approx(c2).epsilon(1e-12)); + + // µs -> ns (microsecond via micro sign) + double c3 = metatomic_torch::unit_conversion_factor("us", "ns"); + double c4 = metatomic_torch::unit_conversion_factor("µs", "ns"); + CHECK(c3 == Approx(c4).epsilon(1e-12)); +} + +// ---- Quantity-unit dimension validation ---- + +TEST_CASE("ModelOutput rejects mismatched quantity and unit") { + // energy quantity with a force unit + CHECK_THROWS_WITH( + torch::make_intrusive( + "energy", "eV/A", false, std::vector{}, "" + ), + Contains("incompatible with quantity") + ); + + // force quantity with an energy unit + CHECK_THROWS_WITH( + torch::make_intrusive( + "force", "eV", false, std::vector{}, "" + ), + Contains("incompatible with quantity") + ); + + // length quantity with a pressure unit + CHECK_THROWS_WITH( + torch::make_intrusive( + "length", "eV/A^3", false, std::vector{}, "" + ), + Contains("incompatible with quantity") + ); +} + +TEST_CASE("ModelOutput accepts matching quantity and unit") { + // These should not throw + torch::make_intrusive( + "energy", "eV", false, std::vector{}, "" + ); + torch::make_intrusive( + "force", "eV/A", false, std::vector{}, "" + ); + torch::make_intrusive( + "pressure", "eV/A^3", false, std::vector{}, "" + ); + torch::make_intrusive( + "length", "Angstrom", false, std::vector{}, "" + ); + torch::make_intrusive( + "momentum", "u*A/fs", false, std::vector{}, "" + ); + torch::make_intrusive( + "velocity", "A/fs", false, std::vector{}, "" + ); +} diff --git a/python/metatomic_torch/metatomic/torch/__init__.py b/python/metatomic_torch/metatomic/torch/__init__.py index 8855e1d6..f24bb09d 100644 --- a/python/metatomic_torch/metatomic/torch/__init__.py +++ b/python/metatomic_torch/metatomic/torch/__init__.py @@ -43,7 +43,63 @@ _check_outputs = torch.ops.metatomic._check_outputs register_autograd_neighbors = torch.ops.metatomic.register_autograd_neighbors - unit_conversion_factor = torch.ops.metatomic.unit_conversion_factor + + _unit_conversion_factor_v2 = torch.ops.metatomic.unit_conversion_factor_v2 + _unit_conversion_factor_v1 = torch.ops.metatomic.unit_conversion_factor + + def unit_conversion_factor(*args, **kwargs): + """ + Get the multiplicative conversion factor from ``from_unit`` to + ``to_unit``. + + Supports two calling conventions: + + - **2-argument** (preferred): + ``unit_conversion_factor(from_unit, to_unit)`` + - **3-argument** (deprecated): + ``unit_conversion_factor(quantity, from_unit, to_unit)`` + + Both ``from_unit`` and ``to_unit`` are parsed as unit expressions + supporting compound forms like ``"kJ/mol/A^2"`` or + ``"(eV*u)^(1/2)"``. The parser validates that both expressions have + matching physical dimensions. + + The ``quantity`` parameter in the 3-argument form is ignored + (dimensional compatibility is checked by the parser). A deprecation + warning will be emitted if the 3-argument form is used. + + The set of recognized base unit tokens is available :ref:`here + `. + + :param from_unit: current unit of the data (expression string) + :param to_unit: target unit of the data (expression string) + """ + import warnings + + if len(args) == 2 and not kwargs: + return _unit_conversion_factor_v2(args[0], args[1]) + elif len(args) == 3 and not kwargs: + warnings.warn( + "the 3-argument unit_conversion_factor(quantity, from, to) is " + "deprecated; use the 2-argument form instead", + DeprecationWarning, + stacklevel=2, + ) + return _unit_conversion_factor_v2(args[1], args[2]) + elif "from_unit" in kwargs and "to_unit" in kwargs: + if "quantity" in kwargs: + warnings.warn( + "the 3-argument unit_conversion_factor(quantity, from, to)" + " is deprecated; use the 2-argument form instead", + DeprecationWarning, + stacklevel=2, + ) + return _unit_conversion_factor_v2(kwargs["from_unit"], kwargs["to_unit"]) + else: + raise TypeError( + "unit_conversion_factor() expects 2 or 3 positional arguments" + ) + pick_device = torch.ops.metatomic.pick_device pick_output = torch.ops.metatomic.pick_output diff --git a/python/metatomic_torch/metatomic/torch/documentation.py b/python/metatomic_torch/metatomic/torch/documentation.py index e0787ab2..cb42058b 100644 --- a/python/metatomic_torch/metatomic/torch/documentation.py +++ b/python/metatomic_torch/metatomic/torch/documentation.py @@ -522,15 +522,32 @@ def register_autograd_neighbors( """ -def unit_conversion_factor(quantity: str, from_unit: str, to_unit: str): +def unit_conversion_factor(from_unit: str, to_unit: str) -> float: """ - Get the multiplicative conversion factor from ``from_unit`` to ``to_unit``. Both - units must be valid and known for the given physical ``quantity``. The set of valid - quantities and units is available :ref:`here `. + Get the multiplicative conversion factor from ``from_unit`` to + ``to_unit``. - :param quantity: name of the physical quantity - :param from_unit: current unit of the data - :param to_unit: target unit of the data + Supports two calling conventions: + + - **2-argument** (preferred): + ``unit_conversion_factor(from_unit, to_unit)`` + - **3-argument** (deprecated): + ``unit_conversion_factor(quantity, from_unit, to_unit)`` + + Both ``from_unit`` and ``to_unit`` are parsed as unit expressions + supporting compound forms like ``"kJ/mol/A^2"`` or + ``"(eV*u)^(1/2)"``. The parser validates that both expressions have + matching physical dimensions. + + The ``quantity`` parameter in the 3-argument form is ignored + (dimensional compatibility is checked by the parser). A deprecation + warning will be emitted if the 3-argument form is used. + + The set of recognized base unit tokens is available :ref:`here + `. + + :param from_unit: current unit of the data (expression string) + :param to_unit: target unit of the data (expression string) """ diff --git a/python/metatomic_torch/metatomic/torch/model.py b/python/metatomic_torch/metatomic/torch/model.py index c2a9f528..b6da08c9 100644 --- a/python/metatomic_torch/metatomic/torch/model.py +++ b/python/metatomic_torch/metatomic/torch/model.py @@ -21,7 +21,6 @@ _check_outputs, check_atomistic_model, load_model_extensions, - unit_conversion_factor, ) from . import __version__ as metatomic_version from ._extensions import _collect_extensions @@ -517,10 +516,9 @@ def forward( f"'{requested.quantity}'" ) - conversion = unit_conversion_factor( - quantity=declared.quantity, - from_unit=declared.unit, - to_unit=requested.unit, + conversion = torch.ops.metatomic.unit_conversion_factor_v2( + declared.unit, + requested.unit, ) if conversion != 1.0: @@ -934,10 +932,9 @@ def _convert_systems_units( # no conversion for positions/cell/NL conversion = 1.0 else: - conversion = unit_conversion_factor( - quantity="length", - from_unit=system_length_unit, - to_unit=model_length_unit, + conversion = torch.ops.metatomic.unit_conversion_factor_v2( + system_length_unit, + model_length_unit, ) new_systems: List[System] = [] @@ -974,10 +971,9 @@ def _convert_systems_units( unit = tensor.get_info("unit") if requested.quantity != "" and unit is not None: - conversion = unit_conversion_factor( - quantity=requested.quantity, - from_unit=unit, - to_unit=requested.unit, + conversion = torch.ops.metatomic.unit_conversion_factor_v2( + unit, + requested.unit, ) else: conversion = 1.0 diff --git a/python/metatomic_torch/tests/units.py b/python/metatomic_torch/tests/units.py index 9b47894b..39baac16 100644 --- a/python/metatomic_torch/tests/units.py +++ b/python/metatomic_torch/tests/units.py @@ -1,36 +1,135 @@ +import math +import warnings + import ase.units +import pytest from metatomic.torch import ModelOutput, unit_conversion_factor -def test_conversion_length(): - length_angstrom = 1.0 - length_nm = unit_conversion_factor("length", "angstrom", "nm") * length_angstrom - assert length_nm == 0.1 +# ---- Backward compat: 3-arg still works (with deprecation warning) ---- -def test_conversion_energy(): - energy_ev = 1.0 - energy_mev = unit_conversion_factor("energy", "ev", "mev") * energy_ev - assert energy_mev == 1000.0 +def test_conversion_length_3arg(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + length_angstrom = 1.0 + length_nm = unit_conversion_factor("length", "angstrom", "nm") * length_angstrom + assert length_nm == pytest.approx(0.1) + +def test_conversion_energy_3arg(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + energy_ev = 1.0 + energy_mev = unit_conversion_factor("energy", "ev", "mev") * energy_ev + assert energy_mev == pytest.approx(1000.0) -def test_units(): - def length_conversion(unit): - return unit_conversion_factor("length", "angstrom", unit) - assert length_conversion("bohr") == ase.units.Ang / ase.units.Bohr - assert length_conversion("nm") == ase.units.Ang / ase.units.nm - assert length_conversion("nanometer") == ase.units.Ang / ase.units.nm +# ---- 2-arg API ---- + + +def test_conversion_length(): + assert unit_conversion_factor("angstrom", "nm") == pytest.approx(0.1) + assert unit_conversion_factor("angstrom", "Bohr") == pytest.approx( + 1.8897259886, rel=1e-6 + ) + assert unit_conversion_factor("Angstrom", "meter") == pytest.approx( + 1e-10, rel=1e-10 + ) - def energy_conversion(unit): - return unit_conversion_factor("energy", "ev", unit) - assert energy_conversion("Hartree") == ase.units.eV / ase.units.Hartree +def test_conversion_energy(): + assert unit_conversion_factor("eV", "meV") == pytest.approx(1000.0) + assert unit_conversion_factor("eV", "Hartree") == pytest.approx(0.0367493, rel=1e-3) + + +def test_units_vs_ase(): + assert unit_conversion_factor("angstrom", "bohr") == pytest.approx( + ase.units.Ang / ase.units.Bohr, rel=1e-6 + ) + assert unit_conversion_factor("angstrom", "nm") == pytest.approx( + ase.units.Ang / ase.units.nm + ) + assert unit_conversion_factor("angstrom", "nanometer") == pytest.approx( + ase.units.Ang / ase.units.nm + ) + + assert unit_conversion_factor("ev", "Hartree") == pytest.approx( + ase.units.eV / ase.units.Hartree, rel=1e-6 + ) kcal_mol = ase.units.kcal / ase.units.mol - assert energy_conversion("kcal/mol") == ase.units.eV / kcal_mol + assert unit_conversion_factor("ev", "kcal/mol") == pytest.approx( + ase.units.eV / kcal_mol, rel=1e-4 + ) kJ_mol = ase.units.kJ / ase.units.mol - assert energy_conversion("kJ/mol") == ase.units.eV / kJ_mol + assert unit_conversion_factor("ev", "kJ/mol") == pytest.approx( + ase.units.eV / kJ_mol, rel=1e-4 + ) + + +# ---- Compound expressions ---- + + +def test_compound_expressions(): + # Force: eV/Angstrom -> Hartree/Bohr + conv = unit_conversion_factor("eV/Angstrom", "Hartree/Bohr") + assert conv == pytest.approx(0.0194469, rel=1e-3) + + # kJ/mol -> kcal/mol + conv = unit_conversion_factor("kJ/mol", "kcal/mol") + assert conv == pytest.approx(1000.0 / 4184.0, rel=1e-6) + + # Pressure identity + conv = unit_conversion_factor("eV/Angstrom^3", "eV/A^3") + assert conv == pytest.approx(1.0) + + +# ---- Fractional powers ---- + + +def test_fractional_powers(): + # (eV*u)^(1/2) -> u*A/fs + conv = unit_conversion_factor("(eV*u)^(1/2)", "u*A/fs") + ev_si = 1.602176634e-19 + u_si = 1.66053906660e-27 + a_si = 1e-10 + fs_si = 1e-15 + expected = math.sqrt(ev_si * u_si) / (u_si * a_si / fs_si) + assert conv == pytest.approx(expected, rel=1e-3) + + # (eV/u)^(1/2) -> A/fs + conv = unit_conversion_factor("(eV/u)^(1/2)", "A/fs") + expected = math.sqrt(ev_si / u_si) / (a_si / fs_si) + assert conv == pytest.approx(expected, rel=1e-3) + + +# ---- Dimension mismatch ---- + + +def test_dimension_mismatch(): + with pytest.raises((ValueError, RuntimeError), match="dimension mismatch"): + unit_conversion_factor("eV", "Angstrom") + + +# ---- Unknown token ---- + + +def test_unknown_token(): + with pytest.raises((ValueError, RuntimeError), match="unknown unit"): + unit_conversion_factor("foobar", "eV") + + +# ---- Empty string ---- + + +def test_empty_string(): + assert unit_conversion_factor("", "eV") == 1.0 + assert unit_conversion_factor("eV", "") == 1.0 + assert unit_conversion_factor("", "") == 1.0 + + +# ---- Valid units (ModelOutput creation still works) ---- def test_valid_units(): @@ -67,3 +166,56 @@ def test_valid_units(): ModelOutput(quantity="momentum", unit="u * A/ fs") ModelOutput(quantity="momentum", unit=" (eV*u )^(1/ 2 )") + + ModelOutput(quantity="velocity", unit="A/fs") + ModelOutput(quantity="velocity", unit="A/s") + + +# ---- Time units ---- + + +def test_time_units(): + assert unit_conversion_factor("s", "fs") == pytest.approx(1e15) + assert unit_conversion_factor("second", "ps") == pytest.approx(1e12) + assert unit_conversion_factor("ns", "fs") == pytest.approx(1e6) + assert unit_conversion_factor("us", "ns") == pytest.approx(1e3) + assert unit_conversion_factor("ms", "us") == pytest.approx(1e3) + + +# ---- Micro sign for microsecond ---- + + +def test_micro_sign_microsecond(): + # µs -> ns (microsecond via micro sign) + assert unit_conversion_factor("µs", "ns") == pytest.approx( + unit_conversion_factor("us", "ns") + ) + + +# ---- Quantity-unit mismatch ---- + + +def test_quantity_unit_mismatch(): + # energy quantity with force unit + with pytest.raises((ValueError, RuntimeError), match="incompatible with quantity"): + ModelOutput(quantity="energy", unit="eV/A") + + # force quantity with energy unit + with pytest.raises((ValueError, RuntimeError), match="incompatible with quantity"): + ModelOutput(quantity="force", unit="eV") + + # length quantity with pressure unit + with pytest.raises((ValueError, RuntimeError), match="incompatible with quantity"): + ModelOutput(quantity="length", unit="eV/A^3") + + +# ---- Deprecation warning for 3-arg ---- + + +def test_3arg_deprecation_warning(): + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + unit_conversion_factor("energy", "eV", "meV") + assert len(w) >= 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "deprecated" in str(w[0].message).lower()