Ensemble methods — MAP and VI via `ensemble_step`
Ensemble methods — MAP and VI via ensemble_step¶
Ensemble inference represents the posterior with a particle set rather than a parametric distribution. Instead of sampling a Markov chain or optimizing variational parameters, you push the particles around with an update rule that mimics the effect of the posterior gradient. The family traces back to the ensemble Kalman filter Evensen (2003) and its “inverse problems” relatives — ensemble Kalman inversion (EKI) Iglesias et al. (2013), the ensemble Kalman sampler (EKS) Garbuno-Iñigo et al. (2020), and calibrate-emulate-sample pipelines Cleary et al. (2021). Pyrox bundles these into two primitives — ensemble_step and the EnsembleMAP / EnsembleVI runners — that drop into Equinox + NumPyro models with no custom plumbing.
Model¶
Given an unnormalized log posterior , we maintain particles . Let be the ensemble mean and the ensemble covariance.
MAP via ensemble gradient¶
For maximum-a-posteriori estimation, each step pre-conditions the log-posterior gradient with the ensemble covariance and applies a standard gradient step:
The ensemble covariance acts as a data-driven metric tensor, so the update is invariant to linear reparameterizations of θ — a practical advantage over vanilla SGD on constrained parameters.
VI via ensemble-Langevin dynamics¶
Adding a random-walk term turns the step into a sampler. EKS / interacting-Langevin updates Garbuno-Iñigo et al. (2020):
In the / limit this is a Langevin diffusion targeting ; in practice a modest (tens to a few hundred) is already competitive with mean-field VI Blei et al. (2017) on problems where the posterior is curved or multi-modal.
Numerical considerations¶
- Ensemble size . Rank- covariance is singular when . Two practical fixes: localization (zero out off-diagonal entries beyond a correlation cutoff — inherited from geophysical data assimilation) or Tikhonov regularization . The pyrox primitives expose both via hooks.
- Step-size . A constant η is fine for MAP if the loss landscape is well-conditioned; for VI, is required for asymptotic correctness. Warmup + cosine-decay schedules are a safe default.
- Jitter in . The noise term requires a Cholesky of , which is rank-deficient when . Add before factorizing and document ε as a hyperparameter.
- Parallelism. Each particle’s gradient evaluation is independent and
jax.vmaps trivially. Theensemble_stepprimitive already does this;EnsembleMAP/EnsembleVIcompose it with anoptaxoptimizer and log history. - When not to use it. Ensembles shine on moderate-dim () physics-inspired problems where gradients are cheap but the posterior is nasty. For flat, high-dimensional neural-net posteriors, NUTS + reparam or mean-field VI Kucukelbir et al. (2017) usually wins.
Notebooks¶
ensemble_primitives_tutorial— three ways to drive the low-levelensemble_stepprimitive: as a function, as a jit-compiled loop, and inside anoptaxupdate.ensemble_runner_tutorial— the higher-levelEnsembleMAPandEnsembleVIclasses — same runtime, with history logging, config handling, and drop-in use against any NumPyro model.
References¶
- Evensen, G. (2003). The Ensemble Kalman Filter: Theoretical Formulation and Practical Implementation. Ocean Dynamics, 53, 343–367. 10.1007/s10236-003-0036-9
- Iglesias, M. A., Law, K. J. H., & Stuart, A. M. (2013). Ensemble Kalman Methods for Inverse Problems. Inverse Problems, 29(4), 045001. 10.1088/0266-5611/29/4/045001
- Garbuno-Iñigo, A., Hoffmann, F., Li, W., & Stuart, A. M. (2020). Interacting Langevin Diffusions: Gradient Structure and Ensemble Kalman Sampler. SIAM Journal on Applied Dynamical Systems, 19(1), 412–441. 10.1137/19M1251655
- Cleary, E., Garbuno-Iñigo, A., Lan, S., Schneider, T., & Stuart, A. M. (2021). Calibrate, Emulate, Sample. Journal of Computational Physics, 424, 109716. 10.1016/j.jcp.2020.109716
- Blei, D. M., Kucukelbir, A., & McAuliffe, J. D. (2017). Variational Inference: A Review for Statisticians. Journal of the American Statistical Association, 112(518), 859–877. 10.1080/01621459.2017.1285773
- Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic Differentiation Variational Inference. Journal of Machine Learning Research, 18(14), 1–45.