Skip to content

Extend Array API compliance#9

Draft
cboulay wants to merge 2 commits intodevfrom
array_api
Draft

Extend Array API compliance#9
cboulay wants to merge 2 commits intodevfrom
array_api

Conversation

@cboulay
Copy link
Member

@cboulay cboulay commented Feb 7, 2026

Summary

Extends Array API compliance across ezmsg-learn, following the pattern already established in process/ssr.py. This enables the same code to run on NumPy, CuPy, PyTorch, and other Array API-compatible backends via array_api_compat.get_namespace().

Phase 1: Kalman Filter (model/refit_kalman.py, process/refit_kalman.py)

  • fit(), predict(), update(): All linear algebra now uses Array API (xp.linalg.inv, xp.linalg.matrix_transpose, xp.linalg.pinv, xp_create(xp.eye, ...)).
  • _compute_gain(): DARE solver remains NumPy (no Array API equivalent for scipy.linalg.solve_discrete_are); results are converted back to the source namespace via xp_asarray.
  • refit(): Per-sample mutation loop stays NumPy (small-vector np.linalg.norm, scalar indexing); final H/Q computation converted to xp.
  • _reset_state() / _process(): Derive xp/dev from message.data; output arrays created with xp_create.
  • Removed redundant .copy() calls (predict/update already return new arrays).

Phase 2: Incremental CCA (model/cca.py)

  • Replaced all scipy.linalg.inv(scipy.linalg.sqrtm(...)) calls (4 occurrences) with _inv_sqrtm_spd() — an eigendecomposition-based inverse square root using only Array API ops (eigh, clip, sqrt, @, matrix_transpose).
  • Removed scipy.linalg dependency entirely from this module.
  • np.linalg.norm -> xp.linalg.matrix_norm, np.clip(scalar) -> Python max(min(...)), .any() -> bool(xp.any(...)).
  • Added ref_array parameter to initialize() for namespace derivation.
  • New test file tests/unit/test_cca.py with 8 tests including numerical equivalence validation against scipy.

Phase 3: Tier 2 Processors

  • process/slda.py: np.moveaxis -> xp.permute_dims; NumPy boundary before sklearn.predict_proba.
  • process/adaptive_linear_regressor.py: np.any(np.isnan(...)) -> xp.any(xp.isnan(...)), np.moveaxis -> xp.permute_dims; NumPy boundary before sklearn/river calls.
  • dim_reduce/adaptive_decomp.py: np.prod -> math.prod, .reshape -> xp.reshape; NumPy boundary before sklearn partial_fit/transform, convert back to source namespace after transform.

Not changed

  • process/rnn.py (PyTorch-native)
  • process/sklearn.py (generic wrapper, unknown models)
  • process/linear_regressor.py, process/sgd.py (trivial: just np.isnan guard + sklearn call)

Test plan

  • All 19 existing Kalman filter tests pass
  • All 4 existing SLDA tests pass
  • All 4 existing adaptive linear regressor tests pass
  • All 8 existing adaptive decomp tests pass
  • 8 new CCA tests pass (including _inv_sqrtm_spd vs scipy.linalg.sqrtm numerical equivalence)
  • Full suite: 277 passed, 2 skipped (CUDA), 0 failures

@cboulay cboulay marked this pull request as draft February 7, 2026 06:46
@cboulay
Copy link
Member Author

cboulay commented Feb 7, 2026

Waiting to merge until a new ezmsg-sigproc release, which is waiting on a new ezmsg release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant