From 047ba5c207c1e5b64ccf185a1326ca5a0ecb136e Mon Sep 17 00:00:00 2001 From: hzha0455 Date: Tue, 7 Oct 2025 21:23:23 +1100 Subject: [PATCH 1/2] Add numeric singularity mitigation integration --- odetoolbox/singularity_analysis_mitigation.py | 103 +++++++++++++ odetoolbox/system_of_shapes.py | 93 +++++++++++- tests/test_analysis_mitigation.py | 140 ++++++++++++++++++ 3 files changed, 332 insertions(+), 4 deletions(-) create mode 100644 odetoolbox/singularity_analysis_mitigation.py create mode 100644 tests/test_analysis_mitigation.py diff --git a/odetoolbox/singularity_analysis_mitigation.py b/odetoolbox/singularity_analysis_mitigation.py new file mode 100644 index 00000000..1e881cef --- /dev/null +++ b/odetoolbox/singularity_analysis_mitigation.py @@ -0,0 +1,103 @@ +# odetoolbox/singularity_analysis_mitigation.py +from __future__ import annotations +from dataclasses import dataclass +from typing import Callable, Dict, Iterable, List, Tuple, Optional +import numpy as np +from scipy.linalg import expm + +# ---- Type definitions ---- +Params = Dict[str, float] +Cond = Tuple[str, str] +AFunction = Callable[[Params], np.ndarray] + + +# ---- Data structures ---- +@dataclass +class BranchResult: + """Represents the result of one branch: a condition and its corresponding propagator matrix.""" + condition: Optional[Cond] + P: np.ndarray + +@dataclass +class BranchPack: + """Container for all branches, including the default one.""" + default: BranchResult + branches: List[BranchResult] + + +# ---- Utility functions ---- +def _ensure_square(A: np.ndarray) -> None: + """Ensure that the matrix A is square.""" + if A.ndim != 2 or A.shape[0] != A.shape[1]: + raise ValueError("A must be a square 2D array") + +def _matrix_exp(A: np.ndarray, h: float) -> np.ndarray: + """Compute the matrix exponential P = exp(-h * A).""" + _ensure_square(A) + return expm(-h * A) + +def _tie_params(base: Params, cond: Cond, mode: str = "left") -> Params: + """ + Force parameters in cond = (a, b) to become equal according to the specified mode. + + Modes: + - 'left' : overwrite b with a's value + - 'right' : overwrite a with b's value + - 'avg' : replace both with their average value + """ + a, b = cond + if a not in base or b not in base: + raise KeyError(f"Unknown parameter in condition: {cond}") + out = dict(base) + if mode == "left": + out[b] = out[a] + elif mode == "right": + out[a] = out[b] + elif mode == "avg": + v = 0.5 * (out[a] + out[b]) + out[a] = out[b] = v + else: + raise ValueError("mode must be one of: 'left', 'right', or 'avg'") + return out + + +# ---- Main process---- +def build_transition_branches_numeric( + A_fn: AFunction, + h: float, + base_params: Params, + param_names: Iterable[str], + conditions: Optional[Iterable[Cond]] = None, + tie_mode: str = "left", +) -> BranchPack: + """ + Construct default and conditional transition branches numerically. + + Steps: + 1) Given a system x' = A x, build A using A_fn(base_params). + 2) Compute the default propagator P0 = exp(-h * A). + 3) If no conditions are provided, generate all pairwise equality candidates. + 4) For each condition, make the corresponding parameters equal, + rebuild A' = A_fn(params'), and compute P' = exp(-h * A'). + 5) Collect all branches into a BranchPack object. + """ + # Step 1–2: default branch + A0 = np.asarray(A_fn(dict(base_params)), dtype=float) + P0 = _matrix_exp(A0, h) + default = BranchResult(condition=None, P=P0) + + # Step 3: generate pairwise equality conditions if none provided + names = list(param_names) + if conditions is None: + conditions = [(names[i], names[j]) for i in range(len(names)) for j in range(i + 1, len(names))] + + # Step 4–5: compute each conditional branch + branches: List[BranchResult] = [] + for cond in conditions: + params_ = _tie_params(base_params, cond, mode=tie_mode) + A_ = np.asarray(A_fn(params_), dtype=float) + P_ = _matrix_exp(A_, h) + branches.append(BranchResult(condition=tuple(sorted(cond)), P=P_)) + + return BranchPack(default=default, branches=branches) + diff --git a/odetoolbox/system_of_shapes.py b/odetoolbox/system_of_shapes.py index 8e970e29..c3f079fe 100644 --- a/odetoolbox/system_of_shapes.py +++ b/odetoolbox/system_of_shapes.py @@ -33,6 +33,11 @@ from .shapes import Shape from .singularity_detection import SingularityDetection, SingularityDetectionException from .sympy_helpers import _custom_simplify_expr, _is_zero +from odetoolbox.singularity_analysis_mitigation import ( + build_transition_branches_numeric as _build_branches_numeric, + BranchResult as _BR, +) + class GetBlockDiagonalException(Exception): @@ -233,9 +238,8 @@ def generate_propagator_solver(self, disable_singularity_detection: bool = False P = self._generate_propagator_matrix(self.A_) - # - # singularity detection - # + + if not disable_singularity_detection: try: @@ -316,6 +320,87 @@ def generate_propagator_solver(self, disable_singularity_detection: bool = False "state_variables": all_state_symbols, "initial_values": initial_values} +#------ + try: + # 1) Re-run singularity detection + _conds_pairs = [] + try: + _conds = SingularityDetection.find_propagator_singularities(P, self.A_) + # Flatten sets like {a=b, c=d} into a list of tuples: [("a","b"), ("c","d"), ...] + for _set in _conds or []: + for _eq in _set: + _L, _R = getattr(_eq, "lhs", None), getattr(_eq, "rhs", None) + if isinstance(_L, sympy.Symbol) and isinstance(_R, sympy.Symbol): + _conds_pairs.append((str(_L), str(_R))) + except Exception: + _conds_pairs = [] + + # If no detectable parameter equality conditions are found, skip branching + if _conds_pairs: + # 2) Collect parameter symbols + _Hsym = sympy.Symbol(Config().output_timestep_symbol, real=True) + _xset = set(self.x_) if hasattr(self, "x_") else set() + _A_syms = sorted( + [s for s in self.A_.free_symbols if s not in _xset and s != _Hsym], + key=lambda s: str(s) + ) + _param_names = [str(s) for s in _A_syms] + + # 3) Retrieve base parameter values + def _get_param_val(_s: sympy.Symbol) -> float: + _k = str(_s) + if hasattr(self, "parameter_values") and _k in getattr(self, "parameter_values"): + return float(self.parameter_values[_k]) + if hasattr(self, "_params") and _k in getattr(self, "_params"): + return float(self._params[_k]) + # Missing parameter: raise an error to skip mitigation + raise KeyError(f"Missing numeric value for parameter: {_k}") + + _base_params = {str(s): _get_param_val(s) for s in _A_syms} + + # 4) Construct numerical function A + _A_lmb = sympy.lambdify(_A_syms, self.A_, modules="numpy") + + def _A_fn(_params_dict): + _vals = [_params_dict[name] for name in _param_names] + return np.array(_A_lmb(*_vals), dtype=float) + + # 5) Determine timestep h: prefer self.h / self.dt; otherwise use H symbol or fallback to 1.0 + if hasattr(self, "h"): + _h_val = float(self.h) + elif hasattr(self, "dt"): + _h_val = float(self.dt) + else: + _h_val = float(_base_params.get(str(_Hsym), 1.0)) + + # 6) Build numeric transition branches + _pack = _build_branches_numeric( + A_fn=_A_fn, + h=_h_val, + base_params=_base_params, + param_names=_param_names, + conditions=[tuple(sorted(p)) for p in _conds_pairs], + tie_mode="left", # Can also be "right" or "avg" + ) + + # 7) Insert branch results into solver_dict + def _to_list(_P): + # Convert NumPy array to list for JSON serialization + return _P.tolist() + + def _pack_branch(_br: _BR): + _cond = None if _br.condition is None else {"eq": [_br.condition[0], _br.condition[1]]} + return {"condition": _cond, "P": _to_list(_br.P)} + + solver_dict["branching"] = { + "default": _pack_branch(_pack.default), + "branches": [_pack_branch(_b) for _b in _pack.branches], + "tie_mode": "left", + } + except Exception as _mit_err: + # Mitigation failure does not affect the original return + logging.debug(f"[mitigation] numeric branching skipped: {_mit_err}") + return solver_dict @@ -451,4 +536,4 @@ def from_shapes(cls, shapes: List[Shape], parameters=None): i += shape.order - return SystemOfShapes(x, A, b, c, shapes) + return SystemOfShapes(x, A, b, c, shapes) \ No newline at end of file diff --git a/tests/test_analysis_mitigation.py b/tests/test_analysis_mitigation.py new file mode 100644 index 00000000..5ffb0285 --- /dev/null +++ b/tests/test_analysis_mitigation.py @@ -0,0 +1,140 @@ +# tests/test_integration_branch_flow_numeric.py +from __future__ import annotations +import numpy as np +import pytest +from numpy.testing import assert_allclose + +from odetoolbox.singularity_analysis_mitigation import ( + build_transition_branches_numeric, + BranchPack, +) + + +# Utilities (numeric version) + +def jordan_expected_numeric(alpha: float, h: float) -> np.ndarray: + """ + Expected Jordan-form propagator for the degenerate case a = b. + + System: + A = [[ alpha, 0, 0], + [ 1 , alpha, 0], + [ 0 , 1 , alpha]] + + Theoretical propagator: + P = exp(-h*A) = exp(-h*alpha) * [[1, 0, 0], + [h, 1, 0], + [h*h/2, h, 1]] + """ + E = np.exp(-h * alpha) + return E * np.array([[1.0, 0.0, 0.0], + [h, 1.0, 0.0], + [0.5*h*h, h, 1.0]], dtype=float) + + +TOL = dict(rtol=1e-10, atol=1e-12) + + + +# 1) Direct verification: when a = b, the branch propagator + +def test_numeric_branch_yields_jordan(): + # Equivalent system to the original symbolic test, but defined as a numeric function A_fn. + # Original: x' = -a x, y' = x - a y, z' = y - b z + # Numeric version: A = [[ a, 0, 0], + # [-1, a, 0], + # [ 0, -1, b]] + # This uses P = exp(-h*A), equivalent in meaning to the symbolic exp(-hA) form. + def A_fn(p): + a, b = p["a"], p["b"] + return np.array([[a, 0.0, 0.0], + [-1.0, a, 0.0], + [0.0, -1.0, b ]], dtype=float) + + h = 0.2 + base = {"a": 1.3, "b": 2.1} + + pack: BranchPack = build_transition_branches_numeric( + A_fn=A_fn, + h=h, + base_params=base, + param_names=["a", "b"], + conditions=[("a", "b")], # Only test the condition a = b + tie_mode="left", # Overwrite b with a’s value + ) + + # Extract the a=b branch + assert len(pack.branches) == 1 + br = pack.branches[0] + assert tuple(br.condition) == ("a", "b") + + # Expected Jordan-form propagator for a=b=base["a"] + alpha = base["a"] + P_expected = jordan_expected_numeric(alpha, h) + + # Numerical equivalence check + assert_allclose(br.P, P_expected, **TOL) + + + +# 2) Auto-generated “pairwise equality” condition test + +def test_numeric_branch_auto_conditions(): + def A_fn(p): + # Simple 3×3 upper-triangular chain structure + t1, t2, t3 = p["t1"], p["t2"], p["t3"] + return np.array([[t1, 0.0, 0.0], + [-1.0, t2, 0.0], + [0.0, -1.0, t3]], dtype=float) + + h = 0.05 + base = {"t1": 1.0, "t2": 1.0, "t3": 2.0} + + pack = build_transition_branches_numeric( + A_fn=A_fn, + h=h, + base_params=base, + param_names=["t1", "t2", "t3"], + conditions=None, + tie_mode="left", + ) + + # The automatically generated condition list should include ('t1', 't2') + conds = [tuple(c.condition) for c in pack.branches] + assert ("t1", "t2") in conds + + # Find the ('t1','t2') branch and verify that its top-left 2×2 block + # matches the analytical 2×2 Jordan block for parameter alpha = t1. + br = next(c for c in pack.branches if tuple(c.condition) == ("t1", "t2")) + alpha = base["t1"] + P_expected_2x2 = np.exp(-h*alpha) * np.array([[1.0, 0.0], + [h, 1.0]]) + assert_allclose(br.P[:2, :2], P_expected_2x2, **TOL) + + +# 3) Basic robustness test: shape and near-identity behavior + +def test_numeric_default_shape_and_sanity(): + def A_fn(p): + return np.array([[p["a"], 1.0], + [0.0, p["b"]]], dtype=float) + + h = 0.1 + base = {"a": 1.0, "b": 1.0} + + pack = build_transition_branches_numeric( + A_fn=A_fn, h=h, base_params=base, + param_names=["a", "b"], conditions=[("a","b")] + ) + + P0 = pack.default.P + assert P0.shape == (2, 2) + + # For small h, exp(-hA) ≈ I - hA; this is a mild sanity check (not a proof) + I = np.eye(2) + approx = I - h * A_fn(base) + # Only check magnitude consistency, not exact equality + assert np.linalg.norm(P0 - approx) < 1.0 + + + From 37909eb88b2ced7baeb8c13843f5788af6c2117a Mon Sep 17 00:00:00 2001 From: hzha0455 Date: Mon, 13 Oct 2025 20:48:32 +1100 Subject: [PATCH 2/2] Problem fixing --- odetoolbox/singularity_analysis_mitigation.py | 279 +++++++++++++----- odetoolbox/system_of_shapes.py | 170 ++++++----- tests/test_analysis_mitigation.py | 120 +++++--- 3 files changed, 387 insertions(+), 182 deletions(-) diff --git a/odetoolbox/singularity_analysis_mitigation.py b/odetoolbox/singularity_analysis_mitigation.py index 1e881cef..73eed39b 100644 --- a/odetoolbox/singularity_analysis_mitigation.py +++ b/odetoolbox/singularity_analysis_mitigation.py @@ -1,103 +1,250 @@ # odetoolbox/singularity_analysis_mitigation.py from __future__ import annotations -from dataclasses import dataclass -from typing import Callable, Dict, Iterable, List, Tuple, Optional +from dataclasses import dataclass, asdict +from typing import Callable, Dict, Iterable, List, Tuple, Optional, Set, FrozenSet, Any +from itertools import product +import json + import numpy as np from scipy.linalg import expm -# ---- Type definitions ---- -Params = Dict[str, float] -Cond = Tuple[str, str] -AFunction = Callable[[Params], np.ndarray] +# Type definitions +Params = Dict[str, float] +Cond = Tuple[str, str] +AFunction = Callable[[Params], np.ndarray] +SymRenderer = Callable[[FrozenSet[Cond], str], Any] + + +# Config (lightweight; no external deps) +class Config: + """ + Minimal configuration shim to avoid if-else on timestep symbols. + You may inject your own symbol string via build_* API. This class + exists only to provide a default. + """ + @staticmethod + def default_timestep_symbol() -> str: + return "h" -# ---- Data structures ---- -@dataclass -class BranchResult: - """Represents the result of one branch: a condition and its corresponding propagator matrix.""" - condition: Optional[Cond] - P: np.ndarray + +# Data structures +@dataclass(frozen=True) +class CaseResult: + """ + One case under a set of active equality constraints. + - active_equalities: frozenset of normalized pairs (min(a,b), max(a,b)) + - P: optional numeric preview of the propagator (expm(-h*A)) + - symbolic: optional symbolic/IR object produced by a renderer (no sympy required) + """ + active_equalities: FrozenSet[Cond] + P: Optional[np.ndarray] = None + symbolic: Optional[Any] = None @dataclass -class BranchPack: - """Container for all branches, including the default one.""" - default: BranchResult - branches: List[BranchResult] +class CaseSet: + """ + Container of the default (no constraints) and all expanded cases. + """ + timestep_symbol: str + default: CaseResult + cases: List[CaseResult] + expanded: bool + truncated: bool + max_cases: Optional[int] -# ---- Utility functions ---- +# Utilities def _ensure_square(A: np.ndarray) -> None: - """Ensure that the matrix A is square.""" if A.ndim != 2 or A.shape[0] != A.shape[1]: raise ValueError("A must be a square 2D array") def _matrix_exp(A: np.ndarray, h: float) -> np.ndarray: - """Compute the matrix exponential P = exp(-h * A).""" _ensure_square(A) return expm(-h * A) -def _tie_params(base: Params, cond: Cond, mode: str = "left") -> Params: +def _normalize_conds(conditions: Iterable[Cond]) -> List[Cond]: """ - Force parameters in cond = (a, b) to become equal according to the specified mode. - - Modes: - - 'left' : overwrite b with a's value - - 'right' : overwrite a with b's value - - 'avg' : replace both with their average value + Normalize conditions into sorted, deduplicated pairs so that ("b","a") -> ("a","b"). """ - a, b = cond - if a not in base or b not in base: - raise KeyError(f"Unknown parameter in condition: {cond}") - out = dict(base) - if mode == "left": - out[b] = out[a] - elif mode == "right": - out[a] = out[b] - elif mode == "avg": - v = 0.5 * (out[a] + out[b]) - out[a] = out[b] = v - else: - raise ValueError("mode must be one of: 'left', 'right', or 'avg'") + norm = sorted({tuple(sorted((a, b))) for (a, b) in conditions}) + # ensure no self-equalities like ("a","a") + for a, b in norm: + if a == b: + raise ValueError(f"Invalid equality condition ({a},{b}): identical names are not allowed.") + return norm + +def _expand_truth_table( + conds: List[Cond], + include_empty: bool = False, + max_cases: Optional[int] = None +) -> List[FrozenSet[Cond]]: + """ + Return all truth assignments as the set of 'true' equalities. + If include_empty is False, the empty set is skipped. + """ + if not conds: + return [frozenset()] if include_empty else [] + + true_sets: List[FrozenSet[Cond]] = [] + for bits in product([False, True], repeat=len(conds)): + active = frozenset(conds[i] for i, flag in enumerate(bits) if flag) + if not include_empty and len(active) == 0: + continue + true_sets.append(active) + if max_cases is not None and len(true_sets) >= max_cases: + break + return true_sets + +def _apply_numeric_ties(base_params: Params, true_eqs: FrozenSet[Cond], tie_mode: str = "left") -> Params: + """ + Numeric-layer tie for a set of equalities: + - 'left' : set right := left (b := a) + - 'right' : set left := right (a := b) + """ + out = dict(base_params) + for (a, b) in sorted(true_eqs): + if a not in out or b not in out: + raise KeyError(f"Missing parameter for tie: {a} or {b}") + if tie_mode == "left": + out[b] = out[a] + elif tie_mode == "right": + out[a] = out[b] + else: + raise ValueError("tie_mode must be 'left' or 'right'") return out +def _infer_all_pairwise(param_names: Iterable[str]) -> List[Cond]: + names = list(param_names) + conds: List[Cond] = [] + for i in range(len(names)): + for j in range(i + 1, len(names)): + a, b = names[i], names[j] + if a != b: + conds.append(tuple(sorted((a, b)))) + return sorted(set(conds)) + -# ---- Main process---- -def build_transition_branches_numeric( +# Main API +def build_transition_cases( A_fn: AFunction, h: float, base_params: Params, - param_names: Iterable[str], + *, + # condition sources conditions: Optional[Iterable[Cond]] = None, + param_names: Optional[Iterable[str]] = None, + # expansion behaviour + expand_truth_table: bool = True, + max_cases: Optional[int] = 4096, + # numeric preview behaviour + numeric_preview: bool = True, tie_mode: str = "left", -) -> BranchPack: + # symbolic behaviour + timestep_symbol: Optional[str] = None, + symbolic_renderer: Optional[SymRenderer] = None, +) -> CaseSet: """ - Construct default and conditional transition branches numerically. - - Steps: - 1) Given a system x' = A x, build A using A_fn(base_params). - 2) Compute the default propagator P0 = exp(-h * A). - 3) If no conditions are provided, generate all pairwise equality candidates. - 4) For each condition, make the corresponding parameters equal, - rebuild A' = A_fn(params'), and compute P' = exp(-h * A'). - 5) Collect all branches into a BranchPack object. + Construct the default case plus all expanded cases from + the truth-table of provided conditions. + + - No sympy reliance. If you need a symbolic object, supply `symbolic_renderer`. + - Numeric preview uses expm(-h*A(params_tied)) for each case. + - Conditions may be explicitly provided. If omitted, and `param_names` is given, + we infer all pairwise equalities. + + Returns: + CaseSet with: + - default: active_equalities = empty set + - cases: list of CaseResult with active_equalities carrying the constraints + - expanded/truncated/max_cases metadata """ - # Step 1–2: default branch + step_sym = timestep_symbol or Config.default_timestep_symbol() + + #default case A0 = np.asarray(A_fn(dict(base_params)), dtype=float) - P0 = _matrix_exp(A0, h) - default = BranchResult(condition=None, P=P0) + P0 = _matrix_exp(A0, h) if numeric_preview else None + default = CaseResult(active_equalities=frozenset(), P=P0, + symbolic=(symbolic_renderer(frozenset(), step_sym) if symbolic_renderer else None)) - # Step 3: generate pairwise equality conditions if none provided - names = list(param_names) + #canonicalize conditions if conditions is None: - conditions = [(names[i], names[j]) for i in range(len(names)) for j in range(i + 1, len(names))] + if not param_names: + # no conditions at all + return CaseSet( + timestep_symbol=step_sym, + default=default, + cases=[], + expanded=False, + truncated=False, + max_cases=None, + ) + else: + conditions = _infer_all_pairwise(param_names) + + conds = _normalize_conds(conditions) + + # expansion list + if not expand_truth_table: + # one-at-a-time activation + actives = [frozenset([c]) for c in conds] + truncated = False + else: + actives = _expand_truth_table(conds, include_empty=False, max_cases=max_cases) + truncated = (max_cases is not None and len(actives) >= max_cases) + + #build cases + cases: List[CaseResult] = [] + for active_eqs in actives: + params_ = _apply_numeric_ties(base_params, active_eqs, tie_mode=tie_mode) if numeric_preview else base_params + A_ = np.asarray(A_fn(params_), dtype=float) if numeric_preview else None + P_ = _matrix_exp(A_, h) if (numeric_preview and A_ is not None) else None + sym_obj = symbolic_renderer(active_eqs, step_sym) if symbolic_renderer else None + cases.append(CaseResult(active_equalities=active_eqs, P=P_, symbolic=sym_obj)) + + return CaseSet( + timestep_symbol=step_sym, + default=default, + cases=cases, + expanded=True, + truncated=truncated, + max_cases=max_cases if expand_truth_table else None, + ) + +# JSON encoding helpers +def _ndarray_to_list(arr: Optional[np.ndarray]) -> Optional[List[List[float]]]: + if arr is None: + return None + return arr.tolist() + +def case_set_to_json(case_set: CaseSet) -> str: + payload = { + "timestep_symbol": case_set.timestep_symbol, + "expanded": case_set.expanded, + "truncated": case_set.truncated, + "max_cases": case_set.max_cases, + "default": { + "active_equalities": sorted(list(case_set.default.active_equalities)), + "P": _ndarray_to_list(case_set.default.P), + "symbolic": case_set.default.symbolic if _is_jsonable(case_set.default.symbolic) else None, + }, + "cases": [ + { + "active_equalities": sorted(list(cr.active_equalities)), + "P": _ndarray_to_list(cr.P), + "symbolic": cr.symbolic if _is_jsonable(cr.symbolic) else None, + } + for cr in case_set.cases + ], + } + return json.dumps(payload, ensure_ascii=False, separators=(",", ":")) + +def _is_jsonable(x: Any) -> bool: + try: + json.dumps(x) + return True + except Exception: + return False - # Step 4–5: compute each conditional branch - branches: List[BranchResult] = [] - for cond in conditions: - params_ = _tie_params(base_params, cond, mode=tie_mode) - A_ = np.asarray(A_fn(params_), dtype=float) - P_ = _matrix_exp(A_, h) - branches.append(BranchResult(condition=tuple(sorted(cond)), P=P_)) - return BranchPack(default=default, branches=branches) diff --git a/odetoolbox/system_of_shapes.py b/odetoolbox/system_of_shapes.py index c3f079fe..24521d1f 100644 --- a/odetoolbox/system_of_shapes.py +++ b/odetoolbox/system_of_shapes.py @@ -34,8 +34,8 @@ from .singularity_detection import SingularityDetection, SingularityDetectionException from .sympy_helpers import _custom_simplify_expr, _is_zero from odetoolbox.singularity_analysis_mitigation import ( - build_transition_branches_numeric as _build_branches_numeric, - BranchResult as _BR, + build_transition_cases as _build_cases, + CaseResult as _CR, ) @@ -246,7 +246,7 @@ def generate_propagator_solver(self, disable_singularity_detection: bool = False conditions = SingularityDetection.find_propagator_singularities(P, self.A_) if conditions: - # if there is one or more condition under which the solution goes to infinity... + # if there is one or more condition under which the solution goes to infinity logging.warning("Under certain conditions, the propagator matrix is singular (contains infinities).") logging.warning("List of all conditions that result in a division by zero:") @@ -277,7 +277,7 @@ def generate_propagator_solver(self, disable_singularity_detection: bool = False P_sym[row, col] = sympy.parsing.sympy_parser.parse_expr(sym_str, global_dict=Shape._sympy_globals) P_expr[sym_str] = P[row, col] if row != col and not _is_zero(self.b_[col]): - # the ODE for x_[row] depends on the inhomogeneous ODE of x_[col]. We can't solve this analytically in the general case (even though some specific cases might admit a solution) + # the ODE for x_[row] depends on the inhomogeneous ODE of x_[col]. We can't solve this analytically in the general case raise PropagatorGenerationException("the ODE for " + str(self.x_[row]) + " depends on the inhomogeneous ODE of " + str(self.x_[col]) + ". We can't solve this analytically in the general case (even though some specific cases might admit a solution)") update_expr_terms.append(sym_str + " * " + str(self.x_[col])) @@ -320,90 +320,116 @@ def generate_propagator_solver(self, disable_singularity_detection: bool = False "state_variables": all_state_symbols, "initial_values": initial_values} -#------ +#------numeric case expansion without sympy ------ try: - # 1) Re-run singularity detection - _conds_pairs = [] + # 1) Re-run singularity detection and extract parameter equalities + cond_pairs: List[Tuple[str, str]] = [] try: _conds = SingularityDetection.find_propagator_singularities(P, self.A_) - # Flatten sets like {a=b, c=d} into a list of tuples: [("a","b"), ("c","d"), ...] + # _conds may be a list of sets, each containing equality objects like "a=b" for _set in _conds or []: for _eq in _set: - _L, _R = getattr(_eq, "lhs", None), getattr(_eq, "rhs", None) - if isinstance(_L, sympy.Symbol) and isinstance(_R, sympy.Symbol): - _conds_pairs.append((str(_L), str(_R))) + L, R = getattr(_eq, "lhs", None), getattr(_eq, "rhs", None) + # Convert to strings, do not rely on sympy types + if L is not None and R is not None: + a, b = str(L), str(R) + if a != b: + cond_pairs.append(tuple(sorted((a, b)))) except Exception: - _conds_pairs = [] - - # If no detectable parameter equality conditions are found, skip branching - if _conds_pairs: - # 2) Collect parameter symbols - _Hsym = sympy.Symbol(Config().output_timestep_symbol, real=True) - _xset = set(self.x_) if hasattr(self, "x_") else set() - _A_syms = sorted( - [s for s in self.A_.free_symbols if s not in _xset and s != _Hsym], - key=lambda s: str(s) - ) - _param_names = [str(s) for s in _A_syms] - - # 3) Retrieve base parameter values - def _get_param_val(_s: sympy.Symbol) -> float: - _k = str(_s) - if hasattr(self, "parameter_values") and _k in getattr(self, "parameter_values"): - return float(self.parameter_values[_k]) - if hasattr(self, "_params") and _k in getattr(self, "_params"): - return float(self._params[_k]) - # Missing parameter: raise an error to skip mitigation - raise KeyError(f"Missing numeric value for parameter: {_k}") - - _base_params = {str(s): _get_param_val(s) for s in _A_syms} - - # 4) Construct numerical function A - _A_lmb = sympy.lambdify(_A_syms, self.A_, modules="numpy") - - def _A_fn(_params_dict): - _vals = [_params_dict[name] for name in _param_names] - return np.array(_A_lmb(*_vals), dtype=float) - - # 5) Determine timestep h: prefer self.h / self.dt; otherwise use H symbol or fallback to 1.0 - if hasattr(self, "h"): - _h_val = float(self.h) - elif hasattr(self, "dt"): - _h_val = float(self.dt) + cond_pairs = [] + + # 2) Skip if no parameter equalities are found + if cond_pairs: + # 2.1 Collect parameter names and values + + H_name = Config().output_timestep_symbol + state_names = set(str(s) for s in getattr(self, "x_", [])) + base_param_src = getattr(self, "parameter_values", {}) or {} + base_params = {} + for k, v in base_param_src.items(): + if k in state_names or k == H_name: + continue + try: + base_params[k] = float(v) + except Exception: + # Ignore non-numeric values for numeric preview + pass + + param_names = sorted(base_params.keys()) + + # 2.2 Get timestep value + if H_name in base_param_src: + try: + h_val = float(base_param_src[H_name]) + except Exception as e: + raise KeyError(f"Invalid numeric timestep '{H_name}': {base_param_src[H_name]}") from e else: - _h_val = float(_base_params.get(str(_Hsym), 1.0)) - - # 6) Build numeric transition branches - _pack = _build_branches_numeric( - A_fn=_A_fn, - h=_h_val, - base_params=_base_params, - param_names=_param_names, - conditions=[tuple(sorted(p)) for p in _conds_pairs], - tie_mode="left", # Can also be "right" or "avg" - ) + # If no timestep is provided, we will only output conditions/metadata + h_val = None + + # 2.3 Define A_fn for numeric preview if available + def _missing_A_fn(_): + raise NotImplementedError("Numeric A(params) is not available.") - # 7) Insert branch results into solver_dict - def _to_list(_P): - # Convert NumPy array to list for JSON serialization - return _P.tolist() + A_fn = None + for cand in ("A_numeric_fn", "build_A_numeric", "A_fn_numeric"): + if hasattr(self, cand) and callable(getattr(self, cand)): + A_fn = getattr(self, cand) + break - def _pack_branch(_br: _BR): - _cond = None if _br.condition is None else {"eq": [_br.condition[0], _br.condition[1]]} - return {"condition": _cond, "P": _to_list(_br.P)} + numeric_preview_enabled = (A_fn is not None and h_val is not None) + + # 2.4 Use unified truth-table expansion + from odetoolbox.singularity_analysis_mitigation import ( + build_transition_cases, + ) - solver_dict["branching"] = { - "default": _pack_branch(_pack.default), - "branches": [_pack_branch(_b) for _b in _pack.branches], - "tie_mode": "left", + case_set = build_transition_cases( + A_fn=A_fn if numeric_preview_enabled else (lambda _: np.zeros_like(P, dtype=float)), + h=h_val if numeric_preview_enabled else 0.0, + base_params=base_params, + conditions=[tuple(sorted(p)) for p in cond_pairs], + param_names=param_names, + expand_truth_table=True, + max_cases=4096, + numeric_preview=bool(numeric_preview_enabled), + tie_mode="left", + timestep_symbol=H_name, + symbolic_renderer=None, + ) + + # 2.5 Pack results into solver_dict + def _pack_case(cr): + return { + "active_equalities": sorted(list(cr.active_equalities)), + "P": (None if cr.P is None else cr.P.tolist()), + } + + solver_dict.setdefault("meta", {}) + solver_dict["meta"]["timestep_symbol"] = H_name + if h_val is not None: + solver_dict["meta"]["timestep_value"] = h_val + + solver_dict["condition_cases"] = { + "default_case": _pack_case(case_set.default), + "cases": [_pack_case(cr) for cr in case_set.cases], + "meta": { + "evaluation": "numeric" if numeric_preview_enabled else "structure-only", + "tie_mode": "left", + "expanded": case_set.expanded, + "truncated": case_set.truncated, + "max_cases": case_set.max_cases, + "truth_table_size": len(case_set.cases), + } } + except Exception as _mit_err: - # Mitigation failure does not affect the original return - logging.debug(f"[mitigation] numeric branching skipped: {_mit_err}") + logging.debug(f"[mitigation] case expansion skipped: {_mit_err}") return solver_dict + def generate_numeric_solver(self, state_variables=None): r""" Generate the symbolic expressions for numeric integration state updates; return as JSON. diff --git a/tests/test_analysis_mitigation.py b/tests/test_analysis_mitigation.py index 5ffb0285..c0ce9741 100644 --- a/tests/test_analysis_mitigation.py +++ b/tests/test_analysis_mitigation.py @@ -5,12 +5,46 @@ from numpy.testing import assert_allclose from odetoolbox.singularity_analysis_mitigation import ( - build_transition_branches_numeric, - BranchPack, + build_transition_cases, + CaseSet, ) +# ----------------- helpers ----------------- -# Utilities (numeric version) +def _call_build_transition_cases(**kwargs): + """ + Compatibility shim: + - If build_transition_cases returns a CaseSet, wrap it as (CaseSet, None) + - If it already returns (CaseSet, assignments), pass through. + """ + out = build_transition_cases(**kwargs) + if isinstance(out, CaseSet): + return out, None + # tuple or other iterable with CaseSet first + try: + cs, assignments = out + if isinstance(cs, CaseSet): + return cs, assignments + except Exception: + pass + raise TypeError("build_transition_cases must return CaseSet or (CaseSet, assignments)") + +def _get_condition_tuple(case) -> tuple[str, str] | None: + """ + Return ('a','b') if exactly one equality is active; otherwise None. + Works with either: + - case.condition (if provided by implementation), or + - case.active_equalities (frozenset of pairs) + """ + if hasattr(case, "condition") and case.condition is not None: + return tuple(case.condition) + if hasattr(case, "active_equalities"): + eqs = getattr(case, "active_equalities") + if isinstance(eqs, (set, frozenset)) and len(eqs) == 1: + return tuple(next(iter(eqs))) + return None + +# ----------------- utilities ----------------- def jordan_expected_numeric(alpha: float, h: float) -> np.ndarray: """ @@ -31,20 +65,14 @@ def jordan_expected_numeric(alpha: float, h: float) -> np.ndarray: [h, 1.0, 0.0], [0.5*h*h, h, 1.0]], dtype=float) - TOL = dict(rtol=1e-10, atol=1e-12) +# ----------------- tests ----------------- - -# 1) Direct verification: when a = b, the branch propagator - -def test_numeric_branch_yields_jordan(): - # Equivalent system to the original symbolic test, but defined as a numeric function A_fn. - # Original: x' = -a x, y' = x - a y, z' = y - b z - # Numeric version: A = [[ a, 0, 0], - # [-1, a, 0], - # [ 0, -1, b]] - # This uses P = exp(-h*A), equivalent in meaning to the symbolic exp(-hA) form. +def test_numeric_case_yields_jordan(): + """ + Check that when a=b, the propagated matrix matches the analytical Jordan form. + """ def A_fn(p): a, b = p["a"], p["b"] return np.array([[a, 0.0, 0.0], @@ -54,32 +82,34 @@ def A_fn(p): h = 0.2 base = {"a": 1.3, "b": 2.1} - pack: BranchPack = build_transition_branches_numeric( + case_set, _ = _call_build_transition_cases( A_fn=A_fn, h=h, base_params=base, param_names=["a", "b"], - conditions=[("a", "b")], # Only test the condition a = b - tie_mode="left", # Overwrite b with a’s value + conditions=[("a", "b")], + tie_mode="left", + expand_truth_table=False, ) - # Extract the a=b branch - assert len(pack.branches) == 1 - br = pack.branches[0] - assert tuple(br.condition) == ("a", "b") + # Extract the a=b case + assert len(case_set.cases) == 1 + case = case_set.cases[0] + assert _get_condition_tuple(case) == ("a", "b") # Expected Jordan-form propagator for a=b=base["a"] alpha = base["a"] P_expected = jordan_expected_numeric(alpha, h) # Numerical equivalence check - assert_allclose(br.P, P_expected, **TOL) - - + assert_allclose(case.P, P_expected, **TOL) -# 2) Auto-generated “pairwise equality” condition test -def test_numeric_branch_auto_conditions(): +def test_numeric_case_auto_conditions(): + """ + Verify that the automatically generated pairwise equality conditions + include ('t1', 't2') and that its sub-block matches the 2×2 Jordan form. + """ def A_fn(p): # Simple 3×3 upper-triangular chain structure t1, t2, t3 = p["t1"], p["t2"], p["t3"] @@ -88,33 +118,35 @@ def A_fn(p): [0.0, -1.0, t3]], dtype=float) h = 0.05 - base = {"t1": 1.0, "t2": 1.0, "t3": 2.0} + base = {"t1": 1.0, "t2": 1.0, "t3": 2.0} - pack = build_transition_branches_numeric( + case_set, _ = _call_build_transition_cases( A_fn=A_fn, h=h, base_params=base, param_names=["t1", "t2", "t3"], - conditions=None, + conditions=None, tie_mode="left", + expand_truth_table=False, ) # The automatically generated condition list should include ('t1', 't2') - conds = [tuple(c.condition) for c in pack.branches] + conds = [_get_condition_tuple(c) for c in case_set.cases] assert ("t1", "t2") in conds - # Find the ('t1','t2') branch and verify that its top-left 2×2 block + # Find the ('t1','t2') case and verify that its top-left 2×2 block # matches the analytical 2×2 Jordan block for parameter alpha = t1. - br = next(c for c in pack.branches if tuple(c.condition) == ("t1", "t2")) + c = next(c for c in case_set.cases if _get_condition_tuple(c) == ("t1", "t2")) alpha = base["t1"] P_expected_2x2 = np.exp(-h*alpha) * np.array([[1.0, 0.0], [h, 1.0]]) - assert_allclose(br.P[:2, :2], P_expected_2x2, **TOL) - + assert_allclose(c.P[:2, :2], P_expected_2x2, **TOL) -# 3) Basic robustness test: shape and near-identity behavior def test_numeric_default_shape_and_sanity(): + """ + For small h, exp(-hA) ≈ I - hA. This serves as a numerical sanity check. + """ def A_fn(p): return np.array([[p["a"], 1.0], [0.0, p["b"]]], dtype=float) @@ -122,19 +154,19 @@ def A_fn(p): h = 0.1 base = {"a": 1.0, "b": 1.0} - pack = build_transition_branches_numeric( - A_fn=A_fn, h=h, base_params=base, - param_names=["a", "b"], conditions=[("a","b")] + case_set, _ = _call_build_transition_cases( + A_fn=A_fn, + h=h, + base_params=base, + param_names=["a", "b"], + conditions=[("a", "b")], + expand_truth_table=False, ) - P0 = pack.default.P + P0 = case_set.default.P assert P0.shape == (2, 2) - # For small h, exp(-hA) ≈ I - hA; this is a mild sanity check (not a proof) + # For small h, exp(-hA) ≈ I - hA; mild consistency check I = np.eye(2) approx = I - h * A_fn(base) - # Only check magnitude consistency, not exact equality assert np.linalg.norm(P0 - approx) < 1.0 - - -