Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion dynestyx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
__version__ = version("dynestyx")

from dynestyx.discretizers import Discretizer, euler_maruyama
from dynestyx.handlers import plate, sample
from dynestyx.handlers import infer, plate, sample
from dynestyx.inference.filters import Filter
from dynestyx.inference.smoothers import Smoother
from dynestyx.models import (
Expand Down Expand Up @@ -33,6 +33,7 @@
SDESimulator,
Simulator,
)
from dynestyx.types import InferResult
from dynestyx.utils import flatten_draws

__all__ = [
Expand All @@ -56,6 +57,8 @@
"Filter",
"Smoother",
"flatten_draws",
"infer",
"InferResult",
"plate",
"sample",
"DiracIdentityObservation",
Expand Down
4 changes: 2 additions & 2 deletions dynestyx/discretizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from effectful.ops.syntax import ObjectInterpretation, implements
from jaxtyping import Array, Real

from dynestyx.handlers import HandlesSelf, _sample_intp
from dynestyx.handlers import HandlesSelf, _infer_intp
from dynestyx.models import (
DynamicalModel,
GaussianStateEvolution,
Expand Down Expand Up @@ -152,7 +152,7 @@ def __init__(self, discretize=euler_maruyama):
super().__init__()
self.discretize = discretize

@implements(_sample_intp)
@implements(_infer_intp)
def _sample_ds(
self,
name: str,
Expand Down
145 changes: 107 additions & 38 deletions dynestyx/handlers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Contains the `sample` primitive and `effectful` utilities for `dynestyx`."""
"""Contains the `sample` and `infer` primitives and `effectful` utilities for `dynestyx`."""

from typing import TypeVar

Expand All @@ -11,7 +11,7 @@
from dynestyx.models import (
DynamicalModel,
)
from dynestyx.types import FunctionOfTime
from dynestyx.types import FunctionOfTime, InferResult
from dynestyx.utils import (
_get_dynamics_with_t0,
_validate_control_dim,
Expand All @@ -22,7 +22,7 @@
T = TypeVar("T")


def sample(
def _validate_and_prepare(
name: str,
dynamics: DynamicalModel,
*,
Expand All @@ -35,36 +35,8 @@ def sample(
| Real[Array, "*ctrl_value_plate ctrl_time"]
| None = None,
predict_times: Real[Array, "*predict_time_plate predict_time"] | None = None,
**kwargs,
) -> FunctionOfTime:
"""
Samples from a dynamical model. This is the main primitive of dynestyx.

The `sample` primitive is meant to mimic the `numpyro.sample` primitive in usage,
but using a `DynamicalModel` instead of a `Distribution`.

The `sample` method calls `_sample_intp`, which is defined as a `defop` in `effectful`.
This is where any real "work" is done, after input validation.

Shape note:
Inside ``dsx.plate``, observation arrays use leading plate axes followed
by time and event axes, e.g. ``(N, T, obs_dim)``. Model parameters follow
the same leading-plate, trailing-event convention. See :class:`plate`
for the full plated-shape contract.

Parameters:
name: Name of the sample site.
dynamics: Dynamical model to sample from.
obs_times: Times at which to sample the observations.
obs_values: Values of the observations at the given times.
ctrl_times: Times at which to sample the controls.
ctrl_values: Values of the controls at the given times.
predict_times: Times at which to predict the observations.
**kwargs: Additional keyword arguments.

Returns:
FunctionOfTime: A function of time that samples from the dynamical model.
"""
) -> DynamicalModel:
"""Validate inputs and return dynamics with t0 resolved."""
# Rule: obs_times must be accompanied with obs_values, which should be the same length.
if obs_times is None and predict_times is None:
raise ValueError("At least one of obs_times or predict_times must be provided")
Expand Down Expand Up @@ -106,10 +78,50 @@ def sample(
_validate_control_dim(dynamics, ctrl_values)

# Initial dynamics may not have t0, which is then inferred from obs_times
dynamics_with_t0 = _get_dynamics_with_t0(dynamics, obs_times, predict_times)
return _get_dynamics_with_t0(dynamics, obs_times, predict_times)


def infer(
name: str,
dynamics: DynamicalModel,
*,
obs_times: Real[Array, "*obs_time_plate obs_time"] | None = None,
obs_values: Real[Array, "*obs_value_plate obs_time observation_dim"]
| Real[Array, "*obs_value_plate obs_time"]
| None = None,
ctrl_times: Real[Array, "*ctrl_time_plate ctrl_time"] | None = None,
ctrl_values: Real[Array, "*ctrl_value_plate ctrl_time control_dim"]
| Real[Array, "*ctrl_value_plate ctrl_time"]
| None = None,
predict_times: Real[Array, "*predict_time_plate predict_time"] | None = None,
**kwargs,
):
"""Run inference on a dynamical model without registering numpyro sites.

This is the numpyro-free entry point. When a Filter or Smoother handler
is active, returns an InferResult dataclass with marginal_loglik, states, etc.

# Pass to interpreted version of `sample` for inference.
return _sample_intp(
Parameters:
name: Name of the inference site.
dynamics: Dynamical model to infer.
obs_times: Times at which observations are available.
obs_values: Values of the observations at the given times.
ctrl_times: Times at which controls are applied.
ctrl_values: Values of the controls at the given times.
predict_times: Times at which to predict.
**kwargs: Additional keyword arguments.
"""
dynamics_with_t0 = _validate_and_prepare(
name,
dynamics,
obs_times=obs_times,
obs_values=obs_values,
ctrl_times=ctrl_times,
ctrl_values=ctrl_values,
predict_times=predict_times,
)

return _infer_intp(
name,
dynamics_with_t0,
obs_times=obs_times,
Expand All @@ -121,8 +133,65 @@ def sample(
)


def sample(
name: str,
dynamics: DynamicalModel,
*,
obs_times: Real[Array, "*obs_time_plate obs_time"] | None = None,
obs_values: Real[Array, "*obs_value_plate obs_time observation_dim"]
| Real[Array, "*obs_value_plate obs_time"]
| None = None,
ctrl_times: Real[Array, "*ctrl_time_plate ctrl_time"] | None = None,
ctrl_values: Real[Array, "*ctrl_value_plate ctrl_time control_dim"]
| Real[Array, "*ctrl_value_plate ctrl_time"]
| None = None,
predict_times: Real[Array, "*predict_time_plate predict_time"] | None = None,
**kwargs,
):
"""
Samples from a dynamical model. This is the main primitive of dynestyx.

The ``sample`` primitive is meant to mimic the ``numpyro.sample`` primitive
in usage, but using a ``DynamicalModel`` instead of a ``Distribution``.

Internally, ``sample`` calls ``dsx.infer(...)`` and then registers the
results as numpyro sites (``numpyro.factor``, ``numpyro.deterministic``).

Shape note:
Inside ``dsx.plate``, observation arrays use leading plate axes followed
by time and event axes, e.g. ``(N, T, obs_dim)``. Model parameters follow
the same leading-plate, trailing-event convention. See :class:`plate`
for the full plated-shape contract.

Parameters:
name: Name of the sample site.
dynamics: Dynamical model to sample from.
obs_times: Times at which to sample the observations.
obs_values: Values of the observations at the given times.
ctrl_times: Times at which to sample the controls.
ctrl_values: Values of the controls at the given times.
predict_times: Times at which to predict the observations.
**kwargs: Additional keyword arguments.
"""
result = infer(
name,
dynamics,
obs_times=obs_times,
obs_values=obs_values,
ctrl_times=ctrl_times,
ctrl_values=ctrl_values,
predict_times=predict_times,
**kwargs,
)

if isinstance(result, InferResult) and result._register_numpyro_sites is not None:
result._register_numpyro_sites(name)

return result


@defop
def _sample_intp(
def _infer_intp(
name: str,
dynamics: DynamicalModel,
*,
Expand Down Expand Up @@ -337,7 +406,7 @@ def __exit__(self, exc_type, exc, tb):
self._cm.__exit__(exc_type, exc, tb)
return self._numpyro_plate.__exit__(exc_type, exc, tb)

@implements(_sample_intp)
@implements(_infer_intp)
def _sample_ds(
self, name, dynamics, *, plate_shapes=(), **kwargs
) -> FunctionOfTime:
Expand Down
Loading
Loading