Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Regression masterclass — three patterns for parameter handling

A recurring friction when mixing Equinox modules Kidger & Garcia (2021) with NumPyro Phan et al. (2019) is where hyperparameters live. Equinox wants a PyTree of leaves; NumPyro wants named random sites. Pyrox picks three compatible idioms and the same regression problem is solved end-to-end in each, so you can A/B them on ergonomics.

The three patterns live on a spectrum:

PatternWhere the params liveWhat wraps themBoilerplate
Pattern 1eqx.tree_at + raw NumPyroequinox module leavesuser does the substitution by handmost
Pattern 2PyroxModule + pyrox_samplemodule leaves + a NumPyro platepyrox wires samplingmedium
Pattern 3Parameterized + PyroxParam + native pyrox.gpleaves carry prior metadatapyrox resolves everythingleast

All three are mathematically equivalent — they sample the same posterior. The choice is about how much machinery you want pyrox to do for you.

Model

All three notebooks target the same regression problem: an exact GP prior fGP(0,kθ)f \sim \mathcal{GP}(0, k_\theta) with Gaussian noise yi=f(xi)+εiy_i = f(x_i) + \varepsilon_i, εiN(0,σn2)\varepsilon_i \sim \mathcal{N}(0, \sigma_n^2), and hyperparameters θ=(,α,σn)\theta = (\ell, \alpha, \sigma_n) (length-scale, signal amplitude, noise scale). We place hyperpriors on each and infer p(θy)p(\theta \mid \mathbf{y}).

The conjugate marginal likelihood (after integrating out ff) is

logp(yθ)=12y(KXX(θ)+σn2I)1y12logKXX(θ)+σn2IN2log2π,\log p(\mathbf{y} \mid \theta) = -\tfrac{1}{2} \mathbf{y}^\top (K_{XX}(\theta) + \sigma_n^2 I)^{-1} \mathbf{y} - \tfrac{1}{2} \log\lvert K_{XX}(\theta) + \sigma_n^2 I\rvert - \tfrac{N}{2}\log 2\pi,

and the joint density evaluated during NUTS is

logp(θ,y)=logp(yθ)+logp(θ).\log p(\theta, \mathbf{y}) = \log p(\mathbf{y} \mid \theta) + \log p(\theta).

Each pattern differs only in how θ gets into the Equinox kernel module that evaluates KXX(θ)K_{XX}(\theta).

Numerical considerations

Notebooks

References

References
  1. Kidger, P., & Garcia, C. (2021). Equinox: Neural Networks in JAX via Callable PyTrees and Filtered Transformations. Differentiable Programming Workshop at NeurIPS.
  2. Phan, D., Pradhan, N., & Jankowiak, M. (2019). Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro. arXiv:1912.11554.
  3. Kingma, D. P., & Welling, M. (2014). Auto-Encoding Variational Bayes. International Conference on Learning Representations (ICLR).
  4. Blei, D. M., Kucukelbir, A., & McAuliffe, J. D. (2017). Variational Inference: A Review for Statisticians. Journal of the American Statistical Association, 112(518), 859–877. 10.1080/01621459.2017.1285773
  5. Ranganath, R., Gerrish, S., & Blei, D. M. (2014). Black Box Variational Inference. International Conference on Artificial Intelligence and Statistics (AISTATS).
  6. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic Differentiation Variational Inference. Journal of Machine Learning Research, 18(14), 1–45.