Exponential Family Gaussians¶
This notebook demonstrates the gaussx exponential family module -- working with Gaussians in natural parameter form.
What you'll learn:
- The exponential family form of the Gaussian: natural parameters $\eta_1$, $\eta_2$
- Converting between mean/covariance and natural parameters
- Log-partition function, sufficient statistics, Fisher information
- KL divergence via Bregman divergence in natural parameter space
- Why natural parameters are useful (conjugate updates, message passing)
1. Background¶
Any member of the exponential family can be written as:
$$ q(x \mid \eta) = h(x) \exp\!\bigl(\eta^\top T(x) - A(\eta)\bigr) $$
For a multivariate Gaussian with mean $\mu$ and precision $\Lambda = \Sigma^{-1}$, the natural parameters are:
$$ \eta_1 = \Lambda \mu, \qquad \eta_2 = -\tfrac{1}{2}\Lambda $$
The sufficient statistics are $T(x) = [x,\; x x^\top]$, and the log-partition function is:
$$ A(\eta) = -\tfrac{1}{4}\,\eta_1^\top \eta_2^{-1} \eta_1 - \tfrac{1}{2}\log\lvert -2\eta_2 \rvert + \tfrac{N}{2}\log(2\pi) $$
This encodes the normalization constant. Everything in gaussx's
_expfam module is built on these identities.
The Gaussian is the maximum-entropy distribution for a given mean and covariance (Jaynes, 1957), which makes its exponential family form central to several inference frameworks:
- Expectation propagation (Minka, 2001) — message passing operates directly in natural parameter space, where site updates are additive.
- Natural gradient methods (Amari, 1998) — the Fisher information metric gives the steepest descent direction in distribution space, and natural parameters make the Fisher matrix readily available.
- Variational inference (Wainwright & Jordan, 2008) — conjugate- computation variational inference exploits natural parameter additivity to perform closed-form coordinate updates.
2. Setup¶
from __future__ import annotations
import warnings
warnings.filterwarnings("ignore", message=r".*IProgress.*")
import jax
import jax.numpy as jnp
import lineax as lx
import gaussx
jax.config.update("jax_enable_x64", True)
We will work with a simple 3D Gaussian throughout.
N = 3
mu = jnp.array([1.0, -0.5, 2.0])
# A small positive-definite covariance matrix
Sigma_mat = jnp.array(
[
[2.0, 0.5, 0.3],
[0.5, 1.0, 0.1],
[0.3, 0.1, 1.5],
]
)
Sigma_op = lx.MatrixLinearOperator(Sigma_mat)
# Precision = Sigma^{-1}
Lambda_mat = jnp.linalg.inv(Sigma_mat)
Lambda_op = lx.MatrixLinearOperator(Lambda_mat)
print("mu =", mu)
print("Sigma =\n", Sigma_mat)
print("Lambda =\n", Lambda_mat)
mu = [ 1. -0.5 2. ] Sigma = [[2. 0.5 0.3] [0.5 1. 0.1] [0.3 0.1 1.5]] Lambda = [[ 0.58546169 -0.28290766 -0.09823183] [-0.28290766 1.14341847 -0.01964637] [-0.09823183 -0.01964637 0.68762279]]
3. Construction¶
There are three ways to build a GaussianExpFam.
# Method 1: from mean and covariance
q1 = gaussx.GaussianExpFam.from_mean_cov(mu, Sigma_op)
# Method 2: from mean and precision
q2 = gaussx.GaussianExpFam.from_mean_prec(mu, Lambda_op)
# Method 3: directly from natural parameters
eta1_manual = Lambda_mat @ mu
eta2_manual = -0.5 * Lambda_mat
q3 = gaussx.GaussianExpFam(
eta1=eta1_manual,
eta2=lx.MatrixLinearOperator(eta2_manual),
)
print("eta1 (from_mean_cov) =", q1.eta1)
print("eta1 (from_mean_prec) =", q2.eta1)
print("eta1 (manual) =", q3.eta1)
print()
print("eta2 (from_mean_cov) =\n", q1.eta2.as_matrix())
print("eta2 (from_mean_prec) =\n", q2.eta2.as_matrix())
print("eta2 (manual) =\n", q3.eta2.as_matrix())
eta1 (from_mean_cov) = [ 0.53045187 -0.89390963 1.28683694] eta1 (from_mean_prec) = [ 0.53045187 -0.89390963 1.28683694] eta1 (manual) = [ 0.53045187 -0.89390963 1.28683694] eta2 (from_mean_cov) = [[-0.29273084 0.14145383 0.04911591] [ 0.14145383 -0.57170923 0.00982318] [ 0.04911591 0.00982318 -0.34381139]] eta2 (from_mean_prec) = [[-0.29273084 0.14145383 0.04911591] [ 0.14145383 -0.57170923 0.00982318] [ 0.04911591 0.00982318 -0.34381139]] eta2 (manual) = [[-0.29273084 0.14145383 0.04911591] [ 0.14145383 -0.57170923 0.00982318] [ 0.04911591 0.00982318 -0.34381139]]
# Verify all three agree
assert jnp.allclose(q1.eta1, q2.eta1, atol=1e-12)
assert jnp.allclose(q1.eta1, q3.eta1, atol=1e-12)
assert jnp.allclose(q1.eta2.as_matrix(), q2.eta2.as_matrix(), atol=1e-12)
assert jnp.allclose(q1.eta2.as_matrix(), q3.eta2.as_matrix(), atol=1e-12)
print("All three constructors produce identical natural parameters.")
All three constructors produce identical natural parameters.
4. Conversions¶
Round-trip: natural $\to$ expectation $\to$ natural.
# Recover mean and covariance from the exp-fam object
mu_recovered, Sigma_recovered = gaussx.to_expectation(q1)
print("Original mu =", mu)
print("Recovered mu =", mu_recovered)
print("Match:", jnp.allclose(mu, mu_recovered, atol=1e-12))
print()
print("Original Sigma =\n", Sigma_mat)
print("Recovered Sigma =\n", Sigma_recovered.as_matrix())
print("Match:", jnp.allclose(Sigma_mat, Sigma_recovered.as_matrix(), atol=1e-10))
Original mu = [ 1. -0.5 2. ] Recovered mu = [ 1. -0.5 2. ] Match: True Original Sigma = [[2. 0.5 0.3] [0.5 1. 0.1] [0.3 0.1 1.5]] Recovered Sigma = [[2. 0.5 0.3] [0.5 1. 0.1] [0.3 0.1 1.5]] Match: True
# Also verify to_natural produces matching eta1, eta2
eta1_fn, eta2_fn = gaussx.to_natural(mu, Sigma_op)
print("eta1 (to_natural) =", eta1_fn)
print("eta1 (object) =", q1.eta1)
print("Match:", jnp.allclose(eta1_fn, q1.eta1, atol=1e-12))
print()
print("eta2 (to_natural) =\n", eta2_fn.as_matrix())
print("eta2 (object) =\n", q1.eta2.as_matrix())
print("Match:", jnp.allclose(eta2_fn.as_matrix(), q1.eta2.as_matrix(), atol=1e-12))
eta1 (to_natural) = [ 0.53045187 -0.89390963 1.28683694] eta1 (object) = [ 0.53045187 -0.89390963 1.28683694] Match: True eta2 (to_natural) = [[-0.29273084 0.14145383 0.04911591] [ 0.14145383 -0.57170923 0.00982318] [ 0.04911591 0.00982318 -0.34381139]] eta2 (object) = [[-0.29273084 0.14145383 0.04911591] [ 0.14145383 -0.57170923 0.00982318] [ 0.04911591 0.00982318 -0.34381139]] Match: True
5. Log-partition function¶
The log-partition $A(\eta)$ encodes the normalization constant. We verify the gaussx implementation against a manual computation.
A_gaussx = gaussx.log_partition(q1)
# Manual computation:
# A(eta) = -0.25 * eta1^T eta2^{-1} eta1 - 0.5 * log|-2 eta2| + N/2 * log(2pi)
eta2_inv = jnp.linalg.inv(q1.eta2.as_matrix())
quad_term = -0.25 * q1.eta1 @ eta2_inv @ q1.eta1
neg2_eta2 = -2.0 * q1.eta2.as_matrix()
logdet_term = -0.5 * jnp.linalg.slogdet(neg2_eta2)[1]
base_term = 0.5 * N * jnp.log(2.0 * jnp.pi)
A_manual = quad_term + logdet_term + base_term
print(f"A(eta) [gaussx]: {A_gaussx:.10f}")
print(f"A(eta) [manual]: {A_manual:.10f}")
print(f"Match: {jnp.allclose(A_gaussx, A_manual, atol=1e-10)}")
A(eta) [gaussx]: 4.9994211997 A(eta) [manual]: 4.9994211997 Match: True
6. Sufficient statistics¶
For the Gaussian, $T(x) = [x,\; x x^\top]$.
# Single vector
x = jnp.array([0.5, -1.0, 0.3])
t1, t2 = gaussx.sufficient_stats(x)
print("x =", x)
print("T_1(x) = x =", t1)
print("T_2(x) = x x^T =\n", t2)
print("Matches outer product:", jnp.allclose(t2, jnp.outer(x, x)))
x = [ 0.5 -1. 0.3] T_1(x) = x = [ 0.5 -1. 0.3] T_2(x) = x x^T = [[ 0.25 -0.5 0.15] [-0.5 1. -0.3 ] [ 0.15 -0.3 0.09]] Matches outer product: True
# Batch of vectors
X = jnp.array(
[
[0.5, -1.0, 0.3],
[1.0, 0.0, -0.5],
[2.0, 1.0, 1.0],
]
)
t1_batch, t2_batch = gaussx.sufficient_stats(X)
print(f"Batch input shape: {X.shape}")
print(f"T_1 shape: {t1_batch.shape}")
print(f"T_2 shape: {t2_batch.shape}")
for i in range(X.shape[0]):
ok = jnp.allclose(t2_batch[i], jnp.outer(X[i], X[i]))
print(f" Sample {i}: outer product match = {ok}")
Batch input shape: (3, 3) T_1 shape: (3, 3) T_2 shape: (3, 3, 3)
Sample 0: outer product match = True Sample 1: outer product match = True Sample 2: outer product match = True
7. Fisher information¶
For a Gaussian, the Fisher information in the natural parameterization equals the precision matrix $\Lambda = \Sigma^{-1}$.
More generally, for any exponential family the Fisher information is the Hessian of the log-partition function:
$$ F_{ij} = \frac{\partial^2 A(\eta)}{\partial \eta_i \,\partial \eta_j} $$
This connects the geometry of the natural parameter space to the curvature of $A(\eta)$, which is always convex (Barndorff-Nielsen, 1978).
F = gaussx.fisher_info(q1)
print("Fisher info =\n", F.as_matrix())
print("Precision =\n", Lambda_mat)
print("Match:", jnp.allclose(F.as_matrix(), Lambda_mat, atol=1e-10))
Fisher info = [[ 0.58546169 -0.28290766 -0.09823183] [-0.28290766 1.14341847 -0.01964637] [-0.09823183 -0.01964637 0.68762279]] Precision = [[ 0.58546169 -0.28290766 -0.09823183] [-0.28290766 1.14341847 -0.01964637] [-0.09823183 -0.01964637 0.68762279]] Match: True
8. KL divergence¶
gaussx.kl_divergence(q, p) computes $\text{KL}(q \| p)$ via the
Bregman divergence in natural parameter space:
$$ \text{KL}(q \| p) = A(\eta_p) - A(\eta_q) - (\eta_p - \eta_q)^\top \nabla A(\eta_q) $$
The Bregman divergence interpretation means KL can be computed using only the log-partition function and its gradient — no explicit matrix inversions needed beyond what is in the natural-to-expectation conversion. This is computationally advantageous when the precision has structure (e.g. Kronecker, block-diagonal, or sparse).
We verify against the standard formula:
$$ \text{KL}(q \| p) = \tfrac{1}{2}\bigl[ \text{tr}(\Sigma_p^{-1}\Sigma_q) + (\mu_p - \mu_q)^\top \Sigma_p^{-1}(\mu_p - \mu_q) - N + \log\tfrac{|\Sigma_p|}{|\Sigma_q|} \bigr] $$
# Create a second Gaussian
mu_p = jnp.array([0.0, 1.0, -1.0])
Sigma_p_mat = jnp.array(
[
[1.5, 0.2, 0.0],
[0.2, 2.0, 0.4],
[0.0, 0.4, 1.0],
]
)
Sigma_p_op = lx.MatrixLinearOperator(Sigma_p_mat)
p = gaussx.GaussianExpFam.from_mean_cov(mu_p, Sigma_p_op)
# gaussx KL
kl_gaussx = gaussx.kl_divergence(q1, p)
# Standard formula KL(q || p)
Lambda_p = jnp.linalg.inv(Sigma_p_mat)
diff = mu_p - mu
kl_standard = 0.5 * (
jnp.trace(Lambda_p @ Sigma_mat)
+ diff @ Lambda_p @ diff
- N
+ jnp.linalg.slogdet(Sigma_p_mat)[1]
- jnp.linalg.slogdet(Sigma_mat)[1]
)
print(f"KL(q || p) [gaussx]: {kl_gaussx:.10f}")
print(f"KL(q || p) [standard]: {kl_standard:.10f}")
print(f"Match: {jnp.allclose(kl_gaussx, kl_standard, atol=1e-8)}")
KL(q || p) [gaussx]: 7.2985079681 KL(q || p) [standard]: 7.2985079681 Match: True
9. Why natural parameters?¶
In natural parameter space, conjugate updates are just addition. When we combine a Gaussian prior with a Gaussian likelihood site, the posterior natural parameters are the sum of the prior and site natural parameters:
$$ \eta_{\text{post}} = \eta_{\text{prior}} + \eta_{\text{site}} $$
This makes message passing and variational inference extremely simple: no matrix inversions are needed for the update itself.
# Prior: our original Gaussian q1
prior = q1
print("Prior mean =", mu)
# A likelihood "site" -- e.g. from a single noisy observation
# Observation model: y = x + noise, noise ~ N(0, R)
R_mat = 0.1 * jnp.eye(N) # observation noise
y_obs = jnp.array([1.2, -0.3, 2.1]) # observed value
# Site natural parameters: eta1_site = R^{-1} y, eta2_site = -0.5 R^{-1}
R_inv = jnp.linalg.inv(R_mat)
site_eta1 = R_inv @ y_obs
site_eta2_mat = -0.5 * R_inv
print(f"Site eta1 = {site_eta1}")
Prior mean = [ 1. -0.5 2. ]
Site eta1 = [12. -3. 21.]
# Posterior = prior + site (just addition in natural parameter space!)
post_eta1 = prior.eta1 + site_eta1
post_eta2_mat = prior.eta2.as_matrix() + site_eta2_mat
posterior = gaussx.GaussianExpFam(
eta1=post_eta1,
eta2=lx.MatrixLinearOperator(post_eta2_mat),
)
# Recover posterior mean and covariance
mu_post, Sigma_post = gaussx.to_expectation(posterior)
print("Posterior mean =", mu_post)
print("Posterior cov =\n", Sigma_post.as_matrix())
Posterior mean = [ 1.19475983 -0.31540861 2.09569557] Posterior cov = [[0.09454148 0.00240175 0.00087336] [0.00240175 0.08980037 0.00018715] [0.00087336 0.00018715 0.09357455]]
# Verify against the standard Gaussian conditioning formula:
# Lambda_post = Lambda_prior + R^{-1}
# eta1_post = Lambda_prior @ mu_prior + R^{-1} @ y
# mu_post = Sigma_post @ eta1_post
Lambda_post_expected = Lambda_mat + R_inv
Sigma_post_expected = jnp.linalg.inv(Lambda_post_expected)
mu_post_expected = Sigma_post_expected @ (Lambda_mat @ mu + R_inv @ y_obs)
print("Expected posterior mean =", mu_post_expected)
print("Mean match:", jnp.allclose(mu_post, mu_post_expected, atol=1e-10))
print()
print("Expected posterior cov =\n", Sigma_post_expected)
print(
"Cov match:", jnp.allclose(Sigma_post.as_matrix(), Sigma_post_expected, atol=1e-10)
)
Expected posterior mean = [ 1.19475983 -0.31540861 2.09569557] Mean match: True Expected posterior cov = [[0.09454148 0.00240175 0.00087336] [0.00240175 0.08980037 0.00018715] [0.00087336 0.00018715 0.09357455]] Cov match: True
10. Summary¶
| Concept | gaussx function |
|---|---|
| Build from mean/cov | GaussianExpFam.from_mean_cov(mu, Sigma) |
| Build from mean/prec | GaussianExpFam.from_mean_prec(mu, Lambda) |
| Natural $\to$ expectation | gaussx.to_expectation(q) |
| Expectation $\to$ natural | gaussx.to_natural(mu, Sigma) |
| Log-partition $A(\eta)$ | gaussx.log_partition(q) |
| Sufficient statistics | gaussx.sufficient_stats(x) |
| Fisher information | gaussx.fisher_info(q) |
| KL divergence | gaussx.kl_divergence(q, p) |
The key takeaway: natural parameters turn Bayesian updates into simple addition, making them the natural choice for message passing, variational inference, and conjugate update algorithms.
References¶
- Amari, S. (1998). Natural gradient works efficiently in learning. Neural Computation, 10(2), 251-276.
- Barndorff-Nielsen, O. (1978). Information and Exponential Families in Statistical Theory. Wiley.
- Jaynes, E. T. (1957). Information theory and statistical mechanics. Physical Review, 106(4), 620-630.
- Minka, T. P. (2001). A Family of Algorithms for Approximate Bayesian Inference. PhD thesis, MIT.
- Wainwright, M. J. & Jordan, M. I. (2008). Graphical models, exponential families, and variational inference. Foundations and Trends in Machine Learning, 1(1-2), 1-305.