Skip to content

dsx.infer for filtering and smoothing without numpyro#238

Open
theorashid wants to merge 4 commits into
BasisResearch:dsx-infer-stagingfrom
theorashid:refactor/decouple-numpyro-filters-smoothers
Open

dsx.infer for filtering and smoothing without numpyro#238
theorashid wants to merge 4 commits into
BasisResearch:dsx-infer-stagingfrom
theorashid:refactor/decouple-numpyro-filters-smoothers

Conversation

@theorashid
Copy link
Copy Markdown

This is for #213.

So far it is just for Filter and Smoother. Simulator will have to come.

Basic idea is that Filter/Smoother no longer call numpyro.factor inline, it happens in dsx.sample. test_filters and test_smoothers still pass, happy days.

# non-numpyro path: returns InferResult directly (no side effects)
with Filter(filter_config=KFConfig(...)):
    result = dsx.infer("f", dynamics, obs_times=t, obs_values=y)
loss = -result.marginal_loglik

# numpyro path (unchanged from before):
with Filter(filter_config=KFConfig(...)):
    dsx.sample("f", dynamics, obs_times=t, obs_values=y)

In the background, sample just calls infer and then registers sites. dsx.infer returns InferResult which carries marginal_loglik, states, dists, and a private _register_numpyro_sites callback.

def sample(...):
    result = infer(...)
    if isinstance(result, InferResult) and result._register_numpyro_sites is not None:
        result._register_numpyro_sites(name)
    return result

The main work can be seen in the test_*_standalone.py, which I based from my cuthbert-models repo. This is how I would see it being used.

def test_infer_optax_mle():
    """Use dsx.infer + optax to do MLE without numpyro."""
    obs_times, obs_values = _make_data()

    def neg_loglik(alpha):
        dynamics = _make_lti_dynamics(alpha)
        with Filter(filter_config=KFConfig(filter_source="cuthbert")):
            result = dsx.infer(
                "f", dynamics, obs_times=obs_times, obs_values=obs_values
            )
        return -result.marginal_loglik

    optimizer = optax.adam(1e-2)
    alpha = jnp.array(0.3)
    opt_state = optimizer.init(alpha)

    initial_loss = neg_loglik(alpha)
    grad_fn = jax.grad(neg_loglik)

    for _ in range(20):
        grads = grad_fn(alpha)
        updates, opt_state = optimizer.update(grads, opt_state)
        alpha = optax.apply_updates(alpha, updates)

    final_loss = neg_loglik(alpha)
    assert final_loss < initial_loss

how these changes were made

Of course, this was largely done by burning tokens and handholding so that it matches the design that I (and then with feedback from both of you) wanted. I am not as familiar with the internals of the library, so if there are other places beyond the diff that you think these changes might affect, let me know where to look – I was relying a bit on existing tests not breaking.

smaller design things

I largely tried to keep everything that was there before to reduce the size of the refactor. But:

  • I did make BaseLogFactorAdder (and the Filter equivalent) an ABC because I think it made sense.
  • Plate, as before, does not register per-field sites.
  • Maybe we should rename _sample_intp to _infer_intp.
  • InferResult a __call__ shim to satisfy FunctionOfTime protocol (needed because some model functions return dsx.sample(...)). Effectful uses the @defop return annotation to decide the type of fwd() returns. InferResult.__call__ raises NotImplementedError to satisfy the protocol — we can't change it without reworking Simulator.

Decompose dsx.sample into dsx.infer (pure computation, returns InferResult)
+ numpyro site registration (via callback). Filter and Smoother are now
numpyro-free: they compute results and return InferResult with a deferred
_register_numpyro_sites callback that dsx.sample fires. All integration
backends return (marginal_loglik, states, dists) tuples with no side effects.
@theorashid
Copy link
Copy Markdown
Author

On the testing suite, sometimes tests/test_hierarchical_simulator_discretizer_smokes.py, tests/test_science/test_discrete_time_l63_mcmc.py are flaky. tests/test_science/test_hmm.py::test_mcmc_inference breaks for TypeError: only 0-dimensional arrays can be converted to Python scalars. test_science is pretty slow in general

