Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

ODE parameter/state estimation

Part 6.4: State and Parameter Estimation for ODEs

Estimating initial conditions (state estimation) and physical parameters (parameter estimation) from noisy observations of the Lorenz63 system.

Stack: equinox (model) · diffrax (ODE solver) · optax (optimizer) · coordax (state wrapping)

import time as _time

import coordax as cx
import jax
import jax.numpy as jnp

try:
    import equinox as eqx
    has_eqx = True
except ImportError:
    has_eqx = False
    print("equinox not installed.")

try:
    import diffrax
    has_diffrax = True
except ImportError:
    has_diffrax = False
    print("diffrax not installed.")

try:
    import optax
    has_optax = True
except ImportError:
    has_optax = False
    print("optax not installed.")

# NOTE: cx.field() is preferred over cx.wrap() in coordax >= 0.2

Lorenz63 System

dx/dt = σ(y − x)
dy/dt = x(ρ − z) − y
dz/dt = xy − βz
SIGMA_TRUE = 10.0
RHO_TRUE   = 28.0
BETA_TRUE  = 8.0 / 3.0
X0_TRUE    = jnp.array([1.0, 1.0, 1.0])

state_axis = cx.LabeledAxis('state', jnp.arange(3, dtype=float))
# dims: ('state',) | shape: (3,)  — [x, y, z] state variables

print(f"True params: σ={SIGMA_TRUE}, ρ={RHO_TRUE}, β={BETA_TRUE:.4f}")
print(f"True IC:     x0={X0_TRUE}")
True params: σ=10.0, ρ=28.0, β=2.6667
True IC:     x0=[1. 1. 1.]

Lorenz63 Model (Equinox + Diffrax)

if has_eqx and has_diffrax:
    class Lorenz63(eqx.Module):
        """Lorenz63 ODE model."""
        sigma: float
        rho: float
        beta: float
        state_axis: cx.Coordinate = eqx.field(static=True)

        def vector_field(self, t, y, args=None):
            x, yv, z = y[0], y[1], y[2]
            return jnp.array([
                self.sigma * (yv - x),
                x * (self.rho - z) - yv,
                x * yv - self.beta * z,
            ])

        def __call__(self, t_span, y0):
            sol = diffrax.diffeqsolve(
                terms=diffrax.ODETerm(self.vector_field),
                solver=diffrax.Tsit5(),
                t0=t_span[0],
                t1=t_span[-1],
                dt0=None,
                y0=y0,
                saveat=diffrax.SaveAt(ts=t_span),
                stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-8),
                max_steps=20000,
                adjoint=diffrax.RecursiveCheckpointAdjoint(),
            )
            # sol.ys: shape (n_times, 3) | units: dimensionless  — trajectory
            return sol.ys  # (n_times, 3)

    model_true = Lorenz63(sigma=SIGMA_TRUE, rho=RHO_TRUE,
                           beta=BETA_TRUE, state_axis=state_axis)
    print("Lorenz63 model created.")
Lorenz63 model created.

Generate Synthetic Observations

if has_eqx and has_diffrax:
    t_obs = jnp.linspace(0, 2.0, 41)  # 2 Lyapunov times — long enough to see dynamics, short enough for gradients to stay informative
    time_axis = cx.LabeledAxis('time', t_obs)

    traj_true = model_true(t_obs, X0_TRUE)
    # shape: (41, 3) | axes: (time, state=[x,y,z]) | units: dimensionless
    noise_level = 0.5
    key = jax.random.PRNGKey(42)
    observations = traj_true + jax.random.normal(key, traj_true.shape) * noise_level

    observations_field = cx.field(observations, time_axis, state_axis)
    # shape: (41, 3) | dims: ('time','state') | units: dimensionless  — noisy observations
    print(f"Observations: {observations_field.dims}, shape: {observations_field.shape}")
    print(f"True traj std: {float(jnp.std(traj_true)):.2f},  SNR: {float(jnp.std(traj_true)/noise_level):.1f}")
else:
    t_obs = jnp.linspace(0, 2.0, 41)
    time_axis = cx.LabeledAxis('time', t_obs)
    observations = jnp.zeros((t_obs.shape[0], 3))
    observations_field = cx.field(observations, time_axis, state_axis)
Observations: ('time', 'state'), shape: (41, 3)
True traj std: 16.25,  SNR: 32.5

Problem 1: State Estimation (Recover Initial Condition)

Loss: L(x₀) = (1/N) Σ ||x(tᵢ; x₀, θ_true) − x_obs(tᵢ)||²

