-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsimulation.py
More file actions
90 lines (70 loc) · 2.88 KB
/
simulation.py
File metadata and controls
90 lines (70 loc) · 2.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""Simulation of Ornstein-Uhlenbeck process with hidden variable U_t."""
import jax
import jax.numpy as jnp
import jax.random as jr
from jax import vmap
def simulate_ou(key, x0, params, dt, N, U_values):
"""Simulate OU process with U_t-dependent mean.
Uses the exact transition density for simulation.
Args:
key: JAX random key for reproducibility
x0: Initial value X_0
params: Tuple (lambda, a, b, sigma) containing process parameters
dt: Time step size
N: Number of time steps
U_values: Hidden variable values at each time point (length N+1)
Returns:
Array of trajectory values X_t (length N+1)
"""
lam, a, b, sigma = params
def ou_step(carry, i):
x, key = carry
key, subkey = jr.split(key)
u_t = U_values[i]
mu_t = a * u_t + b
D = sigma**2 / 2
mean = x * jnp.exp(-lam * dt) + mu_t * (1 - jnp.exp(-lam * dt))
var = D / lam * (1 - jnp.exp(-2 * lam * dt))
x_new = mean + jr.normal(subkey) * jnp.sqrt(var)
return (x_new, key), x_new
init = (x0, key)
(_, _), xs = jax.lax.scan(ou_step, init, jnp.arange(1, N+1))
xs = jnp.concatenate([jnp.array([x0]), xs])
return xs
def sim_multi_OU(key, x0, params, dt, N, U_values, n_samples):
"""Simulate multiple OU trajectories in parallel.
Args:
key: JAX random key for reproducibility
x0: Initial value X_0
params: Tuple (lambda, a, b, sigma) containing process parameters
dt: Time step size
N: Number of time steps
U_values: Hidden variable values at each time point (length N+1)
n_samples: Number of trajectories to generate
Returns:
Array of trajectories with shape (n_samples, N+1)
"""
keys = jr.split(key, n_samples)
simulations = vmap(lambda k: simulate_ou(k, x0, params, dt, N, U_values))(keys)
return simulations
def generate_U_values(key, num_segments, N, scale=1.0, offset=1.0):
"""Generate piecewise constant hidden variable U_t.
Creates a step function with random values in each segment.
Args:
key: JAX random key for reproducibility
num_segments: Number of constant segments
N: Total number of time steps
scale: Standard deviation of random segment values
offset: Mean of random segment values
Returns:
Array of U_t values (length N+1)
"""
segment_length = N // num_segments
segment_values = jr.normal(key, shape=(num_segments,)) * scale + offset
U_values = jnp.repeat(segment_values, jnp.ones(num_segments, dtype=int) * segment_length)
# Pad with the last value if necessary to match N+1
if len(U_values) < N + 1:
U_values = jnp.pad(U_values, (0, N + 1 - len(U_values)), mode='edge')
else:
U_values = U_values[:N + 1]
return U_values