ODE parameter/state estimation
Part 6.4: State and Parameter Estimation for ODEs¶
Estimating initial conditions (state estimation) and physical parameters (parameter estimation) from noisy observations of the Lorenz63 system.
Stack: equinox (model) · diffrax (ODE solver) · optax (optimizer) · coordax (state wrapping)
import time as _time
import coordax as cx
import jax
import jax.numpy as jnp
try:
import equinox as eqx
has_eqx = True
except ImportError:
has_eqx = False
print("equinox not installed.")
try:
import diffrax
has_diffrax = True
except ImportError:
has_diffrax = False
print("diffrax not installed.")
try:
import optax
has_optax = True
except ImportError:
has_optax = False
print("optax not installed.")
# NOTE: cx.field() is preferred over cx.wrap() in coordax >= 0.2Lorenz63 System¶
dx/dt = σ(y − x)
dy/dt = x(ρ − z) − y
dz/dt = xy − βzSIGMA_TRUE = 10.0
RHO_TRUE = 28.0
BETA_TRUE = 8.0 / 3.0
X0_TRUE = jnp.array([1.0, 1.0, 1.0])
state_axis = cx.LabeledAxis('state', jnp.arange(3, dtype=float))
# dims: ('state',) | shape: (3,) — [x, y, z] state variables
print(f"True params: σ={SIGMA_TRUE}, ρ={RHO_TRUE}, β={BETA_TRUE:.4f}")
print(f"True IC: x0={X0_TRUE}")True params: σ=10.0, ρ=28.0, β=2.6667
True IC: x0=[1. 1. 1.]
Lorenz63 Model (Equinox + Diffrax)¶
if has_eqx and has_diffrax:
class Lorenz63(eqx.Module):
"""Lorenz63 ODE model."""
sigma: float
rho: float
beta: float
state_axis: cx.Coordinate = eqx.field(static=True)
def vector_field(self, t, y, args=None):
x, yv, z = y[0], y[1], y[2]
return jnp.array([
self.sigma * (yv - x),
x * (self.rho - z) - yv,
x * yv - self.beta * z,
])
def __call__(self, t_span, y0):
sol = diffrax.diffeqsolve(
terms=diffrax.ODETerm(self.vector_field),
solver=diffrax.Tsit5(),
t0=t_span[0],
t1=t_span[-1],
dt0=None,
y0=y0,
saveat=diffrax.SaveAt(ts=t_span),
stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-8),
max_steps=20000,
adjoint=diffrax.RecursiveCheckpointAdjoint(),
)
# sol.ys: shape (n_times, 3) | units: dimensionless — trajectory
return sol.ys # (n_times, 3)
model_true = Lorenz63(sigma=SIGMA_TRUE, rho=RHO_TRUE,
beta=BETA_TRUE, state_axis=state_axis)
print("Lorenz63 model created.")Lorenz63 model created.
Generate Synthetic Observations¶
if has_eqx and has_diffrax:
t_obs = jnp.linspace(0, 2.0, 41) # 2 Lyapunov times — long enough to see dynamics, short enough for gradients to stay informative
time_axis = cx.LabeledAxis('time', t_obs)
traj_true = model_true(t_obs, X0_TRUE)
# shape: (41, 3) | axes: (time, state=[x,y,z]) | units: dimensionless
noise_level = 0.5
key = jax.random.PRNGKey(42)
observations = traj_true + jax.random.normal(key, traj_true.shape) * noise_level
observations_field = cx.field(observations, time_axis, state_axis)
# shape: (41, 3) | dims: ('time','state') | units: dimensionless — noisy observations
print(f"Observations: {observations_field.dims}, shape: {observations_field.shape}")
print(f"True traj std: {float(jnp.std(traj_true)):.2f}, SNR: {float(jnp.std(traj_true)/noise_level):.1f}")
else:
t_obs = jnp.linspace(0, 2.0, 41)
time_axis = cx.LabeledAxis('time', t_obs)
observations = jnp.zeros((t_obs.shape[0], 3))
observations_field = cx.field(observations, time_axis, state_axis)Observations: ('time', 'state'), shape: (41, 3)
True traj std: 16.25, SNR: 32.5
Problem 1: State Estimation (Recover Initial Condition)¶
Loss: L(x₀) = (1/N) Σ ||x(tᵢ; x₀, θ_true) − x_obs(tᵢ)||²
if has_eqx and has_diffrax and has_optax:
def state_loss(x0, model, obs, t_span):
traj = model(t_span, x0)
# traj: shape (n_times, 3) | units: dimensionless — forward trajectory
return jnp.mean((traj - obs)**2)
x0_guess = jnp.array([2.0, 2.0, 2.0])
# shape: (3,) | units: dimensionless — perturbed initial guess (true: [1,1,1])
init_loss = state_loss(x0_guess, model_true, observations, t_obs)
true_loss = state_loss(X0_TRUE, model_true, observations, t_obs)
print(f"Loss at x0_guess={x0_guess}: {float(init_loss):.3f}")
print(f"Loss at x0_true ={X0_TRUE}: {float(true_loss):.3f}")
# Gradient w.r.t. initial condition
grad_fn = jax.value_and_grad(state_loss, argnums=0)
val, grad = grad_fn(x0_guess, model_true, observations, t_obs)
# grad: shape (3,) | units: ∂L/∂x₀ — gradient of MSE loss w.r.t. initial condition
print(f"Gradient of loss w.r.t. x0: {grad}")
# Optimize with optax (Adam)
optimizer = optax.adam(1e-2)
x0_opt = x0_guess
opt_state = optimizer.init(x0_opt)
@jax.jit
def opt_step(x0, opt_state):
loss, grad = jax.value_and_grad(state_loss)(x0, model_true, observations, t_obs)
updates, new_opt_state = optimizer.update(grad, opt_state)
new_x0 = optax.apply_updates(x0, updates)
return new_x0, new_opt_state, loss
losses_state = []
t0 = _time.perf_counter()
for i in range(300):
x0_opt, opt_state, loss = opt_step(x0_opt, opt_state)
losses_state.append(float(loss))
if (i + 1) % 25 == 0:
print(f" iter {i+1:3d}: loss={float(loss):.4f}, x0={x0_opt}")
elapsed = _time.perf_counter() - t0
print(f"Optimization: {elapsed:.1f}s for 300 steps")
print(f"True x0={X0_TRUE}, recovered x0={x0_opt}")
print(f"Error: {float(jnp.linalg.norm(x0_opt - X0_TRUE)):.3f}")Loss at x0_guess=[2. 2. 2.]: 15.265
Loss at x0_true =[1. 1. 1.]: 0.191
Gradient of loss w.r.t. x0: [12.135993 10.012697 -2.1368177]
iter 25: loss=9.6705, x0=[1.751183 1.7511531 2.2452395]
iter 50: loss=4.8385, x0=[1.5133518 1.5131625 2.459534 ]
iter 75: loss=1.8110, x0=[1.3091147 1.3087074 2.6061773]
iter 100: loss=0.7051, x0=[1.1700063 1.169633 2.652397 ]
iter 125: loss=0.5305, x0=[1.108884 1.109028 2.6025007]
iter 150: loss=0.4834, x0=[1.0947616 1.0957208 2.503615 ]
iter 175: loss=0.4395, x0=[1.0904061 1.0922536 2.3961523]
iter 200: loss=0.4003, x0=[1.0840853 1.0868418 2.2923956]
iter 225: loss=0.3656, x0=[1.076989 1.0806665 2.1925259]
iter 250: loss=0.3352, x0=[1.0703448 1.0749434 2.0964162]
iter 275: loss=0.3090, x0=[1.0641315 1.0696449 2.0047286]
iter 300: loss=0.2867, x0=[1.05827 1.0646863 1.9179724]
Optimization: 3.3s for 300 steps
True x0=[1. 1. 1.], recovered x0=[1.05827 1.0646863 1.9179724]
Error: 0.922
Problem 2: Parameter Estimation (Recover σ, ρ, β)¶
Loss: L(θ) = (1/N) Σ ||x(tᵢ; x0_true, θ) − x_obs(tᵢ)||²
if has_eqx and has_diffrax and has_optax:
# Perturbed initial model — use jnp.array so eqx.is_array recognizes them as differentiable
model_init = Lorenz63(sigma=jnp.array(9.0), rho=jnp.array(26.5),
beta=jnp.array(2.4), state_axis=state_axis)
print(f"Initial params: σ={float(model_init.sigma)}, ρ={float(model_init.rho)}, β={float(model_init.beta):.4f}")
def param_loss(model, obs, t_span, x0):
traj = model(t_span, x0)
return jnp.mean((traj - obs)**2)
# Use eqx.filter_value_and_grad to differentiate only float leaves
loss_val_init = param_loss(model_init, observations, t_obs, X0_TRUE)
print(f"Initial parameter loss: {float(loss_val_init):.4f}")
optimizer_p = optax.adam(2e-2)
model_opt = model_init
opt_state_p = optimizer_p.init(eqx.filter(model_opt, eqx.is_array))
@eqx.filter_jit
def param_opt_step(model, opt_state, obs, t_span, x0):
loss, grads = eqx.filter_value_and_grad(param_loss)(model, obs, t_span, x0)
updates, new_opt_state = optimizer_p.update(
eqx.filter(grads, eqx.is_array), opt_state
)
new_model = eqx.apply_updates(model, updates)
return new_model, new_opt_state, loss
losses_param = []
t0 = _time.perf_counter()
for i in range(500):
model_opt, opt_state_p, loss = param_opt_step(
model_opt, opt_state_p, observations, t_obs, X0_TRUE
)
losses_param.append(float(loss))
if (i + 1) % 25 == 0:
print(f" iter {i+1:3d}: loss={float(loss):.4f} "
f"σ={float(model_opt.sigma):.2f} "
f"ρ={float(model_opt.rho):.2f} "
f"β={float(model_opt.beta):.4f}")
elapsed = _time.perf_counter() - t0
print(f"Optimization: {elapsed:.1f}s for 500 steps")
print(f"True: σ={SIGMA_TRUE}, ρ={RHO_TRUE}, β={BETA_TRUE:.4f}")
print(f"Recovered: σ={float(model_opt.sigma):.4f}, "
f"ρ={float(model_opt.rho):.4f}, β={float(model_opt.beta):.4f}")Initial params: σ=9.0, ρ=26.5, β=2.4000
Initial parameter loss: 4.0564
iter 25: loss=1.6989 σ=9.49 ρ=26.99 β=2.5830
iter 50: loss=0.6217 σ=9.86 ρ=27.39 β=2.6454
iter 75: loss=0.2901 σ=10.06 ρ=27.66 β=2.6700
iter 100: loss=0.2163 σ=10.13 ρ=27.82 β=2.6744
iter 125: loss=0.2004 σ=10.12 ρ=27.90 β=2.6749
iter 150: loss=0.1948 σ=10.09 ρ=27.94 β=2.6733
iter 175: loss=0.1923 σ=10.06 ρ=27.96 β=2.6717
iter 200: loss=0.1911 σ=10.04 ρ=27.98 β=2.6705
iter 225: loss=0.1907 σ=10.03 ρ=27.99 β=2.6698
iter 250: loss=0.1906 σ=10.02 ρ=28.00 β=2.6693
iter 275: loss=0.1905 σ=10.02 ρ=28.00 β=2.6690
iter 300: loss=0.1905 σ=10.01 ρ=28.00 β=2.6689
iter 325: loss=0.1905 σ=10.01 ρ=28.00 β=2.6688
iter 350: loss=0.1905 σ=10.01 ρ=28.00 β=2.6688
iter 375: loss=0.1905 σ=10.01 ρ=28.00 β=2.6688
iter 400: loss=0.1905 σ=10.01 ρ=28.00 β=2.6687
iter 425: loss=0.1905 σ=10.01 ρ=28.00 β=2.6687
iter 450: loss=0.1905 σ=10.01 ρ=28.00 β=2.6687
iter 475: loss=0.1905 σ=10.01 ρ=28.00 β=2.6687
iter 500: loss=0.1905 σ=10.01 ρ=28.00 β=2.6687
Optimization: 15.0s for 500 steps
True: σ=10.0, ρ=28.0, β=2.6667
Recovered: σ=10.0100, ρ=28.0046, β=2.6687
Coordax State Wrapping¶
Demonstrate wrapping model states as coordax Fields for coordinate-aware analysis.
if has_eqx and has_diffrax:
traj_opt_x0 = model_true(t_obs, x0_opt if has_optax else X0_TRUE)
traj_field = cx.field(traj_opt_x0, time_axis, state_axis)
# shape: (41, 3) | dims: ('time','state') | units: dimensionless
print(f"Trajectory field: {traj_field.dims}, shape: {traj_field.shape}")
# Compute trajectory statistics per state variable
traj_mean = cx.cmap(lambda x: jnp.mean(x, axis=0))(traj_field.untag('time'))
# shape: (3,) | dims: ('state',) | units: dimensionless — time-mean of (x,y,z)
traj_std = cx.cmap(lambda x: jnp.std(x, axis=0))(traj_field.untag('time'))
# shape: (3,) | dims: ('state',) | units: dimensionless — std dev of (x,y,z)
print(f"Trajectory mean (x,y,z): {traj_mean.data}")
print(f"Trajectory std (x,y,z): {traj_std.data}")
# Observations as coordax Field
obs_field = cx.field(observations, time_axis, state_axis)
# shape: (41, 3) | dims: ('time','state') | units: dimensionless
residual_field = traj_field - obs_field
# shape: (41, 3) | dims: ('time','state') | units: dimensionless — prediction − obs
print(f"Residual field: {residual_field.dims}, shape: {residual_field.shape}")
print(f"RMS residual: {float(jnp.sqrt(jnp.mean(residual_field.data**2))):.4f}")Trajectory field: ('time', 'state'), shape: (41, 3)
Trajectory mean (x,y,z): [-3.6464005 -4.109796 24.636417 ]
Trajectory std (x,y,z): [7.847485 9.112147 9.818682]
Residual field: ('time', 'state'), shape: (41, 3)
RMS residual: 0.5347
Key Patterns¶
- Use equinox for structured, differentiable model classes
- Use diffrax for ODE solving inside the loss function
- Use optax for gradient-based optimization
- Use coordax to wrap state trajectories with meaningful dimension labels
jax.value_and_grad/eqx.filter_value_and_gradfor gradients@jax.jit/@eqx.filter_jitfor compilation