if has_eqx and has_diffrax and has_optax:
    def state_loss(x0, model, obs, t_span):
        traj = model(t_span, x0)
        # traj: shape (n_times, 3) | units: dimensionless  — forward trajectory
        return jnp.mean((traj - obs)**2)

    x0_guess = jnp.array([2.0, 2.0, 2.0])
    # shape: (3,) | units: dimensionless  — perturbed initial guess (true: [1,1,1])
    init_loss = state_loss(x0_guess, model_true, observations, t_obs)
    true_loss = state_loss(X0_TRUE,  model_true, observations, t_obs)
    print(f"Loss at x0_guess={x0_guess}: {float(init_loss):.3f}")
    print(f"Loss at x0_true ={X0_TRUE}:  {float(true_loss):.3f}")

    # Gradient w.r.t. initial condition
    grad_fn = jax.value_and_grad(state_loss, argnums=0)
    val, grad = grad_fn(x0_guess, model_true, observations, t_obs)
    # grad: shape (3,) | units: ∂L/∂x₀  — gradient of MSE loss w.r.t. initial condition
    print(f"Gradient of loss w.r.t. x0: {grad}")

    # Optimize with optax (Adam)
    optimizer = optax.adam(1e-2)
    x0_opt = x0_guess
    opt_state = optimizer.init(x0_opt)

    @jax.jit
    def opt_step(x0, opt_state):
        loss, grad = jax.value_and_grad(state_loss)(x0, model_true, observations, t_obs)
        updates, new_opt_state = optimizer.update(grad, opt_state)
        new_x0 = optax.apply_updates(x0, updates)
        return new_x0, new_opt_state, loss

    losses_state = []
    t0 = _time.perf_counter()
    for i in range(300):
        x0_opt, opt_state, loss = opt_step(x0_opt, opt_state)
        losses_state.append(float(loss))
        if (i + 1) % 25 == 0:
            print(f"  iter {i+1:3d}: loss={float(loss):.4f}, x0={x0_opt}")
    elapsed = _time.perf_counter() - t0
    print(f"Optimization: {elapsed:.1f}s for 300 steps")
    print(f"True x0={X0_TRUE}, recovered x0={x0_opt}")
    print(f"Error: {float(jnp.linalg.norm(x0_opt - X0_TRUE)):.3f}")
Loss at x0_guess=[2. 2. 2.]: 15.265
Loss at x0_true =[1. 1. 1.]:  0.191
Gradient of loss w.r.t. x0: [12.135993  10.012697  -2.1368177]
  iter  25: loss=9.6705, x0=[1.751183  1.7511531 2.2452395]
  iter  50: loss=4.8385, x0=[1.5133518 1.5131625 2.459534 ]
  iter  75: loss=1.8110, x0=[1.3091147 1.3087074 2.6061773]
  iter 100: loss=0.7051, x0=[1.1700063 1.169633  2.652397 ]
  iter 125: loss=0.5305, x0=[1.108884  1.109028  2.6025007]
  iter 150: loss=0.4834, x0=[1.0947616 1.0957208 2.503615 ]
  iter 175: loss=0.4395, x0=[1.0904061 1.0922536 2.3961523]
  iter 200: loss=0.4003, x0=[1.0840853 1.0868418 2.2923956]
  iter 225: loss=0.3656, x0=[1.076989  1.0806665 2.1925259]
  iter 250: loss=0.3352, x0=[1.0703448 1.0749434 2.0964162]
  iter 275: loss=0.3090, x0=[1.0641315 1.0696449 2.0047286]
  iter 300: loss=0.2867, x0=[1.05827   1.0646863 1.9179724]
Optimization: 3.3s for 300 steps
True x0=[1. 1. 1.], recovered x0=[1.05827   1.0646863 1.9179724]
Error: 0.922

Problem 2: Parameter Estimation (Recover σ, ρ, β)

Loss: L(θ) = (1/N) Σ ||x(tᵢ; x0_true, θ) − x_obs(tᵢ)||²

if has_eqx and has_diffrax and has_optax:
    # Perturbed initial model — use jnp.array so eqx.is_array recognizes them as differentiable
    model_init = Lorenz63(sigma=jnp.array(9.0), rho=jnp.array(26.5),
                           beta=jnp.array(2.4), state_axis=state_axis)
    print(f"Initial params: σ={float(model_init.sigma)}, ρ={float(model_init.rho)}, β={float(model_init.beta):.4f}")

    def param_loss(model, obs, t_span, x0):
        traj = model(t_span, x0)
        return jnp.mean((traj - obs)**2)

    # Use eqx.filter_value_and_grad to differentiate only float leaves
    loss_val_init = param_loss(model_init, observations, t_obs, X0_TRUE)
    print(f"Initial parameter loss: {float(loss_val_init):.4f}")

    optimizer_p = optax.adam(2e-2)
    model_opt = model_init
    opt_state_p = optimizer_p.init(eqx.filter(model_opt, eqx.is_array))

    @eqx.filter_jit
    def param_opt_step(model, opt_state, obs, t_span, x0):
        loss, grads = eqx.filter_value_and_grad(param_loss)(model, obs, t_span, x0)
        updates, new_opt_state = optimizer_p.update(
            eqx.filter(grads, eqx.is_array), opt_state
        )
        new_model = eqx.apply_updates(model, updates)
        return new_model, new_opt_state, loss

    losses_param = []
    t0 = _time.perf_counter()
    for i in range(500):
        model_opt, opt_state_p, loss = param_opt_step(
            model_opt, opt_state_p, observations, t_obs, X0_TRUE
        )
        losses_param.append(float(loss))
        if (i + 1) % 25 == 0:
            print(f"  iter {i+1:3d}: loss={float(loss):.4f}  "
                  f"σ={float(model_opt.sigma):.2f}  "
                  f"ρ={float(model_opt.rho):.2f}  "
                  f"β={float(model_opt.beta):.4f}")
    elapsed = _time.perf_counter() - t0
    print(f"Optimization: {elapsed:.1f}s for 500 steps")
    print(f"True:      σ={SIGMA_TRUE}, ρ={RHO_TRUE}, β={BETA_TRUE:.4f}")
    print(f"Recovered: σ={float(model_opt.sigma):.4f}, "
          f"ρ={float(model_opt.rho):.4f}, β={float(model_opt.beta):.4f}")
