Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Ensemble methods — MAP and VI via `ensemble_step`

Ensemble methods — MAP and VI via ensemble_step

Ensemble inference represents the posterior with a particle set {θ(j)}j=1J\{\theta^{(j)}\}_{j=1}^J rather than a parametric distribution. Instead of sampling a Markov chain or optimizing variational parameters, you push the particles around with an update rule that mimics the effect of the posterior gradient. The family traces back to the ensemble Kalman filter Evensen (2003) and its “inverse problems” relatives — ensemble Kalman inversion (EKI) Iglesias et al. (2013), the ensemble Kalman sampler (EKS) Garbuno-Iñigo et al. (2020), and calibrate-emulate-sample pipelines Cleary et al. (2021). Pyrox bundles these into two primitives — ensemble_step and the EnsembleMAP / EnsembleVI runners — that drop into Equinox + NumPyro models with no custom plumbing.

Model

Given an unnormalized log posterior logp~(θ)=logp(θ)+logp(yθ)\log \tilde p(\theta) = \log p(\theta) + \log p(\mathbf{y} \mid \theta), we maintain JJ particles θ(j)\theta^{(j)}. Let θˉ=1Jjθ(j)\bar\theta = \tfrac{1}{J}\sum_j \theta^{(j)} be the ensemble mean and Cθ=1J1j(θ(j)θˉ)(θ(j)θˉ)C_\theta = \tfrac{1}{J-1}\sum_j (\theta^{(j)} - \bar\theta)(\theta^{(j)} - \bar\theta)^\top the ensemble covariance.

MAP via ensemble gradient

For maximum-a-posteriori estimation, each step pre-conditions the log-posterior gradient with the ensemble covariance and applies a standard gradient step:

θt+1(j)=θt(j)+ηtCθθlogp~(θt(j)).\theta^{(j)}_{t+1} = \theta^{(j)}_t + \eta_t\, C_\theta\, \nabla_\theta \log \tilde p\bigl(\theta^{(j)}_t\bigr).

The ensemble covariance acts as a data-driven metric tensor, so the update is invariant to linear reparameterizations of θ — a practical advantage over vanilla SGD on constrained parameters.

VI via ensemble-Langevin dynamics

Adding a random-walk term turns the step into a sampler. EKS / interacting-Langevin updates Garbuno-Iñigo et al. (2020):

θt+1(j)=θt(j)+ηtCθθlogp~(θt(j))+2ηtCθ  ξt(j),ξt(j)iidN(0,I).\theta^{(j)}_{t+1} = \theta^{(j)}_t + \eta_t\, C_\theta\, \nabla_\theta \log \tilde p\bigl(\theta^{(j)}_t\bigr) + \sqrt{2\eta_t\, C_\theta}\; \xi^{(j)}_t,\qquad \xi^{(j)}_t \overset{\text{iid}}{\sim} \mathcal{N}(0, I).

In the JJ \to \infty / ηt0\eta_t \to 0 limit this is a Langevin diffusion targeting p(θy)p(\theta \mid \mathbf{y}); in practice a modest JJ (tens to a few hundred) is already competitive with mean-field VI Blei et al. (2017) on problems where the posterior is curved or multi-modal.

Numerical considerations

Notebooks

References

References
  1. Evensen, G. (2003). The Ensemble Kalman Filter: Theoretical Formulation and Practical Implementation. Ocean Dynamics, 53, 343–367. 10.1007/s10236-003-0036-9
  2. Iglesias, M. A., Law, K. J. H., & Stuart, A. M. (2013). Ensemble Kalman Methods for Inverse Problems. Inverse Problems, 29(4), 045001. 10.1088/0266-5611/29/4/045001
  3. Garbuno-Iñigo, A., Hoffmann, F., Li, W., & Stuart, A. M. (2020). Interacting Langevin Diffusions: Gradient Structure and Ensemble Kalman Sampler. SIAM Journal on Applied Dynamical Systems, 19(1), 412–441. 10.1137/19M1251655
  4. Cleary, E., Garbuno-Iñigo, A., Lan, S., Schneider, T., & Stuart, A. M. (2021). Calibrate, Emulate, Sample. Journal of Computational Physics, 424, 109716. 10.1016/j.jcp.2020.109716
  5. Blei, D. M., Kucukelbir, A., & McAuliffe, J. D. (2017). Variational Inference: A Review for Statisticians. Journal of the American Statistical Association, 112(518), 859–877. 10.1080/01621459.2017.1285773
  6. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic Differentiation Variational Inference. Journal of Machine Learning Research, 18(14), 1–45.