I was able to move past my issue by using UKF, but I think there is still something to be explained here. Basically I'm using particle filter in a loop, feeding the weighted particles back into the next call as an initial condition. The particles seem to collapse in a single step, and also fail to move in time after that. I thought it was something to do with my mujoco setup, but it also happens in a straightforward system:
import jax
import jax.numpy as jnp
import jax.random as rax
jax.config.update("jax_traceback_filtering", "off")
import dynestyx as dsx
import numpy as np
import numpyro.distributions as dist
from dynestyx import DynamicalModel, Filter
from dynestyx.inference.filters import PFConfig
from numpyro.infer import Predictive
from dynestyx.inference.integrations.utils import particles_to_delta_mixtures
state_size = 13
observation_size = 17
control_size = 5
def fuzz(x):
return dist.Normal(jnp.astype(x, "float32"), jnp.array(0.1, "float32"))
def step(state, rng):
rng1, rng2 = rax.split(rng)
obs = rax.normal(rng1, shape=[observation_size])
ctrl = rax.normal(rng2, shape=[control_size])
def state_evolution(x, u, t_now, t_next):
return fuzz(2 * x + 1 + u.sum())
def observation_model(x, u, t):
return fuzz(np.ones([observation_size]) + x.sum() + u.sum())
def dsx_model(**kwargs):
dynamics = DynamicalModel(
control_dim=control_size,
initial_condition=state,
state_evolution=state_evolution,
observation_model=observation_model,
)
return dsx.sample("f", dynamics, **kwargs)
with Filter(PFConfig(n_particles=11)):
predictive = Predictive(dsx_model, num_samples=1)
t1 = jnp.ones([1])
pred = predictive(
rng,
obs_times=t1,
obs_values=obs[None],
ctrl_times=t1,
ctrl_values=ctrl[None],
predict_times=t1,
)
# squeeze Predictive num_samples axis (batch_ndims=0 doesn't seem to do it)
x = jnp.squeeze(pred["f_filtered_particles"], axis=0)
w = jnp.squeeze(pred["f_filtered_log_weights"], axis=0)
[px1] = particles_to_delta_mixtures(x, w)
return px1
state0 = fuzz(np.zeros([state_size]))
states = [state0]
for t in range(7):
states.append(step(states[-1], rax.key(t)))
particles = np.stack([state.particles for state in states[1:]]) # states[0] is Normal
from matplotlib import pyplot as plt
plt.figure()
plt.plot(jnp.diff(particles, axis=0)[:, 0])
plt.title("particle 0 time derivative")
plt.figure()
plt.plot(particles.var(axis=1))
plt.title("variance across particles")

I was able to move past my issue by using UKF, but I think there is still something to be explained here. Basically I'm using particle filter in a loop, feeding the weighted particles back into the next call as an initial condition. The particles seem to collapse in a single step, and also fail to move in time after that. I thought it was something to do with my mujoco setup, but it also happens in a straightforward system: