diff --git a/odetoolbox/singularity_analysis_mitigation.py b/odetoolbox/singularity_analysis_mitigation.py new file mode 100644 index 00000000..73eed39b --- /dev/null +++ b/odetoolbox/singularity_analysis_mitigation.py @@ -0,0 +1,250 @@ +# odetoolbox/singularity_analysis_mitigation.py +from __future__ import annotations +from dataclasses import dataclass, asdict +from typing import Callable, Dict, Iterable, List, Tuple, Optional, Set, FrozenSet, Any +from itertools import product +import json + +import numpy as np +from scipy.linalg import expm + + +# Type definitions +Params = Dict[str, float] +Cond = Tuple[str, str] +AFunction = Callable[[Params], np.ndarray] +SymRenderer = Callable[[FrozenSet[Cond], str], Any] + + +# Config (lightweight; no external deps) +class Config: + """ + Minimal configuration shim to avoid if-else on timestep symbols. + You may inject your own symbol string via build_* API. This class + exists only to provide a default. + """ + @staticmethod + def default_timestep_symbol() -> str: + return "h" + + +# Data structures +@dataclass(frozen=True) +class CaseResult: + """ + One case under a set of active equality constraints. + - active_equalities: frozenset of normalized pairs (min(a,b), max(a,b)) + - P: optional numeric preview of the propagator (expm(-h*A)) + - symbolic: optional symbolic/IR object produced by a renderer (no sympy required) + """ + active_equalities: FrozenSet[Cond] + P: Optional[np.ndarray] = None + symbolic: Optional[Any] = None + +@dataclass +class CaseSet: + """ + Container of the default (no constraints) and all expanded cases. + """ + timestep_symbol: str + default: CaseResult + cases: List[CaseResult] + expanded: bool + truncated: bool + max_cases: Optional[int] + + +# Utilities +def _ensure_square(A: np.ndarray) -> None: + if A.ndim != 2 or A.shape[0] != A.shape[1]: + raise ValueError("A must be a square 2D array") + +def _matrix_exp(A: np.ndarray, h: float) -> np.ndarray: + _ensure_square(A) + return expm(-h * A) + +def _normalize_conds(conditions: Iterable[Cond]) -> List[Cond]: + """ + Normalize conditions into sorted, deduplicated pairs so that ("b","a") -> ("a","b"). + """ + norm = sorted({tuple(sorted((a, b))) for (a, b) in conditions}) + # ensure no self-equalities like ("a","a") + for a, b in norm: + if a == b: + raise ValueError(f"Invalid equality condition ({a},{b}): identical names are not allowed.") + return norm + +def _expand_truth_table( + conds: List[Cond], + include_empty: bool = False, + max_cases: Optional[int] = None +) -> List[FrozenSet[Cond]]: + """ + Return all truth assignments as the set of 'true' equalities. + If include_empty is False, the empty set is skipped. + """ + if not conds: + return [frozenset()] if include_empty else [] + + true_sets: List[FrozenSet[Cond]] = [] + for bits in product([False, True], repeat=len(conds)): + active = frozenset(conds[i] for i, flag in enumerate(bits) if flag) + if not include_empty and len(active) == 0: + continue + true_sets.append(active) + if max_cases is not None and len(true_sets) >= max_cases: + break + return true_sets + +def _apply_numeric_ties(base_params: Params, true_eqs: FrozenSet[Cond], tie_mode: str = "left") -> Params: + """ + Numeric-layer tie for a set of equalities: + - 'left' : set right := left (b := a) + - 'right' : set left := right (a := b) + """ + out = dict(base_params) + for (a, b) in sorted(true_eqs): + if a not in out or b not in out: + raise KeyError(f"Missing parameter for tie: {a} or {b}") + if tie_mode == "left": + out[b] = out[a] + elif tie_mode == "right": + out[a] = out[b] + else: + raise ValueError("tie_mode must be 'left' or 'right'") + return out + +def _infer_all_pairwise(param_names: Iterable[str]) -> List[Cond]: + names = list(param_names) + conds: List[Cond] = [] + for i in range(len(names)): + for j in range(i + 1, len(names)): + a, b = names[i], names[j] + if a != b: + conds.append(tuple(sorted((a, b)))) + return sorted(set(conds)) + + +# Main API +def build_transition_cases( + A_fn: AFunction, + h: float, + base_params: Params, + *, + # condition sources + conditions: Optional[Iterable[Cond]] = None, + param_names: Optional[Iterable[str]] = None, + # expansion behaviour + expand_truth_table: bool = True, + max_cases: Optional[int] = 4096, + # numeric preview behaviour + numeric_preview: bool = True, + tie_mode: str = "left", + # symbolic behaviour + timestep_symbol: Optional[str] = None, + symbolic_renderer: Optional[SymRenderer] = None, +) -> CaseSet: + """ + Construct the default case plus all expanded cases from + the truth-table of provided conditions. + + - No sympy reliance. If you need a symbolic object, supply `symbolic_renderer`. + - Numeric preview uses expm(-h*A(params_tied)) for each case. + - Conditions may be explicitly provided. If omitted, and `param_names` is given, + we infer all pairwise equalities. + + Returns: + CaseSet with: + - default: active_equalities = empty set + - cases: list of CaseResult with active_equalities carrying the constraints + - expanded/truncated/max_cases metadata + """ + step_sym = timestep_symbol or Config.default_timestep_symbol() + + #default case + A0 = np.asarray(A_fn(dict(base_params)), dtype=float) + P0 = _matrix_exp(A0, h) if numeric_preview else None + default = CaseResult(active_equalities=frozenset(), P=P0, + symbolic=(symbolic_renderer(frozenset(), step_sym) if symbolic_renderer else None)) + + #canonicalize conditions + if conditions is None: + if not param_names: + # no conditions at all + return CaseSet( + timestep_symbol=step_sym, + default=default, + cases=[], + expanded=False, + truncated=False, + max_cases=None, + ) + else: + conditions = _infer_all_pairwise(param_names) + + conds = _normalize_conds(conditions) + + # expansion list + if not expand_truth_table: + # one-at-a-time activation + actives = [frozenset([c]) for c in conds] + truncated = False + else: + actives = _expand_truth_table(conds, include_empty=False, max_cases=max_cases) + truncated = (max_cases is not None and len(actives) >= max_cases) + + #build cases + cases: List[CaseResult] = [] + for active_eqs in actives: + params_ = _apply_numeric_ties(base_params, active_eqs, tie_mode=tie_mode) if numeric_preview else base_params + A_ = np.asarray(A_fn(params_), dtype=float) if numeric_preview else None + P_ = _matrix_exp(A_, h) if (numeric_preview and A_ is not None) else None + sym_obj = symbolic_renderer(active_eqs, step_sym) if symbolic_renderer else None + cases.append(CaseResult(active_equalities=active_eqs, P=P_, symbolic=sym_obj)) + + return CaseSet( + timestep_symbol=step_sym, + default=default, + cases=cases, + expanded=True, + truncated=truncated, + max_cases=max_cases if expand_truth_table else None, + ) + +# JSON encoding helpers +def _ndarray_to_list(arr: Optional[np.ndarray]) -> Optional[List[List[float]]]: + if arr is None: + return None + return arr.tolist() + +def case_set_to_json(case_set: CaseSet) -> str: + payload = { + "timestep_symbol": case_set.timestep_symbol, + "expanded": case_set.expanded, + "truncated": case_set.truncated, + "max_cases": case_set.max_cases, + "default": { + "active_equalities": sorted(list(case_set.default.active_equalities)), + "P": _ndarray_to_list(case_set.default.P), + "symbolic": case_set.default.symbolic if _is_jsonable(case_set.default.symbolic) else None, + }, + "cases": [ + { + "active_equalities": sorted(list(cr.active_equalities)), + "P": _ndarray_to_list(cr.P), + "symbolic": cr.symbolic if _is_jsonable(cr.symbolic) else None, + } + for cr in case_set.cases + ], + } + return json.dumps(payload, ensure_ascii=False, separators=(",", ":")) + +def _is_jsonable(x: Any) -> bool: + try: + json.dumps(x) + return True + except Exception: + return False + + + diff --git a/odetoolbox/system_of_shapes.py b/odetoolbox/system_of_shapes.py index 8e970e29..24521d1f 100644 --- a/odetoolbox/system_of_shapes.py +++ b/odetoolbox/system_of_shapes.py @@ -33,6 +33,11 @@ from .shapes import Shape from .singularity_detection import SingularityDetection, SingularityDetectionException from .sympy_helpers import _custom_simplify_expr, _is_zero +from odetoolbox.singularity_analysis_mitigation import ( + build_transition_cases as _build_cases, + CaseResult as _CR, +) + class GetBlockDiagonalException(Exception): @@ -233,16 +238,15 @@ def generate_propagator_solver(self, disable_singularity_detection: bool = False P = self._generate_propagator_matrix(self.A_) - # - # singularity detection - # + + if not disable_singularity_detection: try: conditions = SingularityDetection.find_propagator_singularities(P, self.A_) if conditions: - # if there is one or more condition under which the solution goes to infinity... + # if there is one or more condition under which the solution goes to infinity logging.warning("Under certain conditions, the propagator matrix is singular (contains infinities).") logging.warning("List of all conditions that result in a division by zero:") @@ -273,7 +277,7 @@ def generate_propagator_solver(self, disable_singularity_detection: bool = False P_sym[row, col] = sympy.parsing.sympy_parser.parse_expr(sym_str, global_dict=Shape._sympy_globals) P_expr[sym_str] = P[row, col] if row != col and not _is_zero(self.b_[col]): - # the ODE for x_[row] depends on the inhomogeneous ODE of x_[col]. We can't solve this analytically in the general case (even though some specific cases might admit a solution) + # the ODE for x_[row] depends on the inhomogeneous ODE of x_[col]. We can't solve this analytically in the general case raise PropagatorGenerationException("the ODE for " + str(self.x_[row]) + " depends on the inhomogeneous ODE of " + str(self.x_[col]) + ". We can't solve this analytically in the general case (even though some specific cases might admit a solution)") update_expr_terms.append(sym_str + " * " + str(self.x_[col])) @@ -316,9 +320,116 @@ def generate_propagator_solver(self, disable_singularity_detection: bool = False "state_variables": all_state_symbols, "initial_values": initial_values} +#------numeric case expansion without sympy ------ + try: + # 1) Re-run singularity detection and extract parameter equalities + cond_pairs: List[Tuple[str, str]] = [] + try: + _conds = SingularityDetection.find_propagator_singularities(P, self.A_) + # _conds may be a list of sets, each containing equality objects like "a=b" + for _set in _conds or []: + for _eq in _set: + L, R = getattr(_eq, "lhs", None), getattr(_eq, "rhs", None) + # Convert to strings, do not rely on sympy types + if L is not None and R is not None: + a, b = str(L), str(R) + if a != b: + cond_pairs.append(tuple(sorted((a, b)))) + except Exception: + cond_pairs = [] + + # 2) Skip if no parameter equalities are found + if cond_pairs: + # 2.1 Collect parameter names and values + + H_name = Config().output_timestep_symbol + state_names = set(str(s) for s in getattr(self, "x_", [])) + base_param_src = getattr(self, "parameter_values", {}) or {} + base_params = {} + for k, v in base_param_src.items(): + if k in state_names or k == H_name: + continue + try: + base_params[k] = float(v) + except Exception: + # Ignore non-numeric values for numeric preview + pass + + param_names = sorted(base_params.keys()) + + # 2.2 Get timestep value + if H_name in base_param_src: + try: + h_val = float(base_param_src[H_name]) + except Exception as e: + raise KeyError(f"Invalid numeric timestep '{H_name}': {base_param_src[H_name]}") from e + else: + # If no timestep is provided, we will only output conditions/metadata + h_val = None + + # 2.3 Define A_fn for numeric preview if available + def _missing_A_fn(_): + raise NotImplementedError("Numeric A(params) is not available.") + + A_fn = None + for cand in ("A_numeric_fn", "build_A_numeric", "A_fn_numeric"): + if hasattr(self, cand) and callable(getattr(self, cand)): + A_fn = getattr(self, cand) + break + + numeric_preview_enabled = (A_fn is not None and h_val is not None) + + # 2.4 Use unified truth-table expansion + from odetoolbox.singularity_analysis_mitigation import ( + build_transition_cases, + ) + + case_set = build_transition_cases( + A_fn=A_fn if numeric_preview_enabled else (lambda _: np.zeros_like(P, dtype=float)), + h=h_val if numeric_preview_enabled else 0.0, + base_params=base_params, + conditions=[tuple(sorted(p)) for p in cond_pairs], + param_names=param_names, + expand_truth_table=True, + max_cases=4096, + numeric_preview=bool(numeric_preview_enabled), + tie_mode="left", + timestep_symbol=H_name, + symbolic_renderer=None, + ) + + # 2.5 Pack results into solver_dict + def _pack_case(cr): + return { + "active_equalities": sorted(list(cr.active_equalities)), + "P": (None if cr.P is None else cr.P.tolist()), + } + + solver_dict.setdefault("meta", {}) + solver_dict["meta"]["timestep_symbol"] = H_name + if h_val is not None: + solver_dict["meta"]["timestep_value"] = h_val + + solver_dict["condition_cases"] = { + "default_case": _pack_case(case_set.default), + "cases": [_pack_case(cr) for cr in case_set.cases], + "meta": { + "evaluation": "numeric" if numeric_preview_enabled else "structure-only", + "tie_mode": "left", + "expanded": case_set.expanded, + "truncated": case_set.truncated, + "max_cases": case_set.max_cases, + "truth_table_size": len(case_set.cases), + } + } + + except Exception as _mit_err: + logging.debug(f"[mitigation] case expansion skipped: {_mit_err}") + return solver_dict + def generate_numeric_solver(self, state_variables=None): r""" Generate the symbolic expressions for numeric integration state updates; return as JSON. @@ -451,4 +562,4 @@ def from_shapes(cls, shapes: List[Shape], parameters=None): i += shape.order - return SystemOfShapes(x, A, b, c, shapes) + return SystemOfShapes(x, A, b, c, shapes) \ No newline at end of file diff --git a/tests/test_analysis_mitigation.py b/tests/test_analysis_mitigation.py new file mode 100644 index 00000000..c0ce9741 --- /dev/null +++ b/tests/test_analysis_mitigation.py @@ -0,0 +1,172 @@ +# tests/test_integration_branch_flow_numeric.py +from __future__ import annotations +import numpy as np +import pytest +from numpy.testing import assert_allclose + +from odetoolbox.singularity_analysis_mitigation import ( + build_transition_cases, + CaseSet, +) + +# ----------------- helpers ----------------- + +def _call_build_transition_cases(**kwargs): + """ + Compatibility shim: + - If build_transition_cases returns a CaseSet, wrap it as (CaseSet, None) + - If it already returns (CaseSet, assignments), pass through. + """ + out = build_transition_cases(**kwargs) + if isinstance(out, CaseSet): + return out, None + # tuple or other iterable with CaseSet first + try: + cs, assignments = out + if isinstance(cs, CaseSet): + return cs, assignments + except Exception: + pass + raise TypeError("build_transition_cases must return CaseSet or (CaseSet, assignments)") + +def _get_condition_tuple(case) -> tuple[str, str] | None: + """ + Return ('a','b') if exactly one equality is active; otherwise None. + Works with either: + - case.condition (if provided by implementation), or + - case.active_equalities (frozenset of pairs) + """ + if hasattr(case, "condition") and case.condition is not None: + return tuple(case.condition) + if hasattr(case, "active_equalities"): + eqs = getattr(case, "active_equalities") + if isinstance(eqs, (set, frozenset)) and len(eqs) == 1: + return tuple(next(iter(eqs))) + return None + +# ----------------- utilities ----------------- + +def jordan_expected_numeric(alpha: float, h: float) -> np.ndarray: + """ + Expected Jordan-form propagator for the degenerate case a = b. + + System: + A = [[ alpha, 0, 0], + [ 1 , alpha, 0], + [ 0 , 1 , alpha]] + + Theoretical propagator: + P = exp(-h*A) = exp(-h*alpha) * [[1, 0, 0], + [h, 1, 0], + [h*h/2, h, 1]] + """ + E = np.exp(-h * alpha) + return E * np.array([[1.0, 0.0, 0.0], + [h, 1.0, 0.0], + [0.5*h*h, h, 1.0]], dtype=float) + +TOL = dict(rtol=1e-10, atol=1e-12) + +# ----------------- tests ----------------- + +def test_numeric_case_yields_jordan(): + """ + Check that when a=b, the propagated matrix matches the analytical Jordan form. + """ + def A_fn(p): + a, b = p["a"], p["b"] + return np.array([[a, 0.0, 0.0], + [-1.0, a, 0.0], + [0.0, -1.0, b ]], dtype=float) + + h = 0.2 + base = {"a": 1.3, "b": 2.1} + + case_set, _ = _call_build_transition_cases( + A_fn=A_fn, + h=h, + base_params=base, + param_names=["a", "b"], + conditions=[("a", "b")], + tie_mode="left", + expand_truth_table=False, + ) + + # Extract the a=b case + assert len(case_set.cases) == 1 + case = case_set.cases[0] + assert _get_condition_tuple(case) == ("a", "b") + + # Expected Jordan-form propagator for a=b=base["a"] + alpha = base["a"] + P_expected = jordan_expected_numeric(alpha, h) + + # Numerical equivalence check + assert_allclose(case.P, P_expected, **TOL) + + +def test_numeric_case_auto_conditions(): + """ + Verify that the automatically generated pairwise equality conditions + include ('t1', 't2') and that its sub-block matches the 2×2 Jordan form. + """ + def A_fn(p): + # Simple 3×3 upper-triangular chain structure + t1, t2, t3 = p["t1"], p["t2"], p["t3"] + return np.array([[t1, 0.0, 0.0], + [-1.0, t2, 0.0], + [0.0, -1.0, t3]], dtype=float) + + h = 0.05 + base = {"t1": 1.0, "t2": 1.0, "t3": 2.0} + + case_set, _ = _call_build_transition_cases( + A_fn=A_fn, + h=h, + base_params=base, + param_names=["t1", "t2", "t3"], + conditions=None, + tie_mode="left", + expand_truth_table=False, + ) + + # The automatically generated condition list should include ('t1', 't2') + conds = [_get_condition_tuple(c) for c in case_set.cases] + assert ("t1", "t2") in conds + + # Find the ('t1','t2') case and verify that its top-left 2×2 block + # matches the analytical 2×2 Jordan block for parameter alpha = t1. + c = next(c for c in case_set.cases if _get_condition_tuple(c) == ("t1", "t2")) + alpha = base["t1"] + P_expected_2x2 = np.exp(-h*alpha) * np.array([[1.0, 0.0], + [h, 1.0]]) + assert_allclose(c.P[:2, :2], P_expected_2x2, **TOL) + + +def test_numeric_default_shape_and_sanity(): + """ + For small h, exp(-hA) ≈ I - hA. This serves as a numerical sanity check. + """ + def A_fn(p): + return np.array([[p["a"], 1.0], + [0.0, p["b"]]], dtype=float) + + h = 0.1 + base = {"a": 1.0, "b": 1.0} + + case_set, _ = _call_build_transition_cases( + A_fn=A_fn, + h=h, + base_params=base, + param_names=["a", "b"], + conditions=[("a", "b")], + expand_truth_table=False, + ) + + P0 = case_set.default.P + assert P0.shape == (2, 2) + + # For small h, exp(-hA) ≈ I - hA; mild consistency check + I = np.eye(2) + approx = I - h * A_fn(base) + assert np.linalg.norm(P0 - approx) < 1.0