Exponential Family Gaussians
This notebook demonstrates the gaussx exponential family module -- working with Gaussians in natural parameter form.
What you’ll learn:
- The exponential family form of the Gaussian: natural parameters ,
- Converting between mean/covariance and natural parameters
- Log-partition function, sufficient statistics, Fisher information
- KL divergence via Bregman divergence in natural parameter space
- Why natural parameters are useful (conjugate updates, message passing)
1. Background¶
Any member of the exponential family can be written as:
For a multivariate Gaussian with mean μ and precision , the natural parameters are:
The sufficient statistics are , and the log-partition function is:
This encodes the normalization constant. Everything in gaussx’s
_expfam module is built on these identities.
The Gaussian is the maximum-entropy distribution for a given mean and covariance (Jaynes, 1957), which makes its exponential family form central to several inference frameworks:
- Expectation propagation (Minka, 2001) — message passing operates directly in natural parameter space, where site updates are additive.
- Natural gradient methods (Amari, 1998) — the Fisher information metric gives the steepest descent direction in distribution space, and natural parameters make the Fisher matrix readily available.
- Variational inference (Wainwright & Jordan, 2008) — conjugate- computation variational inference exploits natural parameter additivity to perform closed-form coordinate updates.
2. Setup¶
from __future__ import annotations
import warnings
warnings.filterwarnings("ignore", message=r".*IProgress.*")
import jax
import jax.numpy as jnp
import lineax as lx
import gaussx
jax.config.update("jax_enable_x64", True)We will work with a simple 3D Gaussian throughout.
N = 3
mu = jnp.array([1.0, -0.5, 2.0])
# A small positive-definite covariance matrix
Sigma_mat = jnp.array(
[
[2.0, 0.5, 0.3],
[0.5, 1.0, 0.1],
[0.3, 0.1, 1.5],
]
)
Sigma_op = lx.MatrixLinearOperator(Sigma_mat)
# Precision = Sigma^{-1}
Lambda_mat = jnp.linalg.inv(Sigma_mat)
Lambda_op = lx.MatrixLinearOperator(Lambda_mat)
print("mu =", mu)
print("Sigma =\n", Sigma_mat)
print("Lambda =\n", Lambda_mat)mu = [ 1. -0.5 2. ]
Sigma =
[[2. 0.5 0.3]
[0.5 1. 0.1]
[0.3 0.1 1.5]]
Lambda =
[[ 0.58546169 -0.28290766 -0.09823183]
[-0.28290766 1.14341847 -0.01964637]
[-0.09823183 -0.01964637 0.68762279]]
3. Construction¶
There are three ways to build a GaussianExpFam.
# Method 1: from mean and covariance
q1 = gaussx.GaussianExpFam.from_mean_cov(mu, Sigma_op)
# Method 2: from mean and precision
q2 = gaussx.GaussianExpFam.from_mean_prec(mu, Lambda_op)
# Method 3: directly from natural parameters
eta1_manual = Lambda_mat @ mu
eta2_manual = -0.5 * Lambda_mat
q3 = gaussx.GaussianExpFam(
eta1=eta1_manual,
eta2=lx.MatrixLinearOperator(eta2_manual),
)
print("eta1 (from_mean_cov) =", q1.eta1)
print("eta1 (from_mean_prec) =", q2.eta1)
print("eta1 (manual) =", q3.eta1)
print()
print("eta2 (from_mean_cov) =\n", q1.eta2.as_matrix())
print("eta2 (from_mean_prec) =\n", q2.eta2.as_matrix())
print("eta2 (manual) =\n", q3.eta2.as_matrix())eta1 (from_mean_cov) = [ 0.53045187 -0.89390963 1.28683694]
eta1 (from_mean_prec) = [ 0.53045187 -0.89390963 1.28683694]
eta1 (manual) = [ 0.53045187 -0.89390963 1.28683694]
eta2 (from_mean_cov) =
[[-0.29273084 0.14145383 0.04911591]
[ 0.14145383 -0.57170923 0.00982318]
[ 0.04911591 0.00982318 -0.34381139]]
eta2 (from_mean_prec) =
[[-0.29273084 0.14145383 0.04911591]
[ 0.14145383 -0.57170923 0.00982318]
[ 0.04911591 0.00982318 -0.34381139]]
eta2 (manual) =
[[-0.29273084 0.14145383 0.04911591]
[ 0.14145383 -0.57170923 0.00982318]
[ 0.04911591 0.00982318 -0.34381139]]
# Verify all three agree
assert jnp.allclose(q1.eta1, q2.eta1, atol=1e-12)
assert jnp.allclose(q1.eta1, q3.eta1, atol=1e-12)
assert jnp.allclose(q1.eta2.as_matrix(), q2.eta2.as_matrix(), atol=1e-12)
assert jnp.allclose(q1.eta2.as_matrix(), q3.eta2.as_matrix(), atol=1e-12)
print("All three constructors produce identical natural parameters.")All three constructors produce identical natural parameters.
4. Conversions¶
Round-trip: natural expectation natural.
# Recover mean and covariance from the exp-fam object
mu_recovered, Sigma_recovered = gaussx.to_expectation(q1)
print("Original mu =", mu)
print("Recovered mu =", mu_recovered)
print("Match:", jnp.allclose(mu, mu_recovered, atol=1e-12))
print()
print("Original Sigma =\n", Sigma_mat)
print("Recovered Sigma =\n", Sigma_recovered.as_matrix())
print("Match:", jnp.allclose(Sigma_mat, Sigma_recovered.as_matrix(), atol=1e-10))Original mu = [ 1. -0.5 2. ]
Recovered mu = [ 1. -0.5 2. ]
Match: True
Original Sigma =
[[2. 0.5 0.3]
[0.5 1. 0.1]
[0.3 0.1 1.5]]
Recovered Sigma =
[[2. 0.5 0.3]
[0.5 1. 0.1]
[0.3 0.1 1.5]]
Match: True
# Also verify to_natural produces matching eta1, eta2
eta1_fn, eta2_fn = gaussx.to_natural(mu, Sigma_op)
print("eta1 (to_natural) =", eta1_fn)
print("eta1 (object) =", q1.eta1)
print("Match:", jnp.allclose(eta1_fn, q1.eta1, atol=1e-12))
print()
print("eta2 (to_natural) =\n", eta2_fn.as_matrix())
print("eta2 (object) =\n", q1.eta2.as_matrix())
print("Match:", jnp.allclose(eta2_fn.as_matrix(), q1.eta2.as_matrix(), atol=1e-12))eta1 (to_natural) = [ 0.53045187 -0.89390963 1.28683694]
eta1 (object) = [ 0.53045187 -0.89390963 1.28683694]
Match: True
eta2 (to_natural) =
[[-0.29273084 0.14145383 0.04911591]
[ 0.14145383 -0.57170923 0.00982318]
[ 0.04911591 0.00982318 -0.34381139]]
eta2 (object) =
[[-0.29273084 0.14145383 0.04911591]
[ 0.14145383 -0.57170923 0.00982318]
[ 0.04911591 0.00982318 -0.34381139]]
Match: True
5. Log-partition function¶
The log-partition encodes the normalization constant. We verify the gaussx implementation against a manual computation.
A_gaussx = gaussx.log_partition(q1)
# Manual computation:
# A(eta) = -0.25 * eta1^T eta2^{-1} eta1 - 0.5 * log|-2 eta2| + N/2 * log(2pi)
eta2_inv = jnp.linalg.inv(q1.eta2.as_matrix())
quad_term = -0.25 * q1.eta1 @ eta2_inv @ q1.eta1
neg2_eta2 = -2.0 * q1.eta2.as_matrix()
logdet_term = -0.5 * jnp.linalg.slogdet(neg2_eta2)[1]
base_term = 0.5 * N * jnp.log(2.0 * jnp.pi)
A_manual = quad_term + logdet_term + base_term
print(f"A(eta) [gaussx]: {A_gaussx:.10f}")
print(f"A(eta) [manual]: {A_manual:.10f}")
print(f"Match: {jnp.allclose(A_gaussx, A_manual, atol=1e-10)}")A(eta) [gaussx]: 4.9994211997
A(eta) [manual]: 4.9994211997
Match: True
6. Sufficient statistics¶
For the Gaussian, .
# Single vector
x = jnp.array([0.5, -1.0, 0.3])
t1, t2 = gaussx.sufficient_stats(x)
print("x =", x)
print("T_1(x) = x =", t1)
print("T_2(x) = x x^T =\n", t2)
print("Matches outer product:", jnp.allclose(t2, jnp.outer(x, x)))x = [ 0.5 -1. 0.3]
T_1(x) = x = [ 0.5 -1. 0.3]
T_2(x) = x x^T =
[[ 0.25 -0.5 0.15]
[-0.5 1. -0.3 ]
[ 0.15 -0.3 0.09]]
Matches outer product: True
# Batch of vectors
X = jnp.array(
[
[0.5, -1.0, 0.3],
[1.0, 0.0, -0.5],
[2.0, 1.0, 1.0],
]
)
t1_batch, t2_batch = gaussx.sufficient_stats(X)
print(f"Batch input shape: {X.shape}")
print(f"T_1 shape: {t1_batch.shape}")
print(f"T_2 shape: {t2_batch.shape}")
for i in range(X.shape[0]):
ok = jnp.allclose(t2_batch[i], jnp.outer(X[i], X[i]))
print(f" Sample {i}: outer product match = {ok}")Batch input shape: (3, 3)
T_1 shape: (3, 3)
T_2 shape: (3, 3, 3)
Sample 0: outer product match = True
Sample 1: outer product match = True
Sample 2: outer product match = True
7. Fisher information¶
For a Gaussian, the Fisher information in the natural parameterization equals the precision matrix .
More generally, for any exponential family the Fisher information is the Hessian of the log-partition function:
This connects the geometry of the natural parameter space to the curvature of , which is always convex (Barndorff-Nielsen, 1978).
F = gaussx.fisher_info(q1)
print("Fisher info =\n", F.as_matrix())
print("Precision =\n", Lambda_mat)
print("Match:", jnp.allclose(F.as_matrix(), Lambda_mat, atol=1e-10))Fisher info =
[[ 0.58546169 -0.28290766 -0.09823183]
[-0.28290766 1.14341847 -0.01964637]
[-0.09823183 -0.01964637 0.68762279]]
Precision =
[[ 0.58546169 -0.28290766 -0.09823183]
[-0.28290766 1.14341847 -0.01964637]
[-0.09823183 -0.01964637 0.68762279]]
Match: True
8. KL divergence¶
gaussx.kl_divergence(q, p) computes via the
Bregman divergence in natural parameter space:
The Bregman divergence interpretation means KL can be computed using only the log-partition function and its gradient — no explicit matrix inversions needed beyond what is in the natural-to-expectation conversion. This is computationally advantageous when the precision has structure (e.g. Kronecker, block-diagonal, or sparse).
We verify against the standard formula:
# Create a second Gaussian
mu_p = jnp.array([0.0, 1.0, -1.0])
Sigma_p_mat = jnp.array(
[
[1.5, 0.2, 0.0],
[0.2, 2.0, 0.4],
[0.0, 0.4, 1.0],
]
)
Sigma_p_op = lx.MatrixLinearOperator(Sigma_p_mat)
p = gaussx.GaussianExpFam.from_mean_cov(mu_p, Sigma_p_op)
# gaussx KL
kl_gaussx = gaussx.kl_divergence(q1, p)
# Standard formula KL(q || p)
Lambda_p = jnp.linalg.inv(Sigma_p_mat)
diff = mu_p - mu
kl_standard = 0.5 * (
jnp.trace(Lambda_p @ Sigma_mat)
+ diff @ Lambda_p @ diff
- N
+ jnp.linalg.slogdet(Sigma_p_mat)[1]
- jnp.linalg.slogdet(Sigma_mat)[1]
)
print(f"KL(q || p) [gaussx]: {kl_gaussx:.10f}")
print(f"KL(q || p) [standard]: {kl_standard:.10f}")
print(f"Match: {jnp.allclose(kl_gaussx, kl_standard, atol=1e-8)}")KL(q || p) [gaussx]: 7.2985079681
KL(q || p) [standard]: 7.2985079681
Match: True
9. Why natural parameters?¶
In natural parameter space, conjugate updates are just addition. When we combine a Gaussian prior with a Gaussian likelihood site, the posterior natural parameters are the sum of the prior and site natural parameters:
This makes message passing and variational inference extremely simple: no matrix inversions are needed for the update itself.
# Prior: our original Gaussian q1
prior = q1
print("Prior mean =", mu)
# A likelihood "site" -- e.g. from a single noisy observation
# Observation model: y = x + noise, noise ~ N(0, R)
R_mat = 0.1 * jnp.eye(N) # observation noise
y_obs = jnp.array([1.2, -0.3, 2.1]) # observed value
# Site natural parameters: eta1_site = R^{-1} y, eta2_site = -0.5 R^{-1}
R_inv = jnp.linalg.inv(R_mat)
site_eta1 = R_inv @ y_obs
site_eta2_mat = -0.5 * R_inv
print(f"Site eta1 = {site_eta1}")Prior mean = [ 1. -0.5 2. ]
Site eta1 = [12. -3. 21.]
# Posterior = prior + site (just addition in natural parameter space!)
post_eta1 = prior.eta1 + site_eta1
post_eta2_mat = prior.eta2.as_matrix() + site_eta2_mat
posterior = gaussx.GaussianExpFam(
eta1=post_eta1,
eta2=lx.MatrixLinearOperator(post_eta2_mat),
)
# Recover posterior mean and covariance
mu_post, Sigma_post = gaussx.to_expectation(posterior)
print("Posterior mean =", mu_post)
print("Posterior cov =\n", Sigma_post.as_matrix())Posterior mean = [ 1.19475983 -0.31540861 2.09569557]
Posterior cov =
[[0.09454148 0.00240175 0.00087336]
[0.00240175 0.08980037 0.00018715]
[0.00087336 0.00018715 0.09357455]]
# Verify against the standard Gaussian conditioning formula:
# Lambda_post = Lambda_prior + R^{-1}
# eta1_post = Lambda_prior @ mu_prior + R^{-1} @ y
# mu_post = Sigma_post @ eta1_post
Lambda_post_expected = Lambda_mat + R_inv
Sigma_post_expected = jnp.linalg.inv(Lambda_post_expected)
mu_post_expected = Sigma_post_expected @ (Lambda_mat @ mu + R_inv @ y_obs)
print("Expected posterior mean =", mu_post_expected)
print("Mean match:", jnp.allclose(mu_post, mu_post_expected, atol=1e-10))
print()
print("Expected posterior cov =\n", Sigma_post_expected)
print(
"Cov match:", jnp.allclose(Sigma_post.as_matrix(), Sigma_post_expected, atol=1e-10)
)Expected posterior mean = [ 1.19475983 -0.31540861 2.09569557]
Mean match: True
Expected posterior cov =
[[0.09454148 0.00240175 0.00087336]
[0.00240175 0.08980037 0.00018715]
[0.00087336 0.00018715 0.09357455]]
Cov match: True
10. Summary¶
| Concept | gaussx function |
|---|---|
| Build from mean/cov | GaussianExpFam.from_mean_cov(mu, Sigma) |
| Build from mean/prec | GaussianExpFam.from_mean_prec(mu, Lambda) |
| Natural expectation | gaussx.to_expectation(q) |
| Expectation natural | gaussx.to_natural(mu, Sigma) |
| Log-partition | gaussx.log_partition(q) |
| Sufficient statistics | gaussx.sufficient_stats(x) |
| Fisher information | gaussx.fisher_info(q) |
| KL divergence | gaussx.kl_divergence(q, p) |
The key takeaway: natural parameters turn Bayesian updates into simple addition, making them the natural choice for message passing, variational inference, and conjugate update algorithms.
References¶
- Amari, S. (1998). Natural gradient works efficiently in learning. Neural Computation, 10(2), 251-276.
- Barndorff-Nielsen, O. (1978). Information and Exponential Families in Statistical Theory. Wiley.
- Jaynes, E. T. (1957). Information theory and statistical mechanics. Physical Review, 106(4), 620-630.
- Minka, T. P. (2001). A Family of Algorithms for Approximate Bayesian Inference. PhD thesis, MIT.
- Wainwright, M. J. & Jordan, M. I. (2008). Graphical models, exponential families, and variational inference. Foundations and Trends in Machine Learning, 1(1-2), 1-305.