Regression masterclass — three patterns for parameter handling
A recurring friction when mixing Equinox modules Kidger & Garcia (2021) with NumPyro Phan et al. (2019) is where hyperparameters live. Equinox wants a PyTree of leaves; NumPyro wants named random sites. Pyrox picks three compatible idioms and the same regression problem is solved end-to-end in each, so you can A/B them on ergonomics.
The three patterns live on a spectrum:
| Pattern | Where the params live | What wraps them | Boilerplate |
|---|---|---|---|
Pattern 1 — eqx.tree_at + raw NumPyro | equinox module leaves | user does the substitution by hand | most |
Pattern 2 — PyroxModule + pyrox_sample | module leaves + a NumPyro plate | pyrox wires sampling | medium |
Pattern 3 — Parameterized + PyroxParam + native pyrox.gp | leaves carry prior metadata | pyrox resolves everything | least |
All three are mathematically equivalent — they sample the same posterior. The choice is about how much machinery you want pyrox to do for you.
Model¶
All three notebooks target the same regression problem: an exact GP prior with Gaussian noise , , and hyperparameters (length-scale, signal amplitude, noise scale). We place hyperpriors on each and infer .
The conjugate marginal likelihood (after integrating out ) is
and the joint density evaluated during NUTS is
Each pattern differs only in how θ gets into the Equinox kernel module that evaluates .
Numerical considerations¶
- Constrained hyperparameters. are all positive. We sample on the unconstrained real line using NumPyro’s
TransformedDistribution(typically orSoftPlusbijectors) so HMC/NUTS gradients are finite at the origin. - Reparameterization. Heavy-tailed priors on scales (e.g.
LogNormal) interact poorly with poorly-conditioned likelihoods. When posteriors pile up near zero, a non-centered reparameterization (Kingma & Welling (2014)-style whitening) dramatically improves the NUTS step-size adaptation. - Leapfrog / NUTS cost. The dominant cost per log-posterior gradient is the Cholesky of — floating-point ops. Pre-computing the kernel once at evaluation time and re-using its Cholesky is how
pyrox.gpstays fast. - Three patterns, same posterior. All three code patterns produce the same up to numerical noise; pyrox’s CI pins this via a regression test. So pattern choice is only about code ergonomics and not about sampling efficiency.
- Black-box VI as an alternative. For bigger or when NUTS is slow, swap
NUTSforSVIwith anAutoNormalguide — a mean-field Gaussian approximation fit by maximizing the ELBO Blei et al. (2017)Ranganath et al. (2014)Kucukelbir et al. (2017).
Notebooks¶
regression_masterclass_treeat— Pattern 1:eqx.tree_atsubstitutes sampled scalars into the equinox kernel by pointer. Most transparent, most boilerplate.regression_masterclass_pyrox_sample— Pattern 2:PyroxModule+pyrox_sampleregister the hyperparameter sites on behalf of the module. Mid-stack.regression_masterclass_parameterized— Pattern 3:Parameterized+PyroxParamattach prior metadata directly to the Equinox leaves, andpyrox.gpresolves sampling + constrained transforms automatically. Fewest user-facing moving parts.
References¶
- Kidger, P., & Garcia, C. (2021). Equinox: Neural Networks in JAX via Callable PyTrees and Filtered Transformations. Differentiable Programming Workshop at NeurIPS.
- Phan, D., Pradhan, N., & Jankowiak, M. (2019). Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro. arXiv:1912.11554.
- Kingma, D. P., & Welling, M. (2014). Auto-Encoding Variational Bayes. International Conference on Learning Representations (ICLR).
- Blei, D. M., Kucukelbir, A., & McAuliffe, J. D. (2017). Variational Inference: A Review for Statisticians. Journal of the American Statistical Association, 112(518), 859–877. 10.1080/01621459.2017.1285773
- Ranganath, R., Gerrish, S., & Blei, D. M. (2014). Black Box Variational Inference. International Conference on Artificial Intelligence and Statistics (AISTATS).
- Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic Differentiation Variational Inference. Journal of Machine Learning Research, 18(14), 1–45.