Initial params: σ=9.0, ρ=26.5, β=2.4000
Initial parameter loss: 4.0564
  iter  25: loss=1.6989  σ=9.49  ρ=26.99  β=2.5830
  iter  50: loss=0.6217  σ=9.86  ρ=27.39  β=2.6454
  iter  75: loss=0.2901  σ=10.06  ρ=27.66  β=2.6700
  iter 100: loss=0.2163  σ=10.13  ρ=27.82  β=2.6744
  iter 125: loss=0.2004  σ=10.12  ρ=27.90  β=2.6749
  iter 150: loss=0.1948  σ=10.09  ρ=27.94  β=2.6733
  iter 175: loss=0.1923  σ=10.06  ρ=27.96  β=2.6717
  iter 200: loss=0.1911  σ=10.04  ρ=27.98  β=2.6705
  iter 225: loss=0.1907  σ=10.03  ρ=27.99  β=2.6698
  iter 250: loss=0.1906  σ=10.02  ρ=28.00  β=2.6693
  iter 275: loss=0.1905  σ=10.02  ρ=28.00  β=2.6690
  iter 300: loss=0.1905  σ=10.01  ρ=28.00  β=2.6689
  iter 325: loss=0.1905  σ=10.01  ρ=28.00  β=2.6688
  iter 350: loss=0.1905  σ=10.01  ρ=28.00  β=2.6688
  iter 375: loss=0.1905  σ=10.01  ρ=28.00  β=2.6688
  iter 400: loss=0.1905  σ=10.01  ρ=28.00  β=2.6687
  iter 425: loss=0.1905  σ=10.01  ρ=28.00  β=2.6687
  iter 450: loss=0.1905  σ=10.01  ρ=28.00  β=2.6687
  iter 475: loss=0.1905  σ=10.01  ρ=28.00  β=2.6687
  iter 500: loss=0.1905  σ=10.01  ρ=28.00  β=2.6687
Optimization: 15.0s for 500 steps
True:      σ=10.0, ρ=28.0, β=2.6667
Recovered: σ=10.0100, ρ=28.0046, β=2.6687

Coordax State Wrapping

Demonstrate wrapping model states as coordax Fields for coordinate-aware analysis.

if has_eqx and has_diffrax:
    traj_opt_x0 = model_true(t_obs, x0_opt if has_optax else X0_TRUE)
    traj_field = cx.field(traj_opt_x0, time_axis, state_axis)
    # shape: (41, 3) | dims: ('time','state') | units: dimensionless

    print(f"Trajectory field: {traj_field.dims}, shape: {traj_field.shape}")

    # Compute trajectory statistics per state variable
    traj_mean = cx.cmap(lambda x: jnp.mean(x, axis=0))(traj_field.untag('time'))
    # shape: (3,) | dims: ('state',) | units: dimensionless  — time-mean of (x,y,z)
    traj_std  = cx.cmap(lambda x: jnp.std(x,  axis=0))(traj_field.untag('time'))
    # shape: (3,) | dims: ('state',) | units: dimensionless  — std dev of (x,y,z)
    print(f"Trajectory mean (x,y,z): {traj_mean.data}")
    print(f"Trajectory std  (x,y,z): {traj_std.data}")

    # Observations as coordax Field
    obs_field = cx.field(observations, time_axis, state_axis)
    # shape: (41, 3) | dims: ('time','state') | units: dimensionless
    residual_field = traj_field - obs_field
    # shape: (41, 3) | dims: ('time','state') | units: dimensionless  — prediction − obs
    print(f"Residual field: {residual_field.dims}, shape: {residual_field.shape}")
    print(f"RMS residual: {float(jnp.sqrt(jnp.mean(residual_field.data**2))):.4f}")
Trajectory field: ('time', 'state'), shape: (41, 3)
Trajectory mean (x,y,z): [-3.6464005 -4.109796  24.636417 ]
Trajectory std  (x,y,z): [7.847485 9.112147 9.818682]
Residual field: ('time', 'state'), shape: (41, 3)
RMS residual: 0.5347

Key Patterns

  • Use equinox for structured, differentiable model classes
  • Use diffrax for ODE solving inside the loss function
  • Use optax for gradient-based optimization
  • Use coordax to wrap state trajectories with meaningful dimension labels
  • jax.value_and_grad / eqx.filter_value_and_grad for gradients
  • @jax.jit / @eqx.filter_jit for compilation