@DanWaxman
Copy link
Copy Markdown
Collaborator

Thanks Theo! This seems directionally about right, I like the implementation strategy!

In the background, sample just calls infer and then registers sites. dsx.infer returns InferResult which carries marginal_loglik, states, dists, and a private _register_numpyro_sites callback.

That makes sense! It will be a tiny bit tricky to get working with Simulators, I think, but that's probably okay. One tricky part is you're allowed to stack simulators with filters, i.e., with Simulator(), Filter(): dsx.sample(...); this allows one to sample from the filtering/posterior predictive. We'll need to pass the corresponding InferResult and append the necessary information.

I am not as familiar with the internals of the library, so if there are other places beyond the diff that you think these changes might affect, let me know where to look

I don't have any off the top of my head, besides the aforementioned interaction with Simulators. The other main test will be running all the notebooks in the documentation from scratch and making sure the results are qualitatively similar, but I think that can come in a bit.

I did make BaseLogFactorAdder (and the Filter equivalent) an ABC because I think it made sense.

Agreed, thanks!

Maybe we should rename _sample_intp to _infer_intp.

Sure, I more-or-less agree.

sometimes tests/test_hierarchical_simulator_discretizer_smokes.py, tests/test_science/test_discrete_time_l63_mcmc.py are flaky

I'm surprised the simulator/discretizer smokes are flaky, I haven't run into that before... but it's not super surprising to me that the discrete_time_l63_mcmc is flaky. I think those were written before we made EnKF the discrete-time default.

test_science is pretty slow in general

Right... I think the test_science suite has somewhat fallen into disuse, and our tests in general are a bit of a mess (though I think with decent coverage -- just disorganized). It's been sitting on the backlog for a bit. I wouldn't worry too much about particularly slow test_science tests as long as docs in the notebooks are looking okay.


At the risk of getting ahead of myself on a draft PR, I think it makes sense to set up a staging branch for this. Then, we can try to land this PR in the staging branch; worry about Simulators and its various interactions afterwards; then worry about the documentation lift that this implies.

@theorashid
Copy link
Copy Markdown
Author

I think the test_science suite has somewhat fallen into disuse

Right, now I've seen the workflows I can see they just run the tests ignoring test_science.

I can give Simulator a go if you want it in this PR, but I'm wary of keeping the size small so it's easy for you to review – up to you.

@DanWaxman DanWaxman changed the base branch from main to dsx-infer-staging May 29, 2026 13:28
@DanWaxman
Copy link
Copy Markdown
Collaborator

I can give Simulator a go if you want it in this PR, but I'm wary of keeping the size small so it's easy for you to review – up to you.

I think it makes sense to keep the PRs small, but also want to minimize the amount of half-working features on the upstream. So I've changed the base branch to dsx-infer-staging, where we can work bit-by-bit in implementing dsx.infer(...).

From that perspective, feel free to mark as ready to review whenever you feel it's ready and I'll take a close look :) thanks again!!

@theorashid theorashid marked this pull request as ready for review May 29, 2026 19:55
@theorashid
Copy link
Copy Markdown
Author

Just renamed _sample_intp to _infer_intp.

I think this is a good point. Smaller change, tests passing, not tooooo many files to check over. Then I'll use any feedback before doing the Simulator refactor.

@mattlevine22
Copy link
Copy Markdown
Collaborator

  1. Yeah, you can ignore test_science for now.
  2. Did you say tests/test_hierarchical_simulator_discretizer_smokes.py is having issues? That one should be a solid test, but it looks like everything passed in (green checks on your previous commits).
  3. I just ran tutorials 04 (discrete-time filter + simulator roll-out) and 06 (SDE filter + simulator roll-out), and both worked and look right locally, so that is a good sign!
  4. I think I prefer dsx.condition, as I worry that infer will sound "all-powerful" to some users trying to do parameter estimation. Infer will sound a bit weird when we use it for Simulator rollouts, but it is not CRAZY to say that a simulator rollout is "conditioned", even if only on obs_times.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants