Skip to content

Particle filter collapse #185

@cooijmanstim

Description

@cooijmanstim

I was able to move past my issue by using UKF, but I think there is still something to be explained here. Basically I'm using particle filter in a loop, feeding the weighted particles back into the next call as an initial condition. The particles seem to collapse in a single step, and also fail to move in time after that. I thought it was something to do with my mujoco setup, but it also happens in a straightforward system:

import jax
import jax.numpy as jnp
import jax.random as rax

jax.config.update("jax_traceback_filtering", "off")

import dynestyx as dsx
import numpy as np
import numpyro.distributions as dist
from dynestyx import DynamicalModel, Filter
from dynestyx.inference.filters import PFConfig
from numpyro.infer import Predictive

from dynestyx.inference.integrations.utils import particles_to_delta_mixtures

state_size = 13
observation_size = 17
control_size = 5

def fuzz(x):
    return dist.Normal(jnp.astype(x, "float32"), jnp.array(0.1, "float32"))

def step(state, rng):
        rng1, rng2 = rax.split(rng)
        obs = rax.normal(rng1, shape=[observation_size])
        ctrl = rax.normal(rng2, shape=[control_size])

        def state_evolution(x, u, t_now, t_next):
            return fuzz(2 * x + 1 + u.sum())

        def observation_model(x, u, t):
            return fuzz(np.ones([observation_size]) + x.sum() + u.sum())

        def dsx_model(**kwargs):
            dynamics = DynamicalModel(
                control_dim=control_size,
                initial_condition=state,
                state_evolution=state_evolution,
                observation_model=observation_model,
            )
            return dsx.sample("f", dynamics, **kwargs)

        with Filter(PFConfig(n_particles=11)):
            predictive = Predictive(dsx_model, num_samples=1)
            t1 = jnp.ones([1])
            pred = predictive(
                rng,
                obs_times=t1,
                obs_values=obs[None],
                ctrl_times=t1,
                ctrl_values=ctrl[None],
                predict_times=t1,
            )
        # squeeze Predictive num_samples axis (batch_ndims=0 doesn't seem to do it)
        x = jnp.squeeze(pred["f_filtered_particles"], axis=0)
        w = jnp.squeeze(pred["f_filtered_log_weights"], axis=0)
        [px1] = particles_to_delta_mixtures(x, w)
        return px1

state0 = fuzz(np.zeros([state_size]))
states = [state0]
for t in range(7):
     states.append(step(states[-1], rax.key(t)))

particles = np.stack([state.particles for state in states[1:]])  # states[0] is Normal

from matplotlib import pyplot as plt
plt.figure()
plt.plot(jnp.diff(particles, axis=0)[:, 0])
plt.title("particle 0 time derivative")
plt.figure()
plt.plot(particles.var(axis=1))
plt.title("variance across particles")
Image Image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions