Dynamics — ODEs, PDEs, and parameter estimation
Once state lives in a Field and derivatives are coordinate-aware, the step
to a full dynamical-system workflow is small: wire the RHS into an ODE
solver, differentiate through it, and optimize. The three notebooks here
integrate a 1-D advection-diffusion PDE with diffrax Kidger (2021),
then invert it for unknown parameters and initial states using
optax and jax.value_and_grad.
Forward problem¶
The target PDE is the linear 1-D advection-diffusion equation
discretized in space with the finite-difference operators from the
derivatives section and integrated in time
with diffrax.Tsit5 and a PID step controller.
Writing the RHS as , the
diffeqsolve call is an end-to-end-differentiable function of the state
and the parameters — this is the lever neural-ODE-style frameworks
Chen et al. (2018) pull to do gradient-based inversion without manually
assembling tangent equations.
Inverse problems¶
Given noisy observations at times , the two inverse problems in this section are variations on a least-squares objective:
where are the PDE parameters, is the
(unknown) initial state, and is an optional prior /
regularizer. Closed-form gradients are infeasible — the solver is implicit
in — so jax.value_and_grad through diffeqsolve + optax.adam is the
only reasonable workflow at this scale. The equivalence with ensemble
Kalman-type inversion Evensen (2009) is worth noting: both target
the same MAP point, but gradient optimization is cheaper when derivatives
are available and the forward model is smooth.
Numerical considerations¶
- Adjoint choice.
diffrax.RecursiveCheckpointAdjointis the default here — it rolls forward with checkpoints and backpropagates, trading memory for gradient accuracy. Pure reverse-mode (“discretize-then-optimize”) is more accurate but allocates the whole trajectory; continuous adjoint (“optimize-then-discretize”) is memory-cheap but less accurate. Pick based on trajectory length vs. state size. - Step controller. A stiff or near-stiff RHS (large κ on a fine
grid) triggers aggressive step-size shrinking under the PID controller —
gradient quality degrades long before the forward solve fails. If loss is
flat but is noisy, tighten
rtol/atolby an order of magnitude and re-check. - Parameter scaling. and κ differ by orders of magnitude at
realistic values; unscaled
adamspends most of its budget on the larger one. Reparameterize as or use per-parameter learning rates. - Initial-state identifiability. Joint estimation is identifiable only when the observation window resolves the relevant diffusion/advection timescales. For , the problem collapses to state estimation; for , upstream information is wiped out. The notebooks pick observation schedules inside the identifiable regime.
- Validation. Always run the “twin experiment” first — simulate with known parameters, recover them. If the twin fails, the real data will too. All three notebooks follow this pattern.
Notebooks¶
08_ode_integration— forward solve of advection-diffusion withdiffrax; wrapping state as aFieldfor coordinate-aware RHS evaluation; a conservation check on the integrated trajectory.09_ode_parameter_state_estimation— joint recovery of and from noisy observations viaoptax.adamthroughdiffeqsolve. Twin experiment.10_pde_parameter_estimation— learning PDE parameters usingequinox.Moduleto keep the parameter/state split clean and thediffraxsolve inside__call__.
References¶
- Kidger, P. (2021). On Neural Differential Equations [Phdthesis, University of Oxford]. https://arxiv.org/abs/2202.02435
- Chen, R. T. Q., Rubanova, Y., Bettencourt, J., & Duvenaud, D. (2018). Neural Ordinary Differential Equations. Advances in Neural Information Processing Systems (NeurIPS).
- Evensen, G. (2009). Data Assimilation: The Ensemble Kalman Filter (2nd ed.). Springer. 10.1007/978-3-642-03711-5