diff --git a/dynestyx/__init__.py b/dynestyx/__init__.py index 84435882..4949c92b 100644 --- a/dynestyx/__init__.py +++ b/dynestyx/__init__.py @@ -5,7 +5,7 @@ __version__ = version("dynestyx") from dynestyx.discretizers import Discretizer, euler_maruyama -from dynestyx.handlers import plate, sample +from dynestyx.handlers import infer, plate, sample from dynestyx.inference.filters import Filter from dynestyx.inference.smoothers import Smoother from dynestyx.models import ( @@ -33,6 +33,7 @@ SDESimulator, Simulator, ) +from dynestyx.types import InferResult from dynestyx.utils import flatten_draws __all__ = [ @@ -56,6 +57,8 @@ "Filter", "Smoother", "flatten_draws", + "infer", + "InferResult", "plate", "sample", "DiracIdentityObservation", diff --git a/dynestyx/discretizers.py b/dynestyx/discretizers.py index 1f0095fb..06031c58 100644 --- a/dynestyx/discretizers.py +++ b/dynestyx/discretizers.py @@ -3,7 +3,7 @@ from effectful.ops.syntax import ObjectInterpretation, implements from jaxtyping import Array, Real -from dynestyx.handlers import HandlesSelf, _sample_intp +from dynestyx.handlers import HandlesSelf, _infer_intp from dynestyx.models import ( DynamicalModel, GaussianStateEvolution, @@ -152,7 +152,7 @@ def __init__(self, discretize=euler_maruyama): super().__init__() self.discretize = discretize - @implements(_sample_intp) + @implements(_infer_intp) def _sample_ds( self, name: str, diff --git a/dynestyx/handlers.py b/dynestyx/handlers.py index b9f31804..d5882554 100644 --- a/dynestyx/handlers.py +++ b/dynestyx/handlers.py @@ -1,4 +1,4 @@ -"""Contains the `sample` primitive and `effectful` utilities for `dynestyx`.""" +"""Contains the `sample` and `infer` primitives and `effectful` utilities for `dynestyx`.""" from typing import TypeVar @@ -11,7 +11,7 @@ from dynestyx.models import ( DynamicalModel, ) -from dynestyx.types import FunctionOfTime +from dynestyx.types import FunctionOfTime, InferResult from dynestyx.utils import ( _get_dynamics_with_t0, _validate_control_dim, @@ -22,7 +22,7 @@ T = TypeVar("T") -def sample( +def _validate_and_prepare( name: str, dynamics: DynamicalModel, *, @@ -35,36 +35,8 @@ def sample( | Real[Array, "*ctrl_value_plate ctrl_time"] | None = None, predict_times: Real[Array, "*predict_time_plate predict_time"] | None = None, - **kwargs, -) -> FunctionOfTime: - """ - Samples from a dynamical model. This is the main primitive of dynestyx. - - The `sample` primitive is meant to mimic the `numpyro.sample` primitive in usage, - but using a `DynamicalModel` instead of a `Distribution`. - - The `sample` method calls `_sample_intp`, which is defined as a `defop` in `effectful`. - This is where any real "work" is done, after input validation. - - Shape note: - Inside ``dsx.plate``, observation arrays use leading plate axes followed - by time and event axes, e.g. ``(N, T, obs_dim)``. Model parameters follow - the same leading-plate, trailing-event convention. See :class:`plate` - for the full plated-shape contract. - - Parameters: - name: Name of the sample site. - dynamics: Dynamical model to sample from. - obs_times: Times at which to sample the observations. - obs_values: Values of the observations at the given times. - ctrl_times: Times at which to sample the controls. - ctrl_values: Values of the controls at the given times. - predict_times: Times at which to predict the observations. - **kwargs: Additional keyword arguments. - - Returns: - FunctionOfTime: A function of time that samples from the dynamical model. - """ +) -> DynamicalModel: + """Validate inputs and return dynamics with t0 resolved.""" # Rule: obs_times must be accompanied with obs_values, which should be the same length. if obs_times is None and predict_times is None: raise ValueError("At least one of obs_times or predict_times must be provided") @@ -106,10 +78,50 @@ def sample( _validate_control_dim(dynamics, ctrl_values) # Initial dynamics may not have t0, which is then inferred from obs_times - dynamics_with_t0 = _get_dynamics_with_t0(dynamics, obs_times, predict_times) + return _get_dynamics_with_t0(dynamics, obs_times, predict_times) + + +def infer( + name: str, + dynamics: DynamicalModel, + *, + obs_times: Real[Array, "*obs_time_plate obs_time"] | None = None, + obs_values: Real[Array, "*obs_value_plate obs_time observation_dim"] + | Real[Array, "*obs_value_plate obs_time"] + | None = None, + ctrl_times: Real[Array, "*ctrl_time_plate ctrl_time"] | None = None, + ctrl_values: Real[Array, "*ctrl_value_plate ctrl_time control_dim"] + | Real[Array, "*ctrl_value_plate ctrl_time"] + | None = None, + predict_times: Real[Array, "*predict_time_plate predict_time"] | None = None, + **kwargs, +): + """Run inference on a dynamical model without registering numpyro sites. + + This is the numpyro-free entry point. When a Filter or Smoother handler + is active, returns an InferResult dataclass with marginal_loglik, states, etc. - # Pass to interpreted version of `sample` for inference. - return _sample_intp( + Parameters: + name: Name of the inference site. + dynamics: Dynamical model to infer. + obs_times: Times at which observations are available. + obs_values: Values of the observations at the given times. + ctrl_times: Times at which controls are applied. + ctrl_values: Values of the controls at the given times. + predict_times: Times at which to predict. + **kwargs: Additional keyword arguments. + """ + dynamics_with_t0 = _validate_and_prepare( + name, + dynamics, + obs_times=obs_times, + obs_values=obs_values, + ctrl_times=ctrl_times, + ctrl_values=ctrl_values, + predict_times=predict_times, + ) + + return _infer_intp( name, dynamics_with_t0, obs_times=obs_times, @@ -121,8 +133,65 @@ def sample( ) +def sample( + name: str, + dynamics: DynamicalModel, + *, + obs_times: Real[Array, "*obs_time_plate obs_time"] | None = None, + obs_values: Real[Array, "*obs_value_plate obs_time observation_dim"] + | Real[Array, "*obs_value_plate obs_time"] + | None = None, + ctrl_times: Real[Array, "*ctrl_time_plate ctrl_time"] | None = None, + ctrl_values: Real[Array, "*ctrl_value_plate ctrl_time control_dim"] + | Real[Array, "*ctrl_value_plate ctrl_time"] + | None = None, + predict_times: Real[Array, "*predict_time_plate predict_time"] | None = None, + **kwargs, +): + """ + Samples from a dynamical model. This is the main primitive of dynestyx. + + The ``sample`` primitive is meant to mimic the ``numpyro.sample`` primitive + in usage, but using a ``DynamicalModel`` instead of a ``Distribution``. + + Internally, ``sample`` calls ``dsx.infer(...)`` and then registers the + results as numpyro sites (``numpyro.factor``, ``numpyro.deterministic``). + + Shape note: + Inside ``dsx.plate``, observation arrays use leading plate axes followed + by time and event axes, e.g. ``(N, T, obs_dim)``. Model parameters follow + the same leading-plate, trailing-event convention. See :class:`plate` + for the full plated-shape contract. + + Parameters: + name: Name of the sample site. + dynamics: Dynamical model to sample from. + obs_times: Times at which to sample the observations. + obs_values: Values of the observations at the given times. + ctrl_times: Times at which to sample the controls. + ctrl_values: Values of the controls at the given times. + predict_times: Times at which to predict the observations. + **kwargs: Additional keyword arguments. + """ + result = infer( + name, + dynamics, + obs_times=obs_times, + obs_values=obs_values, + ctrl_times=ctrl_times, + ctrl_values=ctrl_values, + predict_times=predict_times, + **kwargs, + ) + + if isinstance(result, InferResult) and result._register_numpyro_sites is not None: + result._register_numpyro_sites(name) + + return result + + @defop -def _sample_intp( +def _infer_intp( name: str, dynamics: DynamicalModel, *, @@ -337,7 +406,7 @@ def __exit__(self, exc_type, exc, tb): self._cm.__exit__(exc_type, exc, tb) return self._numpyro_plate.__exit__(exc_type, exc, tb) - @implements(_sample_intp) + @implements(_infer_intp) def _sample_ds( self, name, dynamics, *, plate_shapes=(), **kwargs ) -> FunctionOfTime: diff --git a/dynestyx/inference/filters.py b/dynestyx/inference/filters.py index 7e2f8528..4b8b84de 100644 --- a/dynestyx/inference/filters.py +++ b/dynestyx/inference/filters.py @@ -1,5 +1,6 @@ import dataclasses import math +from abc import ABC, abstractmethod from typing import cast import jax @@ -10,7 +11,7 @@ from effectful.ops.syntax import ObjectInterpretation, implements from jaxtyping import Array, PRNGKeyArray, Real -from dynestyx.handlers import HandlesSelf, _sample_intp +from dynestyx.handlers import HandlesSelf, _infer_intp from dynestyx.inference.checkers import ( _validate_batched_plate_alignment, _validate_missing_observation_support, @@ -56,17 +57,21 @@ from dynestyx.inference.integrations.cuthbert.discrete import ( run_discrete_filter as run_cuthbert_discrete, ) +from dynestyx.inference.numpyro_sites import ( + register_filter_sites, + register_hmm_filter_sites, +) from dynestyx.inference.plate_utils import _array_plate_axis, _make_plate_in_axes from dynestyx.models import DynamicalModel -from dynestyx.types import FunctionOfTime +from dynestyx.types import FunctionOfTime, InferResult type SSMType = ContDiscreteNonlinearGaussianSSM | ContDiscreteNonlinearSSM -class BaseLogFactorAdder(ObjectInterpretation, HandlesSelf): +class BaseLogFactorAdder(ObjectInterpretation, HandlesSelf, ABC): """Base for filter handlers.""" - @implements(_sample_intp) + @implements(_infer_intp) def _sample_ds( self, name: str, @@ -96,8 +101,9 @@ def _sample_ds( **kwargs, ) - # Filter consumes obs_times and obs_values, so they are passed forward as None - return fwd( + # Filter consumes obs_times and obs_values, so they are passed forward as None. + # fwd() lets handlers above (e.g. Simulator) use filtered_dists for rollout. + fwd( name, dynamics, plate_shapes=plate_shapes, @@ -110,6 +116,9 @@ def _sample_ds( **kwargs, ) + return self._build_infer_result(name, filtered_dists) + + @abstractmethod def _add_log_factors( self, name: str, @@ -125,9 +134,12 @@ def _add_log_factors( | Real[Array, "*ctrl_value_plate ctrl_time"] | None = None, **kwargs, - ) -> list[numpyro.distributions.Distribution] | None: - # Inheritors should implement this method. - raise NotImplementedError() + ) -> list[numpyro.distributions.Distribution] | None: ... + + @abstractmethod + def _build_infer_result( + self, name: str, filtered_dists: list | None + ) -> InferResult: ... def _default_filter_config(dynamics: DynamicalModel): @@ -193,6 +205,13 @@ class Filter(BaseLogFactorAdder): """ filter_config: BaseFilterConfig | None = None + marginal_loglik: jax.Array | None = dataclasses.field( + default=None, repr=False, init=False + ) + filtered_states: object = dataclasses.field(default=None, repr=False, init=False) + _filter_config_used: BaseFilterConfig | None = dataclasses.field( + default=None, repr=False, init=False + ) def _add_log_factors( self, @@ -210,17 +229,10 @@ def _add_log_factors( | None = None, **kwargs, ) -> list[numpyro.distributions.Distribution] | None: - """ - Add the marginal log likelihood as a numpyro factor. - - Args: - name: Name of the factor. - dynamics: Dynamical model to filter. - plate_shapes: Tuple of plate sizes from enclosing dsx.plate contexts. - obs_times: Observation times. - obs_values: Observed values. - ctrl_times: Control times (optional). - ctrl_values: Control values (optional). + """Run filtering and store the marginal log-likelihood. + + Pure computation — no numpyro side effects. Site registration + happens via the callback in InferResult when called through dsx.sample. """ if obs_times is None or obs_values is None: raise ValueError("obs_times and obs_values are required for filtering.") @@ -237,7 +249,19 @@ def _add_log_factors( mode="filter", ) - key = numpyro.prng_key() if config.crn_seed is None else config.crn_seed + # Resolve PRNG key: use explicit seed from config, fall back to numpyro + # context (inside a seeded model), or None (deterministic filters don't need one). + if config.crn_seed is not None: + key = config.crn_seed + else: + import warnings # noqa: PLC0415 + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + key = numpyro.prng_key() + except Exception: + key = None if plate_shapes: return self._add_log_factors_batched( @@ -261,7 +285,32 @@ def _add_log_factors( "inside `Filter()`. " f"Got {type(config).__name__}; valid continuous-time config types: {valid}." ) - return _filter_continuous_time( + marginal_loglik, states, filtered_dists = _filter_continuous_time( + name, + dynamics, + config, # type: ignore[arg-type] + key=key, + obs_times=obs_times, + obs_values=obs_values, + ctrl_times=ctrl_times, + ctrl_values=ctrl_values, + **kwargs, + ) + elif isinstance(config, HMMConfigs): + loglik, log_filt_seq, filtered_dists = _filter_hmm( + name, + dynamics, + cast(HMMConfig, config), + obs_times=obs_times, + obs_values=obs_values, + ctrl_times=ctrl_times, + ctrl_values=ctrl_values, + **kwargs, + ) + marginal_loglik = loglik + states = log_filt_seq + elif isinstance(config, DiscreteTimeConfigs): + marginal_loglik, states, filtered_dists = _filter_discrete_time( name, dynamics, config, # type: ignore[arg-type] @@ -273,35 +322,52 @@ def _add_log_factors( **kwargs, ) else: - if isinstance(config, HMMConfigs): - return _filter_hmm( - name, - dynamics, + valid = [c.__name__ for c in HMMConfigs + DiscreteTimeConfigs] + raise ValueError( + f"Invalid filter config: {type(config).__name__}. " + f"Valid config types: {valid}" + ) + + self.marginal_loglik = marginal_loglik + self.filtered_states = states + self._filter_config_used = config + + return filtered_dists + + def _build_infer_result( + self, name: str, filtered_dists: list | None + ) -> InferResult: + """Construct InferResult with a deferred numpyro registration callback.""" + marginal_loglik = self.marginal_loglik + states = self.filtered_states + config = self._filter_config_used + _is_batched = ( + isinstance(marginal_loglik, jax.Array) and marginal_loglik.ndim > 0 + ) + + def _register(site_name: str) -> None: + if marginal_loglik is None or config is None: + return + if _is_batched: + # TODO: support per-field recording for batched (plate) states + numpyro.factor(f"{site_name}_marginal_log_likelihood", marginal_loglik) + numpyro.deterministic(f"{site_name}_marginal_loglik", marginal_loglik) + elif isinstance(config, HMMConfigs): + register_hmm_filter_sites( + site_name, + marginal_loglik, + cast(jax.Array, states), cast(HMMConfig, config), - obs_times=obs_times, - obs_values=obs_values, - ctrl_times=ctrl_times, - ctrl_values=ctrl_values, - **kwargs, - ) - elif isinstance(config, DiscreteTimeConfigs): - return _filter_discrete_time( - name, - dynamics, - config, # type: ignore[arg-type] - key=key, - obs_times=obs_times, - obs_values=obs_values, - ctrl_times=ctrl_times, - ctrl_values=ctrl_values, - **kwargs, ) else: - valid = [c.__name__ for c in HMMConfigs + DiscreteTimeConfigs] - raise ValueError( - f"Invalid filter config: {type(config).__name__}. " - f"Valid config types: {valid}" - ) + register_filter_sites(site_name, marginal_loglik, states, config) + + return InferResult( + marginal_loglik=marginal_loglik, + states=states, + dists=filtered_dists, + _register_numpyro_sites=_register, + ) def _add_log_factors_batched( self, @@ -448,8 +514,9 @@ def compute_output(dyn, ot, ov, ct, cv, k): else: raise ValueError(f"Unsupported batched output kind: {output_kind}") - numpyro.factor(f"{name}_marginal_log_likelihood", marginal_logliks) - numpyro.deterministic(f"{name}_marginal_loglik", marginal_logliks) + self.marginal_loglik = marginal_logliks + self.filtered_states = outputs + self._filter_config_used = config if output_kind == "continuous": particle_mode = isinstance(config, ContinuousTimeDPFConfig) @@ -503,7 +570,7 @@ def _filter_discrete_time( | Real[Array, "*ctrl_value_plate ctrl_time"] | None = None, **kwargs, -) -> list[numpyro.distributions.Distribution]: +) -> tuple[jax.Array, object, list[numpyro.distributions.Distribution]]: """Discrete-time marginal likelihood via cuthbert or cd-dynamax. Filter type inferred from config class: KFConfig, EKFConfig, UKFConfig @@ -560,7 +627,7 @@ def _filter_continuous_time( | Real[Array, "*ctrl_value_plate ctrl_time"] | None = None, **kwargs, -) -> list[numpyro.distributions.Distribution]: +) -> tuple[jax.Array, object, list[numpyro.distributions.Distribution]]: """Continuous-time marginal likelihood via CD-Dynamax. Supports: EnKF, DPF, EKF, UKF (inferred from config type). diff --git a/dynestyx/inference/hmm_filters.py b/dynestyx/inference/hmm_filters.py index 4b0ae8f4..31e24876 100644 --- a/dynestyx/inference/hmm_filters.py +++ b/dynestyx/inference/hmm_filters.py @@ -4,7 +4,6 @@ import jax import jax.numpy as jnp -import numpyro import numpyro.distributions as dist from jax import lax from jax.scipy.special import logsumexp @@ -13,7 +12,6 @@ from dynestyx.inference.filter_configs import HMMConfig from dynestyx.models import DynamicalModel from dynestyx.models.core import DiscreteStateTransition -from dynestyx.utils import _should_record_field def enumerate_latent_states(dynamics: DynamicalModel) -> Int[Array, " n_states"]: @@ -222,7 +220,7 @@ def _filter_hmm( ctrl_times: Real[Array, "*ctrl_time_plate ctrl_time"] | None = None, ctrl_values: Real[Array, "*ctrl_value_plate obs_time control_dim"] | None = None, **kwargs, -) -> list[dist.Distribution]: +) -> tuple[jax.Array, Float[Array, "*plate time n_states"], list[dist.Distribution]]: """Exact HMM marginal likelihood via forward filtering. Args: @@ -235,8 +233,12 @@ def _filter_hmm( ctrl_values: Control values (optional). Returns: - List of Categorical distributions p(x_t | y_{1:t}) at each obs time, - for use with Filter + DiscreteTimeSimulator rollout. + tuple of: + - loglik: scalar marginal log-likelihood log p(y_{1:T}). + - log_filt_seq: log filtering probabilities log p(x_t | y_{1:t}), + shape (time, n_states). + - filtered_dists: list of Categorical distributions p(x_t | y_{1:t}) + at each obs time, for use with Filter + DiscreteTimeSimulator rollout. """ loglik, log_filt_seq = compute_hmm_filter( dynamics, @@ -245,29 +247,8 @@ def _filter_hmm( ctrl_values=ctrl_values, ) - numpyro.factor(f"{name}_marginal_log_likelihood", loglik) - numpyro.deterministic(f"{name}_marginal_loglik", loglik) - - record_max_elems = filter_config.record_max_elems - - if _should_record_field( - filter_config.record_log_filtered, log_filt_seq.shape, record_max_elems - ): - numpyro.deterministic( - f"{name}_log_filtered_states", - log_filt_seq, # (T, K) - ) - - if _should_record_field( - filter_config.record_filtered, log_filt_seq.shape, record_max_elems - ): - numpyro.deterministic( - f"{name}_filtered_states", - jnp.exp(log_filt_seq), # (T, K) - ) - - # Return filtered distributions for Filter + DiscreteTimeSimulator rollout - return [ + filtered_dists = [ dist.Categorical(probs=jnp.exp(log_filt_seq[i])) for i in range(log_filt_seq.shape[0]) ] + return loglik, log_filt_seq, filtered_dists diff --git a/dynestyx/inference/integrations/cd_dynamax/continuous_filter.py b/dynestyx/inference/integrations/cd_dynamax/continuous_filter.py index 6fc63b8b..79e49d14 100644 --- a/dynestyx/inference/integrations/cd_dynamax/continuous_filter.py +++ b/dynestyx/inference/integrations/cd_dynamax/continuous_filter.py @@ -2,7 +2,7 @@ import jax import jax.numpy as jnp -import numpyro +import numpyro.distributions as dist from cd_dynamax import ( ContDiscreteLinearGaussianSSM, ContDiscreteNonlinearGaussianSSM, @@ -19,14 +19,12 @@ ContinuousTimeEnKFConfig, ContinuousTimeKFConfig, ContinuousTimeUKFConfig, - _config_to_record_kwargs, ) from dynestyx.inference.integrations.cd_dynamax.utils import ( dsx_to_cd_dynamax, dsx_to_cdlgssm_params, ) from dynestyx.models import DynamicalModel -from dynestyx.utils import _should_record_field type SSMType = ContDiscreteNonlinearGaussianSSM | ContDiscreteNonlinearSSM @@ -111,41 +109,6 @@ def _config_to_cd_dynamax_filter_kwargs( return base -def _add_filter_sites( - name: str, - filter_config: ContinuousTimeFilterConfig, - filtered, -) -> None: - """Add marginal log-likelihood factor and filtered state deterministic sites.""" - record_kwargs = _config_to_record_kwargs(filter_config) - numpyro.factor(f"{name}_marginal_log_likelihood", filtered.marginal_loglik) - numpyro.deterministic(f"{name}_marginal_loglik", filtered.marginal_loglik) - - max_elems = record_kwargs["record_max_elems"] - means_shape = filtered.filtered_means.shape - cov_shape = filtered.filtered_covariances.shape - add_mean = _should_record_field( - record_kwargs["record_filtered_states_mean"], means_shape, max_elems - ) - add_cov = _should_record_field( - record_kwargs["record_filtered_states_cov"], cov_shape, max_elems - ) - add_cov_diag = _should_record_field( - record_kwargs["record_filtered_states_cov_diag"], - (cov_shape[0], cov_shape[1]), - max_elems, - ) - if add_mean: - numpyro.deterministic(f"{name}_filtered_states_mean", filtered.filtered_means) - if add_cov: - numpyro.deterministic( - f"{name}_filtered_states_cov", filtered.filtered_covariances - ) - if add_cov_diag: - diag_cov = jnp.diagonal(filtered.filtered_covariances, axis1=1, axis2=2) - numpyro.deterministic(f"{name}_filtered_states_cov_diag", diag_cov) - - def _run_linear_kf( dynamics: DynamicalModel, obs_times, @@ -240,8 +203,20 @@ def run_continuous_filter( ctrl_times=None, ctrl_values=None, **kwargs, -) -> list[numpyro.distributions.Distribution]: - """Run continuous-time filter via CD-Dynamax.""" +) -> tuple[jax.Array, object, list[dist.Distribution]]: + """Run continuous-time filter via CD-Dynamax. + + Pure computation — no numpyro side-effects. Callers are responsible for + registering numpyro.factor / numpyro.deterministic if needed. + + Returns: + tuple of: + - marginal_loglik: scalar marginal log-likelihood log p(y_{1:T}). + - filtered_posterior: CD-Dynamax posterior object with filtered_means, + filtered_covariances, and marginal_loglik attributes. + - filtered_dists: list of MultivariateNormal distributions p(x_t | y_{1:t}) + at each obs time, for posterior rollout. + """ filtered = compute_continuous_filter( dynamics, filter_config, @@ -252,15 +227,14 @@ def run_continuous_filter( ctrl_values=ctrl_values, ) - _add_filter_sites(name, filter_config, filtered) - - return _posterior_sequence_to_dists( + filtered_dists = _posterior_sequence_to_dists( filtered, means_attr="filtered_means", covariances_attr="filtered_covariances", particle_mode=isinstance(filter_config, ContinuousTimeDPFConfig), missing_message="Filtered means/covariances unexpectedly None for non-DPF config", ) + return filtered.marginal_loglik, filtered, filtered_dists __all__ = [ diff --git a/dynestyx/inference/integrations/cd_dynamax/continuous_smoother.py b/dynestyx/inference/integrations/cd_dynamax/continuous_smoother.py index f3f55f81..9a8e8c21 100644 --- a/dynestyx/inference/integrations/cd_dynamax/continuous_smoother.py +++ b/dynestyx/inference/integrations/cd_dynamax/continuous_smoother.py @@ -2,7 +2,7 @@ import jax import jax.numpy as jnp -import numpyro +import numpyro.distributions as dist from cd_dynamax import ( ContDiscreteNonlinearGaussianSSM, EKFHyperParams, @@ -19,51 +19,14 @@ from dynestyx.inference.smoother_configs import ( ContinuousTimeEKFSmootherConfig, ContinuousTimeKFSmootherConfig, - _config_to_smoother_record_kwargs, ) from dynestyx.models import DynamicalModel -from dynestyx.utils import _should_record_field ContinuousTimeSmootherConfig = ( ContinuousTimeKFSmootherConfig | ContinuousTimeEKFSmootherConfig ) -def _add_smoother_sites( - name: str, - smoother_config: ContinuousTimeSmootherConfig, - smoothed, -) -> None: - """Add marginal log-likelihood factor and smoothed state deterministic sites.""" - record_kwargs = _config_to_smoother_record_kwargs(smoother_config) - numpyro.factor(f"{name}_marginal_log_likelihood", smoothed.marginal_loglik) - numpyro.deterministic(f"{name}_marginal_loglik", smoothed.marginal_loglik) - - max_elems = record_kwargs["record_max_elems"] - means = smoothed.smoothed_means - covs = smoothed.smoothed_covariances - means_shape = means.shape - cov_shape = covs.shape - add_mean = _should_record_field( - record_kwargs["record_smoothed_states_mean"], means_shape, max_elems - ) - add_cov = _should_record_field( - record_kwargs["record_smoothed_states_cov"], cov_shape, max_elems - ) - add_cov_diag = _should_record_field( - record_kwargs["record_smoothed_states_cov_diag"], - (cov_shape[0], cov_shape[1]), - max_elems, - ) - if add_mean: - numpyro.deterministic(f"{name}_smoothed_states_mean", means) - if add_cov: - numpyro.deterministic(f"{name}_smoothed_states_cov", covs) - if add_cov_diag: - diag_cov = jnp.diagonal(covs, axis1=1, axis2=2) - numpyro.deterministic(f"{name}_smoothed_states_cov_diag", diag_cov) - - def compute_continuous_smoother( dynamics: DynamicalModel, smoother_config: ContinuousTimeSmootherConfig, @@ -162,8 +125,15 @@ def run_continuous_smoother( ctrl_times=None, ctrl_values=None, **kwargs, -) -> list[numpyro.distributions.Distribution]: - """Run continuous-time smoother via CD-Dynamax.""" +) -> tuple[jax.Array, object, list[dist.Distribution]]: + """Run continuous-time smoother via CD-Dynamax. + + Pure computation — no numpyro side-effects. Callers are responsible for + registering numpyro.factor / numpyro.deterministic if needed. + + Returns: + tuple: (marginal_loglik, smoothed_posterior, smoothed_dists). + """ smoothed = compute_continuous_smoother( dynamics, smoother_config, @@ -174,14 +144,14 @@ def run_continuous_smoother( ctrl_values=ctrl_values, ) - _add_smoother_sites(name, smoother_config, smoothed) - return _posterior_sequence_to_dists( + smoothed_dists = _posterior_sequence_to_dists( smoothed, means_attr="smoothed_means", covariances_attr="smoothed_covariances", particle_mode=False, missing_message="Smoothed means/covariances unexpectedly None.", ) + return smoothed.marginal_loglik, smoothed, smoothed_dists __all__ = ["compute_continuous_smoother", "run_continuous_smoother"] diff --git a/dynestyx/inference/integrations/cd_dynamax/discrete_filter.py b/dynestyx/inference/integrations/cd_dynamax/discrete_filter.py index 61398247..2af896f3 100644 --- a/dynestyx/inference/integrations/cd_dynamax/discrete_filter.py +++ b/dynestyx/inference/integrations/cd_dynamax/discrete_filter.py @@ -2,10 +2,8 @@ import jax import jax.numpy as jnp -import numpyro import numpyro.distributions as dist from cd_dynamax.dynamax.linear_gaussian_ssm.inference import ( - PosteriorGSSMFiltered, lgssm_filter, ) from cd_dynamax.dynamax.linear_gaussian_ssm.models import LinearGaussianSSM @@ -23,7 +21,6 @@ EKFConfig, KFConfig, UKFConfig, - _config_to_record_kwargs, ) from dynestyx.inference.integrations.cd_dynamax.utils import gaussian_to_nlgssm_params from dynestyx.inference.integrations.utils import squeeze_leading_singletons @@ -32,7 +29,6 @@ LinearGaussianObservation, LinearGaussianStateEvolution, ) -from dynestyx.utils import _should_record_field def _lti_to_lgssm_params(dynamics: DynamicalModel): @@ -127,36 +123,6 @@ def compute_cd_dynamax_discrete_filter( ) -def _add_kf_sites( - name: str, posterior: PosteriorGSSMFiltered, record_kwargs: dict -) -> None: - """Add filtered means/covariances as deterministic sites (dynamax KF posterior).""" - max_elems = record_kwargs["record_max_elems"] - if posterior.filtered_means is None: - return - means = posterior.filtered_means - covs = posterior.filtered_covariances - t1, state_dim = means.shape - add_mean = _should_record_field( - record_kwargs["record_filtered_states_mean"], means.shape, max_elems - ) - add_cov = _should_record_field( - record_kwargs["record_filtered_states_cov"], - (t1, state_dim, state_dim), - max_elems, - ) - add_cov_diag = _should_record_field( - record_kwargs["record_filtered_states_cov_diag"], (t1, state_dim), max_elems - ) - if add_mean: - numpyro.deterministic(f"{name}_filtered_states_mean", means) - if add_cov and covs is not None: - numpyro.deterministic(f"{name}_filtered_states_cov", covs) - if add_cov_diag and covs is not None: - diag_cov = jnp.diagonal(covs, axis1=1, axis2=2) - numpyro.deterministic(f"{name}_filtered_states_cov_diag", diag_cov) - - def run_discrete_filter( name: str, dynamics: DynamicalModel, @@ -167,8 +133,15 @@ def run_discrete_filter( ctrl_times=None, ctrl_values=None, **kwargs, -) -> list[dist.Distribution]: - """Run discrete-time filter via cd-dynamax (KF, EKF, UKF).""" +) -> tuple[jax.Array, object, list[dist.Distribution]]: + """Run discrete-time filter via cd-dynamax (KF, EKF, UKF). + + Pure computation — no numpyro side-effects. Callers are responsible for + registering numpyro.factor / numpyro.deterministic if needed. + + Returns: + tuple: (marginal_loglik, posterior, filtered_dists). + """ posterior = compute_cd_dynamax_discrete_filter( dynamics, filter_config, @@ -178,19 +151,14 @@ def run_discrete_filter( ctrl_values=ctrl_values, ) - record_kwargs = _config_to_record_kwargs(filter_config) - marginal_loglik = posterior.marginal_loglik - numpyro.factor(f"{name}_marginal_log_likelihood", marginal_loglik) - numpyro.deterministic(f"{name}_marginal_loglik", marginal_loglik) - _add_kf_sites(name, posterior, record_kwargs) - - return _posterior_sequence_to_dists( + filtered_dists = _posterior_sequence_to_dists( posterior, means_attr="filtered_means", covariances_attr="filtered_covariances", particle_mode=False, missing="empty", ) + return posterior.marginal_loglik, posterior, filtered_dists __all__ = [ diff --git a/dynestyx/inference/integrations/cd_dynamax/discrete_smoother.py b/dynestyx/inference/integrations/cd_dynamax/discrete_smoother.py index 1686ccc1..002d7437 100644 --- a/dynestyx/inference/integrations/cd_dynamax/discrete_smoother.py +++ b/dynestyx/inference/integrations/cd_dynamax/discrete_smoother.py @@ -1,8 +1,6 @@ """Discrete-time smoothers via cd-dynamax (dynamax): KF, EKF, UKF.""" import jax -import jax.numpy as jnp -import numpyro import numpyro.distributions as dist from cd_dynamax.dynamax.linear_gaussian_ssm.inference import lgssm_smoother from cd_dynamax.dynamax.nonlinear_gaussian_ssm.inference_ekf import ( @@ -26,10 +24,8 @@ from dynestyx.inference.integrations.cd_dynamax.utils import gaussian_to_nlgssm_params from dynestyx.inference.smoother_configs import ( BaseSmootherConfig, - _config_to_smoother_record_kwargs, ) from dynestyx.models import DynamicalModel -from dynestyx.utils import _should_record_field def compute_cd_dynamax_discrete_smoother( @@ -68,34 +64,6 @@ def compute_cd_dynamax_discrete_smoother( ) -def _add_smoother_sites(name: str, posterior, record_kwargs: dict) -> None: - """Add smoothed means/covariances as deterministic sites.""" - max_elems = record_kwargs["record_max_elems"] - means = posterior.smoothed_means - covs = posterior.smoothed_covariances - if means is None or covs is None: - return - t1, state_dim = means.shape - add_mean = _should_record_field( - record_kwargs["record_smoothed_states_mean"], means.shape, max_elems - ) - add_cov = _should_record_field( - record_kwargs["record_smoothed_states_cov"], - (t1, state_dim, state_dim), - max_elems, - ) - add_cov_diag = _should_record_field( - record_kwargs["record_smoothed_states_cov_diag"], (t1, state_dim), max_elems - ) - if add_mean: - numpyro.deterministic(f"{name}_smoothed_states_mean", means) - if add_cov: - numpyro.deterministic(f"{name}_smoothed_states_cov", covs) - if add_cov_diag: - diag_cov = jnp.diagonal(covs, axis1=1, axis2=2) - numpyro.deterministic(f"{name}_smoothed_states_cov_diag", diag_cov) - - def run_discrete_smoother( name: str, dynamics: DynamicalModel, @@ -106,8 +74,15 @@ def run_discrete_smoother( ctrl_times=None, ctrl_values=None, **kwargs, -) -> list[dist.Distribution]: - """Run discrete-time smoother via cd-dynamax (KF, EKF, UKF).""" +) -> tuple[jax.Array, object, list[dist.Distribution]]: + """Run discrete-time smoother via cd-dynamax (KF, EKF, UKF). + + Pure computation — no numpyro side-effects. Callers are responsible for + registering numpyro.factor / numpyro.deterministic if needed. + + Returns: + tuple: (marginal_loglik, posterior, smoothed_dists). + """ posterior = compute_cd_dynamax_discrete_smoother( dynamics, filter_config, @@ -117,19 +92,14 @@ def run_discrete_smoother( ctrl_values=ctrl_values, ) - numpyro.factor(f"{name}_marginal_log_likelihood", posterior.marginal_loglik) - numpyro.deterministic(f"{name}_marginal_loglik", posterior.marginal_loglik) - _add_smoother_sites( - name, posterior, _config_to_smoother_record_kwargs(filter_config) - ) - - return _posterior_sequence_to_dists( + smoothed_dists = _posterior_sequence_to_dists( posterior, means_attr="smoothed_means", covariances_attr="smoothed_covariances", particle_mode=False, missing="empty", ) + return posterior.marginal_loglik, posterior, smoothed_dists __all__ = ["compute_cd_dynamax_discrete_smoother", "run_discrete_smoother"] diff --git a/dynestyx/inference/integrations/cuthbert/discrete_filter.py b/dynestyx/inference/integrations/cuthbert/discrete_filter.py index 19ad1459..7f718211 100644 --- a/dynestyx/inference/integrations/cuthbert/discrete_filter.py +++ b/dynestyx/inference/integrations/cuthbert/discrete_filter.py @@ -3,7 +3,6 @@ import jax import jax.numpy as jnp -import numpyro import numpyro.distributions as dist from cuthbert import filter as cuthbert_filter from cuthbert.enkf import ensemble_kalman_filter @@ -15,7 +14,6 @@ stop_gradient_decorator, systematic, ) -from numpyro.distributions import Distribution from dynestyx.inference.distribution_utils import _cholesky_state_sequence_to_dists from dynestyx.inference.filter_configs import ( @@ -24,10 +22,8 @@ EnKFConfig, KFConfig, PFConfig, - _config_to_record_kwargs, ) from dynestyx.inference.integrations.utils import ( - covariance_from_cholesky, squeeze_leading_singletons, ) from dynestyx.models import ( @@ -36,7 +32,6 @@ LinearGaussianObservation, LinearGaussianStateEvolution, ) -from dynestyx.utils import _should_record_field class CuthbertInputs(NamedTuple): @@ -100,10 +95,10 @@ def _probe_state_independent_observation_noise( probe_u = jnp.zeros(()) probe_t = jnp.zeros(()) try: - probe_d0: Distribution | None = obs_model( + probe_d0: dist.Distribution | None = obs_model( jnp.zeros((state_dim,)), probe_u, probe_t ) - probe_d1: Distribution | None = obs_model( + probe_d1: dist.Distribution | None = obs_model( jnp.ones((state_dim,)), probe_u, probe_t ) except Exception: @@ -260,15 +255,18 @@ def run_discrete_filter( ctrl_times=None, ctrl_values=None, **kwargs, -) -> list[dist.Distribution]: +) -> tuple[jax.Array, object, list[dist.Distribution]]: """Run discrete-time filter via cuthbert (Kalman, Taylor KF, particle filter). + Pure computation — no numpyro side-effects. Callers are responsible for + registering numpyro.factor / numpyro.deterministic if needed. + Returns: - list[dist.Distribution]: Filtered state distributions at each obs time. + tuple: (marginal_loglik, raw_states, filtered_dists). """ obs_len = int(obs_values.shape[0]) if obs_len == 0: - return [] + return jnp.array(0.0), None, [] marginal_loglik, states = compute_cuthbert_filter( dynamics, @@ -279,19 +277,11 @@ def run_discrete_filter( ctrl_times=ctrl_times, ctrl_values=ctrl_values, ) - record_kwargs = _config_to_record_kwargs(filter_config) - - numpyro.factor(f"{name}_marginal_log_likelihood", marginal_loglik) - numpyro.deterministic(f"{name}_marginal_loglik", marginal_loglik) - - if isinstance(filter_config, PFConfig): - _add_sites_pf(name, states, record_kwargs) - else: - _add_sites_gaussian_filter(name, states, record_kwargs) - return _cholesky_state_sequence_to_dists( + filtered_dists = _cholesky_state_sequence_to_dists( states, particle_mode=isinstance(filter_config, PFConfig), ) + return marginal_loglik, states, filtered_dists def _cuthbert_filter_pf(dynamics: DynamicalModel, filter_kwargs: dict | None = None): @@ -595,104 +585,3 @@ def log_potential(x): ignore_nan_dims=True, ) return kf - - -def _add_sites_pf( - name: str, states: particle_filter.ParticleFilterState, record_kwargs: dict -): - log_weights = states.log_weights - particles = states.particles - if particles.ndim == 2: - particles = particles[..., None] - max_elems = record_kwargs["record_max_elems"] - t_len, n_particles, state_dim = particles.shape - - add_particles = _should_record_field( - record_kwargs["record_filtered_particles"], particles.shape, max_elems - ) - add_log_weights = _should_record_field( - record_kwargs["record_filtered_log_weights"], log_weights.shape, max_elems - ) - add_mean = _should_record_field( - record_kwargs["record_filtered_states_mean"], (t_len, state_dim), max_elems - ) - add_filtered_states_cov = _should_record_field( - record_kwargs["record_filtered_states_cov"], - (t_len, state_dim, state_dim), - max_elems, - ) - add_filtered_states_cov_diag = _should_record_field( - record_kwargs["record_filtered_states_cov_diag"], (t_len, state_dim), max_elems - ) - - need_filtered_means = ( - add_mean or add_filtered_states_cov or add_filtered_states_cov_diag - ) - - if need_filtered_means: - w = jax.nn.softmax(log_weights, axis=1)[..., None] # (T+1, n_particles, 1) - filtered_means = jnp.sum(particles * w, axis=1) # (T+1, state_dim) - - if add_filtered_states_cov or add_filtered_states_cov_diag: - second_mom = jnp.einsum( - "...tnj,...tnk,...tn->...tjk", particles, particles, w.squeeze(-1) - ) - filtered_covariances = second_mom - jnp.einsum( - "...tj,...tk->...tjk", filtered_means, filtered_means - ) - - if add_particles: - numpyro.deterministic(f"{name}_filtered_particles", particles) - if add_log_weights: - numpyro.deterministic(f"{name}_filtered_log_weights", log_weights) - if add_mean: - numpyro.deterministic(f"{name}_filtered_states_mean", filtered_means) - if add_filtered_states_cov: - numpyro.deterministic(f"{name}_filtered_states_cov", filtered_covariances) - if add_filtered_states_cov_diag: - diag_cov = jnp.diagonal(filtered_covariances, axis1=1, axis2=2) - numpyro.deterministic(f"{name}_filtered_states_cov_diag", diag_cov) - - -def _add_sites_gaussian_filter( - name: str, - states: kalman.KalmanFilterState - | taylor.LinearizedKalmanFilterState - | ensemble_kalman_filter.EnKFState, - record_kwargs: dict, -): - max_elems = record_kwargs["record_max_elems"] - mean = states.mean - chol_cov = states.chol_cov - t_len, state_dim, _ = chol_cov.shape - - add_mean = _should_record_field( - record_kwargs["record_filtered_states_mean"], mean.shape, max_elems - ) - add_chol_cov = _should_record_field( - record_kwargs["record_filtered_states_chol_cov"], - chol_cov.shape, - max_elems, - ) - add_filtered_states_cov = _should_record_field( - record_kwargs["record_filtered_states_cov"], - (t_len, state_dim, state_dim), - max_elems, - ) - add_filtered_states_cov_diag = _should_record_field( - record_kwargs["record_filtered_states_cov_diag"], (t_len, state_dim), max_elems - ) - - if add_mean: - numpyro.deterministic(f"{name}_filtered_states_mean", mean) - if add_chol_cov: - numpyro.deterministic(f"{name}_filtered_states_chol_cov", chol_cov) - - if add_filtered_states_cov or add_filtered_states_cov_diag: - filtered_cov = covariance_from_cholesky(chol_cov) - - if add_filtered_states_cov: - numpyro.deterministic(f"{name}_filtered_states_cov", filtered_cov) - if add_filtered_states_cov_diag: - diag_cov = jnp.diagonal(filtered_cov, axis1=1, axis2=2) - numpyro.deterministic(f"{name}_filtered_states_cov_diag", diag_cov) diff --git a/dynestyx/inference/integrations/cuthbert/discrete_smoother.py b/dynestyx/inference/integrations/cuthbert/discrete_smoother.py index 05fdcdb3..5229711d 100644 --- a/dynestyx/inference/integrations/cuthbert/discrete_smoother.py +++ b/dynestyx/inference/integrations/cuthbert/discrete_smoother.py @@ -6,7 +6,6 @@ import jax import jax.numpy as jnp -import numpyro import numpyro.distributions as dist from cuthbert import smoother as cuthbert_smoother from cuthbert.gaussian import kalman, taylor @@ -22,21 +21,18 @@ compute_cuthbert_filter, ) from dynestyx.inference.integrations.utils import ( - covariance_from_cholesky, squeeze_leading_singletons, ) from dynestyx.inference.smoother_configs import ( EKFSmootherConfig, KFSmootherConfig, PFSmootherConfig, - _config_to_smoother_record_kwargs, ) from dynestyx.models import ( DynamicalModel, LinearGaussianObservation, LinearGaussianStateEvolution, ) -from dynestyx.utils import _should_record_field CuthbertSmootherConfig = KFSmootherConfig | EKFSmootherConfig | PFSmootherConfig @@ -250,96 +246,6 @@ def compute_cuthbert_smoother( return marginal_loglik, smoothed_states -def _add_sites_pf(name: str, states, record_kwargs: dict): - log_weights = states.log_weights - particles = states.particles - if particles.ndim == 2: - particles = particles[..., None] - max_elems = record_kwargs["record_max_elems"] - t1, _, state_dim = particles.shape - - add_particles = _should_record_field( - record_kwargs["record_smoothed_particles"], particles.shape, max_elems - ) - add_log_weights = _should_record_field( - record_kwargs["record_smoothed_log_weights"], log_weights.shape, max_elems - ) - add_mean = _should_record_field( - record_kwargs["record_smoothed_states_mean"], (t1, state_dim), max_elems - ) - add_smoothed_states_cov = _should_record_field( - record_kwargs["record_smoothed_states_cov"], - (t1, state_dim, state_dim), - max_elems, - ) - add_smoothed_states_cov_diag = _should_record_field( - record_kwargs["record_smoothed_states_cov_diag"], (t1, state_dim), max_elems - ) - - need_means = add_mean or add_smoothed_states_cov or add_smoothed_states_cov_diag - if need_means: - w = jax.nn.softmax(log_weights, axis=1)[..., None] - smoothed_means = jnp.sum(particles * w, axis=1) - - if add_smoothed_states_cov or add_smoothed_states_cov_diag: - second_mom = jnp.einsum( - "...tnj,...tnk,...tn->...tjk", particles, particles, w.squeeze(-1) - ) - smoothed_cov = second_mom - jnp.einsum( - "...tj,...tk->...tjk", smoothed_means, smoothed_means - ) - - if add_particles: - numpyro.deterministic(f"{name}_smoothed_particles", particles) - if add_log_weights: - numpyro.deterministic(f"{name}_smoothed_log_weights", log_weights) - if add_mean: - numpyro.deterministic(f"{name}_smoothed_states_mean", smoothed_means) - if add_smoothed_states_cov: - numpyro.deterministic(f"{name}_smoothed_states_cov", smoothed_cov) - if add_smoothed_states_cov_diag: - diag_cov = jnp.diagonal(smoothed_cov, axis1=1, axis2=2) - numpyro.deterministic(f"{name}_smoothed_states_cov_diag", diag_cov) - - -def _add_sites_taylor_kf(name: str, states, record_kwargs: dict): - max_elems = record_kwargs["record_max_elems"] - mean = states.mean - chol_cov = states.chol_cov - t1, state_dim, _ = chol_cov.shape - - add_mean = _should_record_field( - record_kwargs["record_smoothed_states_mean"], mean.shape, max_elems - ) - add_chol_cov = _should_record_field( - record_kwargs["record_smoothed_states_chol_cov"], - chol_cov.shape, - max_elems, - ) - add_smoothed_states_cov = _should_record_field( - record_kwargs["record_smoothed_states_cov"], - (t1, state_dim, state_dim), - max_elems, - ) - add_smoothed_states_cov_diag = _should_record_field( - record_kwargs["record_smoothed_states_cov_diag"], (t1, state_dim), max_elems - ) - - if add_mean: - numpyro.deterministic(f"{name}_smoothed_states_mean", mean) - if add_chol_cov: - numpyro.deterministic(f"{name}_smoothed_states_chol_cov", chol_cov) - - if add_smoothed_states_cov or add_smoothed_states_cov_diag: - smoothed_cov = covariance_from_cholesky(chol_cov) - - if add_smoothed_states_cov: - numpyro.deterministic(f"{name}_smoothed_states_cov", smoothed_cov) - if add_smoothed_states_cov_diag: - diag_cov = jnp.diagonal(smoothed_cov, axis1=1, axis2=2) - numpyro.deterministic(f"{name}_smoothed_states_cov_diag", diag_cov) - - def run_discrete_smoother( name: str, dynamics: DynamicalModel, @@ -351,11 +257,15 @@ def run_discrete_smoother( ctrl_times=None, ctrl_values=None, **kwargs, -) -> list[dist.Distribution]: - """Run discrete-time smoother via cuthbert.""" +) -> tuple[jax.Array, object, list[dist.Distribution]]: + """Run discrete-time smoother via cuthbert. + + Returns: + tuple: (marginal_loglik, raw_states, smoothed_dists). + """ t1 = int(obs_values.shape[0]) if t1 == 0: - return [] + return jnp.array(0.0), None, [] marginal_loglik, states = compute_cuthbert_smoother( dynamics, @@ -366,19 +276,11 @@ def run_discrete_smoother( ctrl_times=ctrl_times, ctrl_values=ctrl_values, ) - record_kwargs = _config_to_smoother_record_kwargs(smoother_config) - - numpyro.factor(f"{name}_marginal_log_likelihood", marginal_loglik) - numpyro.deterministic(f"{name}_marginal_loglik", marginal_loglik) - - if isinstance(smoother_config, PFSmootherConfig): - _add_sites_pf(name, states, record_kwargs) - else: - _add_sites_taylor_kf(name, states, record_kwargs) - return _cholesky_state_sequence_to_dists( + smoothed_dists = _cholesky_state_sequence_to_dists( states, particle_mode=isinstance(smoother_config, PFSmootherConfig), ) + return marginal_loglik, states, smoothed_dists __all__ = ["compute_cuthbert_smoother", "run_discrete_smoother"] diff --git a/dynestyx/inference/numpyro_sites.py b/dynestyx/inference/numpyro_sites.py new file mode 100644 index 00000000..0e0e92b7 --- /dev/null +++ b/dynestyx/inference/numpyro_sites.py @@ -0,0 +1,361 @@ +"""NumPyro site registration for filter and smoother outputs. + +All numpyro.factor and numpyro.deterministic calls live here, +keeping the integration backends (cuthbert, cd-dynamax) numpyro-free. +""" + +import jax +import jax.numpy as jnp +import numpyro + +from dynestyx.inference.filter_configs import ( + BaseFilterConfig, + ContinuousTimeConfigs, + HMMConfig, + PFConfig, + _config_to_record_kwargs, +) +from dynestyx.inference.integrations.utils import covariance_from_cholesky +from dynestyx.inference.smoother_configs import ( + BaseSmootherConfig, + PFSmootherConfig, + _config_to_smoother_record_kwargs, +) +from dynestyx.utils import _should_record_field + + +def register_filter_sites( + name: str, + marginal_loglik: jax.Array, + states: object, + filter_config: BaseFilterConfig, +) -> None: + """Register numpyro.factor and deterministic sites for a filter run.""" + numpyro.factor(f"{name}_marginal_log_likelihood", marginal_loglik) + numpyro.deterministic(f"{name}_marginal_loglik", marginal_loglik) + + if isinstance(filter_config, HMMConfig): + return + + record_kwargs = _config_to_record_kwargs(filter_config) + + if isinstance(filter_config, tuple(ContinuousTimeConfigs)): + _add_continuous_filter_sites(name, states, record_kwargs) + elif isinstance(filter_config, PFConfig): + _add_cuthbert_pf_sites(name, states, record_kwargs) + else: + _add_gaussian_filter_sites(name, states, filter_config, record_kwargs) + + +def register_hmm_filter_sites( + name: str, + loglik: jax.Array, + log_filt_seq: jax.Array, + filter_config: HMMConfig, +) -> None: + """Register numpyro sites for HMM filter output.""" + numpyro.factor(f"{name}_marginal_log_likelihood", loglik) + numpyro.deterministic(f"{name}_marginal_loglik", loglik) + + record_max_elems = filter_config.record_max_elems + + if _should_record_field( + filter_config.record_log_filtered, log_filt_seq.shape, record_max_elems + ): + numpyro.deterministic(f"{name}_log_filtered_states", log_filt_seq) + + if _should_record_field( + filter_config.record_filtered, log_filt_seq.shape, record_max_elems + ): + numpyro.deterministic(f"{name}_filtered_states", jnp.exp(log_filt_seq)) + + +def register_smoother_sites( + name: str, + marginal_loglik: jax.Array, + states: object, + smoother_config: BaseSmootherConfig, +) -> None: + """Register numpyro.factor and deterministic sites for a smoother run.""" + numpyro.factor(f"{name}_marginal_log_likelihood", marginal_loglik) + numpyro.deterministic(f"{name}_marginal_loglik", marginal_loglik) + + record_kwargs = _config_to_smoother_record_kwargs(smoother_config) + + if isinstance(smoother_config, PFSmootherConfig): + _add_cuthbert_pf_smoother_sites(name, states, record_kwargs) + elif hasattr(states, "smoothed_means"): + _add_cd_dynamax_smoother_sites(name, states, record_kwargs) + elif states is not None and hasattr(states, "mean"): + _add_cuthbert_gaussian_smoother_sites(name, states, record_kwargs) + + +def _add_continuous_filter_sites(name: str, filtered, record_kwargs: dict) -> None: + max_elems = record_kwargs["record_max_elems"] + means_shape = filtered.filtered_means.shape + cov_shape = filtered.filtered_covariances.shape + add_mean = _should_record_field( + record_kwargs["record_filtered_states_mean"], means_shape, max_elems + ) + add_cov = _should_record_field( + record_kwargs["record_filtered_states_cov"], cov_shape, max_elems + ) + add_cov_diag = _should_record_field( + record_kwargs["record_filtered_states_cov_diag"], + (cov_shape[0], cov_shape[1]), + max_elems, + ) + if add_mean: + numpyro.deterministic(f"{name}_filtered_states_mean", filtered.filtered_means) + if add_cov: + numpyro.deterministic( + f"{name}_filtered_states_cov", filtered.filtered_covariances + ) + if add_cov_diag: + diag_cov = jnp.diagonal(filtered.filtered_covariances, axis1=1, axis2=2) + numpyro.deterministic(f"{name}_filtered_states_cov_diag", diag_cov) + + +def _add_cuthbert_pf_sites(name: str, states, record_kwargs: dict) -> None: + log_weights = states.log_weights + particles = states.particles + if particles.ndim == 2: + particles = particles[..., None] + max_elems = record_kwargs["record_max_elems"] + t_len, n_particles, state_dim = particles.shape + + add_particles = _should_record_field( + record_kwargs["record_filtered_particles"], particles.shape, max_elems + ) + add_log_weights = _should_record_field( + record_kwargs["record_filtered_log_weights"], log_weights.shape, max_elems + ) + add_mean = _should_record_field( + record_kwargs["record_filtered_states_mean"], (t_len, state_dim), max_elems + ) + add_filtered_states_cov = _should_record_field( + record_kwargs["record_filtered_states_cov"], + (t_len, state_dim, state_dim), + max_elems, + ) + add_filtered_states_cov_diag = _should_record_field( + record_kwargs["record_filtered_states_cov_diag"], (t_len, state_dim), max_elems + ) + + need_filtered_means = ( + add_mean or add_filtered_states_cov or add_filtered_states_cov_diag + ) + + if need_filtered_means: + w = jax.nn.softmax(log_weights, axis=1)[..., None] + filtered_means = jnp.sum(particles * w, axis=1) + + if add_filtered_states_cov or add_filtered_states_cov_diag: + second_mom = jnp.einsum( + "...tnj,...tnk,...tn->...tjk", particles, particles, w.squeeze(-1) + ) + filtered_covariances = second_mom - jnp.einsum( + "...tj,...tk->...tjk", filtered_means, filtered_means + ) + + if add_particles: + numpyro.deterministic(f"{name}_filtered_particles", particles) + if add_log_weights: + numpyro.deterministic(f"{name}_filtered_log_weights", log_weights) + if add_mean: + numpyro.deterministic(f"{name}_filtered_states_mean", filtered_means) + if add_filtered_states_cov: + numpyro.deterministic(f"{name}_filtered_states_cov", filtered_covariances) + if add_filtered_states_cov_diag: + diag_cov = jnp.diagonal(filtered_covariances, axis1=1, axis2=2) + numpyro.deterministic(f"{name}_filtered_states_cov_diag", diag_cov) + + +def _add_gaussian_filter_sites( + name: str, states, filter_config: BaseFilterConfig, record_kwargs: dict +) -> None: + """Sites for cuthbert Gaussian filters (KF, EKF, EnKF) or cd-dynamax discrete.""" + max_elems = record_kwargs["record_max_elems"] + + if hasattr(states, "filtered_means"): + means = states.filtered_means + covs = states.filtered_covariances + if means is None: + return + t_len, state_dim = means.shape + add_mean = _should_record_field( + record_kwargs["record_filtered_states_mean"], means.shape, max_elems + ) + add_cov = _should_record_field( + record_kwargs["record_filtered_states_cov"], + (t_len, state_dim, state_dim), + max_elems, + ) + add_cov_diag = _should_record_field( + record_kwargs["record_filtered_states_cov_diag"], + (t_len, state_dim), + max_elems, + ) + if add_mean: + numpyro.deterministic(f"{name}_filtered_states_mean", means) + if add_cov and covs is not None: + numpyro.deterministic(f"{name}_filtered_states_cov", covs) + if add_cov_diag and covs is not None: + diag_cov = jnp.diagonal(covs, axis1=1, axis2=2) + numpyro.deterministic(f"{name}_filtered_states_cov_diag", diag_cov) + elif hasattr(states, "mean"): + mean = states.mean + chol_cov = states.chol_cov + t_len, state_dim, _ = chol_cov.shape + + add_mean = _should_record_field( + record_kwargs["record_filtered_states_mean"], mean.shape, max_elems + ) + add_chol_cov = _should_record_field( + record_kwargs["record_filtered_states_chol_cov"], + chol_cov.shape, + max_elems, + ) + add_filtered_states_cov = _should_record_field( + record_kwargs["record_filtered_states_cov"], + (t_len, state_dim, state_dim), + max_elems, + ) + add_filtered_states_cov_diag = _should_record_field( + record_kwargs["record_filtered_states_cov_diag"], + (t_len, state_dim), + max_elems, + ) + + if add_mean: + numpyro.deterministic(f"{name}_filtered_states_mean", mean) + if add_chol_cov: + numpyro.deterministic(f"{name}_filtered_states_chol_cov", chol_cov) + + if add_filtered_states_cov or add_filtered_states_cov_diag: + filtered_cov = covariance_from_cholesky(chol_cov) + + if add_filtered_states_cov: + numpyro.deterministic(f"{name}_filtered_states_cov", filtered_cov) + if add_filtered_states_cov_diag: + diag_cov = jnp.diagonal(filtered_cov, axis1=1, axis2=2) + numpyro.deterministic(f"{name}_filtered_states_cov_diag", diag_cov) + + +def _add_cuthbert_pf_smoother_sites(name: str, states, record_kwargs: dict) -> None: + log_weights = states.log_weights + particles = states.particles + if particles.ndim == 2: + particles = particles[..., None] + max_elems = record_kwargs["record_max_elems"] + t1, _, state_dim = particles.shape + + add_particles = _should_record_field( + record_kwargs["record_smoothed_particles"], particles.shape, max_elems + ) + add_log_weights = _should_record_field( + record_kwargs["record_smoothed_log_weights"], log_weights.shape, max_elems + ) + add_mean = _should_record_field( + record_kwargs["record_smoothed_states_mean"], (t1, state_dim), max_elems + ) + add_smoothed_states_cov = _should_record_field( + record_kwargs["record_smoothed_states_cov"], + (t1, state_dim, state_dim), + max_elems, + ) + add_smoothed_states_cov_diag = _should_record_field( + record_kwargs["record_smoothed_states_cov_diag"], (t1, state_dim), max_elems + ) + + need_means = add_mean or add_smoothed_states_cov or add_smoothed_states_cov_diag + if need_means: + w = jax.nn.softmax(log_weights, axis=1)[..., None] + smoothed_means = jnp.sum(particles * w, axis=1) + + if add_smoothed_states_cov or add_smoothed_states_cov_diag: + second_mom = jnp.einsum( + "...tnj,...tnk,...tn->...tjk", particles, particles, w.squeeze(-1) + ) + smoothed_cov = second_mom - jnp.einsum( + "...tj,...tk->...tjk", smoothed_means, smoothed_means + ) + + if add_particles: + numpyro.deterministic(f"{name}_smoothed_particles", particles) + if add_log_weights: + numpyro.deterministic(f"{name}_smoothed_log_weights", log_weights) + if add_mean: + numpyro.deterministic(f"{name}_smoothed_states_mean", smoothed_means) + if add_smoothed_states_cov: + numpyro.deterministic(f"{name}_smoothed_states_cov", smoothed_cov) + if add_smoothed_states_cov_diag: + diag_cov = jnp.diagonal(smoothed_cov, axis1=1, axis2=2) + numpyro.deterministic(f"{name}_smoothed_states_cov_diag", diag_cov) + + +def _add_cd_dynamax_smoother_sites(name: str, posterior, record_kwargs: dict) -> None: + max_elems = record_kwargs["record_max_elems"] + means = posterior.smoothed_means + covs = posterior.smoothed_covariances + if means is None or covs is None: + return + t1, state_dim = means.shape + add_mean = _should_record_field( + record_kwargs["record_smoothed_states_mean"], means.shape, max_elems + ) + add_cov = _should_record_field( + record_kwargs["record_smoothed_states_cov"], + (t1, state_dim, state_dim), + max_elems, + ) + add_cov_diag = _should_record_field( + record_kwargs["record_smoothed_states_cov_diag"], (t1, state_dim), max_elems + ) + if add_mean: + numpyro.deterministic(f"{name}_smoothed_states_mean", means) + if add_cov: + numpyro.deterministic(f"{name}_smoothed_states_cov", covs) + if add_cov_diag: + diag_cov = jnp.diagonal(covs, axis1=1, axis2=2) + numpyro.deterministic(f"{name}_smoothed_states_cov_diag", diag_cov) + + +def _add_cuthbert_gaussian_smoother_sites( + name: str, states, record_kwargs: dict +) -> None: + max_elems = record_kwargs["record_max_elems"] + mean = states.mean + chol_cov = states.chol_cov + t1, state_dim, _ = chol_cov.shape + + add_mean = _should_record_field( + record_kwargs["record_smoothed_states_mean"], mean.shape, max_elems + ) + add_chol_cov = _should_record_field( + record_kwargs["record_smoothed_states_chol_cov"], + chol_cov.shape, + max_elems, + ) + add_smoothed_states_cov = _should_record_field( + record_kwargs["record_smoothed_states_cov"], + (t1, state_dim, state_dim), + max_elems, + ) + add_smoothed_states_cov_diag = _should_record_field( + record_kwargs["record_smoothed_states_cov_diag"], (t1, state_dim), max_elems + ) + + if add_mean: + numpyro.deterministic(f"{name}_smoothed_states_mean", mean) + if add_chol_cov: + numpyro.deterministic(f"{name}_smoothed_states_chol_cov", chol_cov) + + if add_smoothed_states_cov or add_smoothed_states_cov_diag: + smoothed_cov = covariance_from_cholesky(chol_cov) + + if add_smoothed_states_cov: + numpyro.deterministic(f"{name}_smoothed_states_cov", smoothed_cov) + if add_smoothed_states_cov_diag: + diag_cov = jnp.diagonal(smoothed_cov, axis1=1, axis2=2) + numpyro.deterministic(f"{name}_smoothed_states_cov_diag", diag_cov) diff --git a/dynestyx/inference/smoothers.py b/dynestyx/inference/smoothers.py index e0658b4e..0c4bc667 100644 --- a/dynestyx/inference/smoothers.py +++ b/dynestyx/inference/smoothers.py @@ -1,5 +1,6 @@ import dataclasses import math +from abc import ABC, abstractmethod from typing import cast import equinox as eqx @@ -11,7 +12,7 @@ from effectful.ops.syntax import ObjectInterpretation, implements from jaxtyping import Array, PRNGKeyArray, Real -from dynestyx.handlers import HandlesSelf, _sample_intp +from dynestyx.handlers import HandlesSelf, _infer_intp from dynestyx.inference.checkers import ( _validate_batched_plate_alignment, _validate_missing_observation_support, @@ -37,6 +38,7 @@ from dynestyx.inference.integrations.cuthbert.discrete_smoother import ( run_discrete_smoother as run_cuthbert_discrete_smoother, ) +from dynestyx.inference.numpyro_sites import register_smoother_sites from dynestyx.inference.plate_utils import _array_plate_axis, _make_plate_in_axes from dynestyx.inference.smoother_configs import ( BaseSmootherConfig, @@ -50,7 +52,7 @@ UKFSmootherConfig, ) from dynestyx.models import DynamicalModel -from dynestyx.types import FunctionOfTime +from dynestyx.types import FunctionOfTime, InferResult DiscreteSmootherConfig = ( KFSmootherConfig | EKFSmootherConfig | UKFSmootherConfig | PFSmootherConfig @@ -101,10 +103,10 @@ def _final_obs_times_for_rollout( return obs_times[..., -1:] -class BaseSmootherLogFactorAdder(ObjectInterpretation, HandlesSelf): +class BaseSmootherLogFactorAdder(ObjectInterpretation, HandlesSelf, ABC): """Base class for smoother handlers.""" - @implements(_sample_intp) + @implements(_infer_intp) def _sample_ds( self, name: str, @@ -148,7 +150,8 @@ def _sample_ds( smoothed_times = None smoothed_dists = None - return fwd( + # fwd() lets handlers above (e.g. Simulator) use smoothed_dists for rollout. + fwd( name, dynamics, plate_shapes=plate_shapes, @@ -165,6 +168,9 @@ def _sample_ds( **kwargs, ) + return self._build_infer_result(name, smoothed_dists) + + @abstractmethod def _add_log_factors( self, name: str, @@ -180,8 +186,12 @@ def _add_log_factors( | Real[Array, "*ctrl_value_plate ctrl_time"] | None = None, **kwargs, - ) -> list[numpyro.distributions.Distribution] | None: - raise NotImplementedError() + ) -> list[numpyro.distributions.Distribution] | None: ... + + @abstractmethod + def _build_infer_result( + self, name: str, smoothed_dists: list | None + ) -> InferResult: ... @dataclasses.dataclass @@ -189,6 +199,13 @@ class Smoother(BaseSmootherLogFactorAdder): r"""Performs Bayesian smoothing to compute the smoothing distribution p(x_t | y_{1:T}).""" smoother_config: SmootherAnyConfig | None = None + marginal_loglik: jax.Array | None = dataclasses.field( + default=None, repr=False, init=False + ) + smoothed_states: object = dataclasses.field(default=None, repr=False, init=False) + _smoother_config_used: BaseSmootherConfig | None = dataclasses.field( + default=None, repr=False, init=False + ) def _add_log_factors( self, @@ -229,12 +246,20 @@ def _add_log_factors( mode="smoother", ) + # Resolve PRNG key: use explicit seed from config, fall back to numpyro + # context (inside a seeded model), or None (deterministic smoothers don't need one). typed_config = config - key = ( - numpyro.prng_key() - if typed_config.crn_seed is None - else typed_config.crn_seed - ) + if typed_config.crn_seed is not None: + key = typed_config.crn_seed + else: + import warnings # noqa: PLC0415 + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + key = numpyro.prng_key() + except Exception: + key = None if plate_shapes: return self._add_log_factors_batched( @@ -257,7 +282,7 @@ def _add_log_factors( f"Valid continuous-time config types: {valid}" ) continuous_config = cast(ContinuousSmootherConfig, typed_config) - return _smooth_continuous_time( + marginal_loglik, states, smoothed_dists = _smooth_continuous_time( name, dynamics, continuous_config, @@ -268,25 +293,58 @@ def _add_log_factors( ctrl_values=ctrl_values, **kwargs, ) - - if not isinstance(typed_config, DiscreteTimeSmootherConfigs): + elif not isinstance(typed_config, DiscreteTimeSmootherConfigs): valid = _valid_smoother_config_names(continuous_time=False) raise ValueError( f"Invalid smoother config: {type(typed_config).__name__}. " f"Valid discrete-time config types: {valid}" ) - discrete_config = cast(DiscreteSmootherConfig, typed_config) + else: + discrete_config = cast(DiscreteSmootherConfig, typed_config) + marginal_loglik, states, smoothed_dists = _smooth_discrete_time( + name, + dynamics, + discrete_config, + key=key, + obs_times=obs_times, + obs_values=obs_values, + ctrl_times=ctrl_times, + ctrl_values=ctrl_values, + **kwargs, + ) - return _smooth_discrete_time( - name, - dynamics, - discrete_config, - key=key, - obs_times=obs_times, - obs_values=obs_values, - ctrl_times=ctrl_times, - ctrl_values=ctrl_values, - **kwargs, + self.marginal_loglik = marginal_loglik + self.smoothed_states = states + self._smoother_config_used = typed_config + + return smoothed_dists + + def _build_infer_result( + self, name: str, smoothed_dists: list | None + ) -> InferResult: + """Construct InferResult with a deferred numpyro registration callback.""" + marginal_loglik = self.marginal_loglik + states = self.smoothed_states + config = self._smoother_config_used + _is_batched = ( + isinstance(marginal_loglik, jax.Array) and marginal_loglik.ndim > 0 + ) + + def _register(site_name: str) -> None: + if marginal_loglik is None or config is None: + return + if _is_batched: + # TODO: support per-field recording for batched (plate) states + numpyro.factor(f"{site_name}_marginal_log_likelihood", marginal_loglik) + numpyro.deterministic(f"{site_name}_marginal_loglik", marginal_loglik) + else: + register_smoother_sites(site_name, marginal_loglik, states, config) + + return InferResult( + marginal_loglik=marginal_loglik, + states=states, + dists=smoothed_dists, + _register_numpyro_sites=_register, ) def _add_log_factors_batched( @@ -419,8 +477,9 @@ def compute_output(dyn, ot, ov, ct, cv, k): else: raise ValueError(f"Unsupported batched output kind: {output_kind}") - numpyro.factor(f"{name}_marginal_log_likelihood", marginal_logliks) - numpyro.deterministic(f"{name}_marginal_loglik", marginal_logliks) + self.marginal_loglik = marginal_logliks + self.smoothed_states = outputs + self._smoother_config_used = config if output_kind == "continuous": return _posterior_sequence_to_dists( @@ -468,7 +527,7 @@ def _smooth_discrete_time( | Real[Array, "*ctrl_value_plate ctrl_time"] | None = None, **kwargs, -) -> list[numpyro.distributions.Distribution]: +) -> tuple[jax.Array, object, list[numpyro.distributions.Distribution]]: """Discrete-time marginal likelihood via cuthbert or cd-dynamax smoothers.""" if isinstance(smoother_config, UKFSmootherConfig) and ( @@ -488,7 +547,7 @@ def _smooth_discrete_time( ) if smoother_config.filter_source == "cd_dynamax": - return run_cd_dynamax_discrete_smoother( + marginal_loglik, states, smoothed_dists = run_cd_dynamax_discrete_smoother( name, dynamics, smoother_config, @@ -498,15 +557,14 @@ def _smooth_discrete_time( ctrl_values=ctrl_values, **kwargs, ) - - if smoother_config.filter_source == "cuthbert": + elif smoother_config.filter_source == "cuthbert": if isinstance(smoother_config, UKFSmootherConfig): raise ValueError( "UKF smoothing is not available in cuthbert. " "Use UKFSmootherConfig(filter_source='cd_dynamax') or a cuthbert-supported smoother " "(KFSmootherConfig, EKFSmootherConfig, PFSmootherConfig)." ) - return run_cuthbert_discrete_smoother( + marginal_loglik, states, smoothed_dists = run_cuthbert_discrete_smoother( name, dynamics, smoother_config, @@ -517,8 +575,10 @@ def _smooth_discrete_time( ctrl_values=ctrl_values, **kwargs, ) + else: + raise ValueError(f"Unknown filter source: {smoother_config.filter_source}") - raise ValueError(f"Unknown filter source: {smoother_config.filter_source}") + return marginal_loglik, states, smoothed_dists def _smooth_continuous_time( @@ -535,14 +595,14 @@ def _smooth_continuous_time( | Real[Array, "*ctrl_value_plate ctrl_time"] | None = None, **kwargs, -) -> list[numpyro.distributions.Distribution]: +) -> tuple[jax.Array, object, list[numpyro.distributions.Distribution]]: """Continuous-time marginal likelihood via CD-Dynamax smoothers.""" if smoother_config.filter_source != "cd_dynamax": raise ValueError( f"{type(smoother_config).__name__} supports only filter_source='cd_dynamax'." ) - return run_continuous_smoother( + marginal_loglik, smoothed, smoothed_dists = run_continuous_smoother( name, dynamics, smoother_config, @@ -553,6 +613,7 @@ def _smooth_continuous_time( ctrl_values=ctrl_values, **kwargs, ) + return marginal_loglik, smoothed, smoothed_dists __all__ = [ diff --git a/dynestyx/simulators.py b/dynestyx/simulators.py index f3673843..1bd23a45 100644 --- a/dynestyx/simulators.py +++ b/dynestyx/simulators.py @@ -19,7 +19,7 @@ from jaxtyping import Real from numpyro.contrib.control_flow import scan as nscan -from dynestyx.handlers import HandlesSelf, _sample_intp +from dynestyx.handlers import HandlesSelf, _infer_intp from dynestyx.inference.integrations.utils import WeightedParticles from dynestyx.models import ( DeterministicContinuousTimeStateEvolution, @@ -553,7 +553,7 @@ def _run_plated_simulation( stacked[key] = flat.reshape(*plate_shapes, *values[0].shape) return stacked - @implements(_sample_intp) + @implements(_infer_intp) def _sample_ds( self, name: str, diff --git a/dynestyx/types.py b/dynestyx/types.py index 9e5f445b..037eefe2 100644 --- a/dynestyx/types.py +++ b/dynestyx/types.py @@ -1,7 +1,10 @@ """Shared typing helpers for dynamical systems.""" +import dataclasses +from collections.abc import Callable from typing import Protocol, runtime_checkable +import jax import jax.numpy as jnp from jaxtyping import Array, Real @@ -14,6 +17,30 @@ def __call__( raise NotImplementedError() +@dataclasses.dataclass +class InferResult: + """Result of dsx.infer — the numpyro-free inference primitive. + + Carries all outputs from the handler stack (Filter, Smoother, etc.) + without registering any numpyro sites. + """ + + marginal_loglik: jax.Array | None = None + states: object = None + dists: list | None = None + _register_numpyro_sites: Callable[[str], None] | None = dataclasses.field( + default=None, repr=False + ) + + def __call__( + self, t: float | int | Real[Array, ""] + ) -> Real[Array, " state_dim"] | Real[Array, ""]: + raise NotImplementedError( + "InferResult is not callable as a FunctionOfTime. " + "Access .marginal_loglik, .states, or .dists instead." + ) + + def as_scalar_time_array( value: float | int | Array, *, name: str, dtype=None ) -> Real[Array, ""]: diff --git a/tests/test_filter_standalone.py b/tests/test_filter_standalone.py new file mode 100644 index 00000000..f4ba6c88 --- /dev/null +++ b/tests/test_filter_standalone.py @@ -0,0 +1,125 @@ +"""Tests for dsx.infer (standalone, no numpyro) and dsx.sample (numpyro model).""" + +import jax +import jax.numpy as jnp +import jax.random as jr +import optax + +import dynestyx as dsx +from dynestyx.inference.filter_configs import EnKFConfig, KFConfig +from dynestyx.inference.filters import Filter +from dynestyx.types import InferResult + + +def _make_lti_dynamics(alpha): + return dsx.LTI_discrete( + A=jnp.array([[alpha, 0.1], [0.1, 0.8]]), + Q=0.1 * jnp.eye(2), + H=jnp.array([[1.0, 0.0]]), + R=jnp.array([[0.25]]), + ) + + +def _make_data(): + obs_times = jnp.arange(0.0, 10.0, 1.0) + key = jr.PRNGKey(42) + obs_values = jr.normal(key, (len(obs_times), 1)) + return obs_times, obs_values + + +# --- dsx.infer tests (standalone, no numpyro) --- + + +def test_infer_returns_infer_result(): + """dsx.infer returns an InferResult with marginal_loglik.""" + obs_times, obs_values = _make_data() + dynamics = _make_lti_dynamics(0.5) + + with Filter(filter_config=KFConfig(filter_source="cuthbert")): + result = dsx.infer("f", dynamics, obs_times=obs_times, obs_values=obs_values) + + assert isinstance(result, InferResult) + assert result.marginal_loglik is not None + assert jnp.isfinite(result.marginal_loglik) + assert result.states is not None + assert result.dists is not None + + +def test_infer_enkf_with_crn_seed(): + """dsx.infer works with EnKF and explicit crn_seed.""" + obs_times, obs_values = _make_data() + dynamics = _make_lti_dynamics(0.5) + + with Filter( + filter_config=EnKFConfig(n_particles=16, crn_seed=jr.PRNGKey(0)), + ): + result = dsx.infer("f", dynamics, obs_times=obs_times, obs_values=obs_values) + + assert isinstance(result, InferResult) + assert result.marginal_loglik is not None + assert jnp.isfinite(result.marginal_loglik) + + +def test_infer_optax_mle(): + """Use dsx.infer + optax to do MLE without numpyro.""" + obs_times, obs_values = _make_data() + + def neg_loglik(alpha): + dynamics = _make_lti_dynamics(alpha) + with Filter(filter_config=KFConfig(filter_source="cuthbert")): + result = dsx.infer( + "f", dynamics, obs_times=obs_times, obs_values=obs_values + ) + return -result.marginal_loglik + + optimizer = optax.adam(1e-2) + alpha = jnp.array(0.3) + opt_state = optimizer.init(alpha) + + initial_loss = neg_loglik(alpha) + grad_fn = jax.grad(neg_loglik) + + for _ in range(20): + grads = grad_fn(alpha) + updates, opt_state = optimizer.update(grads, opt_state) + alpha = optax.apply_updates(alpha, updates) + + final_loss = neg_loglik(alpha) + assert final_loss < initial_loss + + +def test_infer_does_not_register_numpyro_sites(): + """dsx.infer does NOT register numpyro sites — that's dsx.sample's job.""" + from numpyro.handlers import seed, trace + + obs_times, obs_values = _make_data() + dynamics = _make_lti_dynamics(0.5) + + with trace() as tr, seed(rng_seed=jr.PRNGKey(0)): + with Filter(filter_config=KFConfig(filter_source="cuthbert")): + result = dsx.infer( + "f", dynamics, obs_times=obs_times, obs_values=obs_values + ) + + assert isinstance(result, InferResult) + assert result.marginal_loglik is not None + assert "f_marginal_loglik" not in tr + assert "f_marginal_log_likelihood" not in tr + + +# --- dsx.sample tests (numpyro model) --- + + +def test_sample_registers_numpyro_sites(): + """dsx.sample registers numpyro sites via the callback.""" + from numpyro.handlers import seed, trace + + obs_times, obs_values = _make_data() + dynamics = _make_lti_dynamics(0.5) + + with trace() as tr, seed(rng_seed=jr.PRNGKey(0)): + with Filter(filter_config=KFConfig(filter_source="cuthbert")): + dsx.sample("f", dynamics, obs_times=obs_times, obs_values=obs_values) + + assert "f_marginal_loglik" in tr + assert jnp.isfinite(tr["f_marginal_loglik"]["value"]) diff --git a/tests/test_filters.py b/tests/test_filters.py index 048aa0f5..77638ead 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -209,15 +209,14 @@ def test_cuthbert_filtered_distribution_shapes_match_observations(filter_config) obs_times, obs_values = _make_discrete_lti_data() dynamics = _make_discrete_lti_dynamics() - with trace(), seed(rng_seed=jr.PRNGKey(1)): - filtered_dists = run_cuthbert_discrete_filter( - "f", - dynamics, - filter_config, - key=jr.PRNGKey(2), - obs_times=obs_times, - obs_values=obs_values, - ) + _marginal_loglik, _states, filtered_dists = run_cuthbert_discrete_filter( + "f", + dynamics, + filter_config, + key=jr.PRNGKey(2), + obs_times=obs_times, + obs_values=obs_values, + ) assert len(filtered_dists) == len(obs_times) for filtered_dist in filtered_dists: @@ -293,15 +292,14 @@ def test_cuthbert_enkf_accepts_callable_independent_normal_observation(): ), ) - with trace(), seed(rng_seed=jr.PRNGKey(1)): - filtered_dists = run_cuthbert_discrete_filter( - "f", - dynamics, - EnKFConfig(n_particles=16, filter_source="cuthbert"), - key=jr.PRNGKey(2), - obs_times=obs_times, - obs_values=obs_values, - ) + _marginal_loglik, _states, filtered_dists = run_cuthbert_discrete_filter( + "f", + dynamics, + EnKFConfig(n_particles=16, filter_source="cuthbert"), + key=jr.PRNGKey(2), + obs_times=obs_times, + obs_values=obs_values, + ) assert len(filtered_dists) == len(obs_times) assert all(d.event_shape == (dynamics.state_dim,) for d in filtered_dists) diff --git a/tests/test_smoother_standalone.py b/tests/test_smoother_standalone.py new file mode 100644 index 00000000..64b62021 --- /dev/null +++ b/tests/test_smoother_standalone.py @@ -0,0 +1,110 @@ +"""Tests for dsx.infer with Smoother (standalone, no numpyro).""" + +import jax +import jax.numpy as jnp +import jax.random as jr +import optax + +import dynestyx as dsx +from dynestyx.inference.smoother_configs import KFSmootherConfig +from dynestyx.inference.smoothers import Smoother +from dynestyx.types import InferResult + + +def _make_lti_dynamics(alpha): + return dsx.LTI_discrete( + A=jnp.array([[alpha, 0.1], [0.1, 0.8]]), + Q=0.1 * jnp.eye(2), + H=jnp.array([[1.0, 0.0]]), + R=jnp.array([[0.25]]), + ) + + +def _make_data(): + obs_times = jnp.arange(0.0, 10.0, 1.0) + key = jr.PRNGKey(42) + obs_values = jr.normal(key, (len(obs_times), 1)) + return obs_times, obs_values + + +# --- dsx.infer tests (standalone, no numpyro) --- + + +def test_infer_smoother_returns_infer_result(): + """dsx.infer with Smoother returns an InferResult.""" + obs_times, obs_values = _make_data() + dynamics = _make_lti_dynamics(0.5) + + with Smoother(smoother_config=KFSmootherConfig(filter_source="cuthbert")): + result = dsx.infer("f", dynamics, obs_times=obs_times, obs_values=obs_values) + + assert isinstance(result, InferResult) + assert result.marginal_loglik is not None + assert jnp.isfinite(result.marginal_loglik) + assert result.states is not None + assert result.dists is not None + + +def test_infer_smoother_optax_mle(): + """Use dsx.infer + Smoother + optax to do MLE without numpyro.""" + obs_times, obs_values = _make_data() + + def neg_loglik(alpha): + dynamics = _make_lti_dynamics(alpha) + with Smoother(smoother_config=KFSmootherConfig(filter_source="cuthbert")): + result = dsx.infer( + "f", dynamics, obs_times=obs_times, obs_values=obs_values + ) + return -result.marginal_loglik + + optimizer = optax.adam(1e-2) + alpha = jnp.array(0.3) + opt_state = optimizer.init(alpha) + + initial_loss = neg_loglik(alpha) + grad_fn = jax.grad(neg_loglik) + + for _ in range(20): + grads = grad_fn(alpha) + updates, opt_state = optimizer.update(grads, opt_state) + alpha = optax.apply_updates(alpha, updates) + + final_loss = neg_loglik(alpha) + assert final_loss < initial_loss + + +def test_infer_smoother_does_not_register_numpyro_sites(): + """dsx.infer with Smoother does NOT register numpyro sites.""" + from numpyro.handlers import seed, trace + + obs_times, obs_values = _make_data() + dynamics = _make_lti_dynamics(0.5) + + with trace() as tr, seed(rng_seed=jr.PRNGKey(0)): + with Smoother(smoother_config=KFSmootherConfig(filter_source="cuthbert")): + result = dsx.infer( + "f", dynamics, obs_times=obs_times, obs_values=obs_values + ) + + assert isinstance(result, InferResult) + assert result.marginal_loglik is not None + assert "f_marginal_loglik" not in tr + assert "f_marginal_log_likelihood" not in tr + + +# --- dsx.sample tests (numpyro model) --- + + +def test_sample_smoother_registers_numpyro_sites(): + """dsx.sample with Smoother registers numpyro sites via the callback.""" + from numpyro.handlers import seed, trace + + obs_times, obs_values = _make_data() + dynamics = _make_lti_dynamics(0.5) + + with trace() as tr, seed(rng_seed=jr.PRNGKey(0)): + with Smoother(smoother_config=KFSmootherConfig(filter_source="cuthbert")): + dsx.sample("f", dynamics, obs_times=obs_times, obs_values=obs_values) + + assert "f_marginal_loglik" in tr + assert jnp.isfinite(tr["f_marginal_loglik"]["value"])