Distributions & Exponential Family¶
Layer 2: Gaussian distributions over structured covariance operators, the sugar operations that probabilistic code actually calls, and the exponential-family (natural-parameter) view used by variational and EP-style inference.
Multivariate normal distributions¶
NumPyro-compatible distributions whose covariance (or precision) is a lineax
operator, so sample / log_prob inherit every structured fast path.
MultivariateNormalPrecision carries \(\Lambda = \Sigma^{-1}\) directly — the
natural home for natural-parameter guides, where materializing \(\Sigma\) would
be wasted work. Both require numpyro to be installed.
Structured linear algebra and Gaussian primitives for JAX.
MultivariateNormal
¶
Bases: Distribution
Multivariate normal parameterized by a lineax linear operator.
Unlike numpyro.distributions.MultivariateNormal which requires
dense arrays, this distribution accepts any
lineax.AbstractLinearOperator as its covariance. This enables
efficient log-prob, sampling, and entropy for structured covariances
(Kronecker, block-diagonal, low-rank, diagonal, etc.) via gaussx
structural dispatch.
Requires the numpyro optional extra
(pip install "gaussx[numpyro]").
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
loc
|
Float[Array, '*batch N']
|
Mean vector of shape |
required |
cov_operator
|
AbstractLinearOperator
|
Covariance as a lineax linear operator of shape
|
required |
solver
|
AbstractSolverStrategy | None
|
Solver strategy for |
None
|
validate_args
|
bool | None
|
Whether to validate input arguments. |
None
|
Examples:
>>> import jax.numpy as jnp
>>> import lineax as lx
>>> from gaussx._distributions import MultivariateNormal
>>> Sigma = lx.MatrixLinearOperator(
... jnp.eye(3), lx.positive_semidefinite_tag
... )
>>> d = MultivariateNormal(jnp.zeros(3), Sigma)
>>> d.log_prob(jnp.ones(3))
Source code in src/gaussx/_distributions/_mvn.py
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 | |
MultivariateNormalPrecision
¶
Bases: Distribution
Multivariate normal parameterized by a precision (inverse covariance) operator.
This is the natural parameterization for many inference algorithms
(e.g. message passing, variational inference in natural coordinates).
The precision operator Lambda satisfies Lambda = Sigma^{-1}.
Requires the numpyro optional extra
(pip install "gaussx[numpyro]").
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
loc
|
Float[Array, '*batch N']
|
Mean vector of shape |
required |
prec_operator
|
AbstractLinearOperator
|
Precision matrix as a lineax linear operator of
shape |
required |
solver
|
AbstractSolverStrategy | None
|
Solver strategy for |
None
|
validate_args
|
bool | None
|
Whether to validate input arguments. |
None
|
Examples:
>>> import jax.numpy as jnp
>>> import lineax as lx
>>> from gaussx._distributions import MultivariateNormalPrecision
>>> Lambda = lx.MatrixLinearOperator(
... 2.0 * jnp.eye(3), lx.positive_semidefinite_tag
... )
>>> d = MultivariateNormalPrecision(jnp.zeros(3), Lambda)
>>> d.log_prob(jnp.ones(3))
Source code in src/gaussx/_distributions/_mvn_prec.py
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 | |
Gaussian sugar ops¶
evaluated through structured solve + logdet, plus entropy, quadratic
forms, KL divergences, conditioning, and the numerically stable Joseph-form
covariance update.
Structured linear algebra and Gaussian primitives for JAX.
gaussian_log_prob(loc: Float[Array, ' N'], cov_operator: lx.AbstractLinearOperator, value: Float[Array, ' N'], *, solver: AbstractSolverStrategy | None = None) -> Float[Array, '']
¶
Multivariate normal log-probability.
Computes:
log N(value | loc, Sigma)
= -0.5 * (N log(2 pi) + log|Sigma| + (value - loc)^T Sigma^{-1} (value - loc))
All expensive operations (solve, logdet) dispatch on
operator structure automatically, or through an explicit solver.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
loc
|
Float[Array, ' N']
|
Mean vector, shape |
required |
cov_operator
|
AbstractLinearOperator
|
Covariance operator Sigma, shape |
required |
value
|
Float[Array, ' N']
|
Observation vector, shape |
required |
solver
|
AbstractSolverStrategy | None
|
Optional solver strategy (needs both solve and logdet).
When |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar log-probability. |
Source code in src/gaussx/_distributions/_gaussian.py
gaussian_entropy(cov_operator: lx.AbstractLinearOperator, *, solver: AbstractLogdetStrategy | None = None) -> Float[Array, '']
¶
Entropy of a multivariate normal N(mu, Sigma).
Computes:
H = 0.5 * (N * (1 + log(2 pi)) + log|Sigma|)
Independent of the mean.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cov_operator
|
AbstractLinearOperator
|
Covariance operator, shape |
required |
solver
|
AbstractLogdetStrategy | None
|
Optional logdet strategy. When |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar entropy. |
Source code in src/gaussx/_distributions/_gaussian.py
quadratic_form(operator: lx.AbstractLinearOperator, x: Float[Array, ' N'], *, solver: AbstractSolveStrategy | None = None) -> Float[Array, '']
¶
Compute x^T A^{-1} x via a single solve.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
A non-singular linear operator A. |
required |
x
|
Float[Array, ' N']
|
Vector, shape |
required |
solver
|
AbstractSolveStrategy | None
|
Optional solve strategy. When |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar |
Source code in src/gaussx/_distributions/_gaussian.py
kl_standard_normal(m: Float[Array, ' N'], S: lx.AbstractLinearOperator, *, solver: AbstractLogdetStrategy | None = None) -> Float[Array, '']
¶
KL divergence KL(N(m, S) || N(0, I)).
Special case of dist_kl_divergence
with q_loc = 0 and q_cov = I. The identity prior means no
matrix inversion is required, making this more efficient than calling
the general form directly.
Computes:
KL = 0.5 * (tr(S) + m^T m - N - log|S|)
Ubiquitous in variational inference as the prior KL term.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
m
|
Float[Array, ' N']
|
Mean vector, shape |
required |
S
|
AbstractLinearOperator
|
Covariance operator, shape |
required |
solver
|
AbstractLogdetStrategy | None
|
Optional logdet strategy. When |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar KL divergence. |
See Also
dist_kl_divergence: General KL
between two multivariate normals with arbitrary lineax covariance
operators.
Source code in src/gaussx/_distributions/_gaussian.py
dist_kl_divergence(p_loc: Float[Array, ' N'], p_cov: lx.AbstractLinearOperator, q_loc: Float[Array, ' N'], q_cov: lx.AbstractLinearOperator) -> Float[Array, '']
¶
KL divergence KL(p || q) between two multivariate normals.
This is the canonical KL implementation for lineax-operator covariances. The specialised variants below all compute the same quantity but with different parameterisations suited to their use cases:
kl_standard_normal— special caseKL(N(m, S) || N(0, I)); avoids matrix inversion.gauss_kl— Cholesky-parameterised form for GP/SVGP models; supports multi-output and diagonalq_sqrt.kl_divergence— Bregman-divergence form operating on natural parameters for the exponential family.
Exploits structured operators for the trace and logdet terms.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
p_loc
|
Float[Array, ' N']
|
Mean of distribution p, shape |
required |
p_cov
|
AbstractLinearOperator
|
Covariance operator of distribution p. |
required |
q_loc
|
Float[Array, ' N']
|
Mean of distribution q, shape |
required |
q_cov
|
AbstractLinearOperator
|
Covariance operator of distribution q. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar KL divergence. |
Source code in src/gaussx/_distributions/_kl.py
conditional(loc: Float[Array, ' N'], cov: lx.AbstractLinearOperator, obs_idx: Int[Array, ' M'], obs_values: Float[Array, ' M'], *, solver: AbstractSolverStrategy | None = None) -> tuple[Float[Array, ' R'], lx.AbstractLinearOperator]
¶
Compute p(x_A | x_B = b) from a joint Gaussian p(x_A, x_B).
Given a joint distribution \(\mathcal{N}(\mu, \Sigma)\) and observed indices B with values b, returns the conditional distribution over the remaining indices A:
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
loc
|
Float[Array, ' N']
|
Mean vector of the joint distribution, shape |
required |
cov
|
AbstractLinearOperator
|
Covariance operator of the joint distribution, shape |
required |
obs_idx
|
Int[Array, ' M']
|
Indices of the observed variables, shape |
required |
obs_values
|
Float[Array, ' M']
|
Observed values, shape |
required |
solver
|
AbstractSolverStrategy | None
|
Optional solver strategy for structured linear algebra.
When |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, ' R']
|
Tuple |
AbstractLinearOperator
|
conditional distribution over unobserved variables. |
Source code in src/gaussx/_distributions/_conditional.py
joseph_update(P_pred: Float[Array, 'N N'], K: Float[Array, 'N M'], H: Float[Array, 'M N'], R: Float[Array, 'M M']) -> Float[Array, 'N N']
¶
Numerically stable Joseph-form covariance update.
Computes the updated covariance after a Kalman measurement update:
P_update = (I - K H) P_pred (I - K H)^T + K R K^T
This form is more numerically stable than the simplified
P = P_pred - K S K^T or P = (I - K H) P_pred because it
guarantees symmetry and is more robust when the Kalman gain K
is approximate or the system is poorly conditioned.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
P_pred
|
Float[Array, 'N N']
|
Predicted covariance, shape |
required |
K
|
Float[Array, 'N M']
|
Kalman gain, shape |
required |
H
|
Float[Array, 'M N']
|
Observation model, shape |
required |
R
|
Float[Array, 'M M']
|
Observation noise covariance, shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'N N']
|
Updated covariance, shape |
Source code in src/gaussx/_distributions/_joseph.py
add_jitter(operator: lx.AbstractLinearOperator, jitter: float = 1e-06) -> lx.AbstractLinearOperator
¶
Add diagonal jitter for numerical stability: A + eps * I.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
A linear operator, shape |
required |
jitter
|
float
|
Scalar jitter value. Default |
1e-06
|
Returns:
| Type | Description |
|---|---|
AbstractLinearOperator
|
|
Source code in src/gaussx/_distributions/_gaussian.py
project(K_XZ: Float[Array, 'B M'], L_Z: lx.AbstractLinearOperator) -> Float[Array, 'B M']
¶
Compute A_X = K_XZ @ K_ZZ^{-1} via Cholesky solve.
Solves L_Z @ L_Z^T @ A_X^T = K_XZ^T using forward/backward
substitution. Used in sparse variational GPs to project test
points onto the inducing space.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
K_XZ
|
Float[Array, 'B M']
|
Cross-covariance matrix, shape |
required |
L_Z
|
AbstractLinearOperator
|
Lower-triangular Cholesky factor of K_ZZ, shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'B M']
|
Projection matrix A_X, shape |
Source code in src/gaussx/_distributions/_project.py
Exponential family¶
The Gaussian in natural form: \(\eta_1 = \Lambda\mu\), \(\eta_2 = -\tfrac12 \Lambda\). Conversions between mean/covariance, natural, and expectation parameterizations — multivariate (operator-aware) and univariate (per-site diagonal) — plus the log-partition, Fisher information, and sufficient statistics that natural-gradient and EP updates are built from.
Structured linear algebra and Gaussian primitives for JAX.
GaussianExpFam
¶
Bases: Module
Gaussian in natural (exponential family) parameters.
where:
- Natural parameters:
eta1 = Lambda @ mu,eta2 = -0.5 * Lambda - Sufficient statistics:
T(x) = [x, x x^T] - Log-partition:
A(eta) = -0.25 * eta1^T eta2^{-1} eta1 - 0.5 * log|-2 eta2| - Base measure:
h(x) = (2 pi)^{-N/2}
Attributes:
| Name | Type | Description |
|---|---|---|
eta1 |
Float[Array, ' N']
|
Natural location parameter, shape |
eta2 |
AbstractLinearOperator
|
Natural precision-like operator, shape |
Source code in src/gaussx/_expfam/_gaussian.py
from_mean_cov(mu: Float[Array, ' N'], Sigma: lx.AbstractLinearOperator) -> GaussianExpFam
staticmethod
¶
Construct from mean and covariance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu
|
Float[Array, ' N']
|
Mean vector, shape |
required |
Sigma
|
AbstractLinearOperator
|
Covariance operator, shape |
required |
Returns:
| Type | Description |
|---|---|
GaussianExpFam
|
A |
Source code in src/gaussx/_expfam/_gaussian.py
from_mean_prec(mu: Float[Array, ' N'], Lambda: lx.AbstractLinearOperator) -> GaussianExpFam
staticmethod
¶
Construct from mean and precision.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu
|
Float[Array, ' N']
|
Mean vector, shape |
required |
Lambda
|
AbstractLinearOperator
|
Precision operator, shape |
required |
Returns:
| Type | Description |
|---|---|
GaussianExpFam
|
A |
Source code in src/gaussx/_expfam/_gaussian.py
to_natural(mu: Float[Array, ' N'], Sigma: lx.AbstractLinearOperator) -> tuple[Float[Array, ' N'], lx.AbstractLinearOperator]
¶
Convert expectation to natural parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu
|
Float[Array, ' N']
|
Mean vector, shape |
required |
Sigma
|
AbstractLinearOperator
|
Covariance operator, shape |
required |
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, ' N'], AbstractLinearOperator]
|
Tuple |
Source code in src/gaussx/_expfam/_gaussian.py
to_expectation(expfam: GaussianExpFam) -> tuple[Float[Array, ' N'], lx.AbstractLinearOperator]
¶
Convert natural to expectation parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
expfam
|
GaussianExpFam
|
Gaussian in natural form. |
required |
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, ' N'], AbstractLinearOperator]
|
Tuple |
Source code in src/gaussx/_expfam/_gaussian.py
mean_cov_to_natural(mu: Float[Array, ' N'], Sigma: lx.AbstractLinearOperator, *, solver: AbstractSolverStrategy | None = None) -> tuple[Float[Array, ' N'], lx.AbstractLinearOperator]
¶
Convert mean/covariance to natural parameters (operator form).
Given mean mu and covariance Sigma:
eta1 = solve(Sigma, mu)eta2 = -0.5 * inv(Sigma)
Operator structure (diagonal, Kronecker, …) is exploited via
structural dispatch. For dense-array inputs see
meanvar_to_natural.
For block-tridiagonal (SSM) inputs see
gaussx._ssm._ssm_natural.ssm_to_naturals.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu
|
Float[Array, ' N']
|
Mean vector, shape |
required |
Sigma
|
AbstractLinearOperator
|
Covariance operator, shape |
required |
solver
|
AbstractSolverStrategy | None
|
Optional solver strategy. When |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, ' N']
|
Tuple |
AbstractLinearOperator
|
eta2 is a linear operator. |
Source code in src/gaussx/_expfam/_natural.py
natural_to_mean_cov(eta1: Float[Array, ' N'], eta2: lx.AbstractLinearOperator, *, solver: AbstractSolverStrategy | None = None) -> tuple[Float[Array, ' N'], lx.AbstractLinearOperator]
¶
Convert natural parameters to mean/covariance (operator form).
Given natural parameters (eta1, eta2) where
eta1 = Lambda @ mu and eta2 = -0.5 * Lambda:
mu = solve(-2 * eta2, eta1)Sigma = inv(-2 * eta2)
Operator structure (diagonal, Kronecker, …) is exploited via
structural dispatch. For dense-array inputs see
natural_to_meanvar.
For block-tridiagonal (SSM) inputs see
gaussx._ssm._ssm_natural.naturals_to_ssm.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
eta1
|
Float[Array, ' N']
|
Natural location parameter, shape |
required |
eta2
|
AbstractLinearOperator
|
Natural precision-like operator, shape |
required |
solver
|
AbstractSolverStrategy | None
|
Optional solver strategy. When |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, ' N']
|
Tuple |
AbstractLinearOperator
|
Sigma is a linear operator. |
Source code in src/gaussx/_expfam/_natural.py
meanvar_to_natural(mu: Float[Array, '*batch N'], S_sqrt: Float[Array, '*batch N N']) -> tuple[Float[Array, '*batch N'], Float[Array, '*batch N N']]
¶
Convert mean/variance (Cholesky) to natural parameters.
Given mu and lower-triangular S_sqrt such that
Sigma = S_sqrt @ S_sqrt^T:
eta1 = Sigma^{-1} mueta2 = -0.5 * Sigma^{-1}
Uses the Cholesky factor directly via triangular solves; no solver parameter is exposed because the underlying systems are triangular rather than symmetric/PSD, and iterative strategies (CG, BBMM, PreconditionedCG, MINRES) are not valid here.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu
|
Float[Array, '*batch N']
|
Mean vector, shape |
required |
S_sqrt
|
Float[Array, '*batch N N']
|
Lower-triangular Cholesky factor, shape |
required |
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, '*batch N'], Float[Array, '*batch N N']]
|
Tuple |
Source code in src/gaussx/_expfam/_natural.py
natural_to_meanvar(eta1: Float[Array, '*batch N'], eta2: Float[Array, '*batch N N'], *, solver: AbstractSolverStrategy | None = None) -> tuple[Float[Array, '*batch N'], Float[Array, '*batch N N']]
¶
Convert natural parameters to mean/variance (Cholesky).
Given eta1 = Lambda @ mu and eta2 = -0.5 * Lambda:
Sigma = (-2 * eta2)^{-1}mu = Sigma @ eta1S_sqrt = cholesky(Sigma)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
eta1
|
Float[Array, '*batch N']
|
Natural location parameter, shape |
required |
eta2
|
Float[Array, '*batch N N']
|
Natural quadratic parameter, shape |
required |
solver
|
AbstractSolverStrategy | None
|
Optional solver strategy for structured linear algebra.
When |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, '*batch N']
|
Tuple |
Float[Array, '*batch N N']
|
Cholesky factor of the covariance. |
Source code in src/gaussx/_expfam/_natural.py
meanvar_to_expectation(mu: Float[Array, '*batch N'], S_sqrt: Float[Array, '*batch N N']) -> tuple[Float[Array, '*batch N'], Float[Array, '*batch N N']]
¶
Convert mean/variance (Cholesky) to expectation parameters.
Given mu and S_sqrt (lower-triangular Cholesky of Sigma):
m1 = mum2 = mu @ mu^T + Sigma = mu @ mu^T + S_sqrt @ S_sqrt^T
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu
|
Float[Array, '*batch N']
|
Mean vector, shape |
required |
S_sqrt
|
Float[Array, '*batch N N']
|
Lower-triangular Cholesky factor, shape |
required |
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, '*batch N'], Float[Array, '*batch N N']]
|
Tuple |
Source code in src/gaussx/_expfam/_natural.py
expectation_to_meanvar(m1: Float[Array, '*batch N'], m2: Float[Array, '*batch N N']) -> tuple[Float[Array, '*batch N'], Float[Array, '*batch N N']]
¶
Convert expectation parameters to mean/variance (Cholesky).
Given m1 = mu and m2 = mu @ mu^T + Sigma:
mu = m1Sigma = m2 - m1 @ m1^TS_sqrt = cholesky(Sigma)
No solver parameter is exposed because the only linear-algebra operation is Cholesky factorization, which is structurally fixed.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
m1
|
Float[Array, '*batch N']
|
First moment (mean), shape |
required |
m2
|
Float[Array, '*batch N N']
|
Second moment, shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '*batch N']
|
Tuple |
Float[Array, '*batch N N']
|
Cholesky factor of the covariance. |
Source code in src/gaussx/_expfam/_natural.py
expectation_to_natural(m1: Float[Array, '*batch N'], m2: Float[Array, '*batch N N'], *, solver: AbstractSolverStrategy | None = None) -> tuple[Float[Array, '*batch N'], Float[Array, '*batch N N']]
¶
Convert expectation parameters to natural parameters.
Given m1 = mu and m2 = mu @ mu^T + Sigma:
Sigma = m2 - m1 @ m1^Teta1 = Sigma^{-1} @ m1eta2 = -0.5 * Sigma^{-1}
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
m1
|
Float[Array, '*batch N']
|
First moment (mean), shape |
required |
m2
|
Float[Array, '*batch N N']
|
Second moment, shape |
required |
solver
|
AbstractSolverStrategy | None
|
Optional solver strategy for structured linear algebra.
When |
None
|
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, '*batch N'], Float[Array, '*batch N N']]
|
Tuple |
Source code in src/gaussx/_expfam/_natural.py
natural_to_expectation(eta1: Float[Array, '*batch N'], eta2: Float[Array, '*batch N N'], *, solver: AbstractSolverStrategy | None = None) -> tuple[Float[Array, '*batch N'], Float[Array, '*batch N N']]
¶
Convert natural parameters to expectation parameters.
Given eta1 = Lambda @ mu and eta2 = -0.5 * Lambda:
Sigma = (-2 * eta2)^{-1}mu = Sigma @ eta1m1 = mum2 = mu @ mu^T + Sigma
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
eta1
|
Float[Array, '*batch N']
|
Natural location parameter, shape |
required |
eta2
|
Float[Array, '*batch N N']
|
Natural quadratic parameter, shape |
required |
solver
|
AbstractSolverStrategy | None
|
Optional solver strategy for structured linear algebra.
When |
None
|
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, '*batch N'], Float[Array, '*batch N N']]
|
Tuple |
Source code in src/gaussx/_expfam/_natural.py
log_partition(expfam: GaussianExpFam) -> Float[Array, '']
¶
Log-partition function A(eta).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
expfam
|
GaussianExpFam
|
Gaussian in natural form. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar log-partition value. |
Source code in src/gaussx/_expfam/_gaussian.py
fisher_info(expfam: GaussianExpFam) -> lx.AbstractLinearOperator
¶
Fisher information matrix F(eta) = nabla^2 A(eta).
For a Gaussian, the Fisher information in terms of the
covariance is Sigma^{-1} (the precision matrix).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
expfam
|
GaussianExpFam
|
Gaussian in natural form. |
required |
Returns:
| Type | Description |
|---|---|
AbstractLinearOperator
|
Precision operator (the Fisher information matrix). |
Source code in src/gaussx/_expfam/_gaussian.py
sufficient_stats(x: Float[Array, '*batch N']) -> tuple[Float[Array, '*batch N'], Float[Array, '*batch N N']]
¶
Compute sufficient statistics T(x) = [x, x x^T].
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Float[Array, '*batch N']
|
Data vector, shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '*batch N']
|
Tuple |
Float[Array, '*batch N N']
|
shape |
Source code in src/gaussx/_expfam/_gaussian.py
kl_divergence(q: GaussianExpFam, p: GaussianExpFam) -> Float[Array, '']
¶
KL divergence KL(q || p) via the Bregman-divergence form on
natural parameters.
Exponential-family expression of the KL divergence in terms of the
log-partition A and the natural parameters of q and p.
Mathematically equivalent to
dist_kl_divergence.
The current implementation evaluates the Bregman form by routing
through to_expectation for the natural-gradient term
(eta_p - eta_q)^T nabla A(eta_q). The second-moment contraction
splits into a quadratic form (operator matvecs) plus
gaussx.trace_product, so structured eta2 / Sigma_q
operators are never materialized. The benefit relative to
dist_kl_divergence is keeping the gradient flowing in
natural-parameter space (suitable inside a natural-gradient loop).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
q
|
GaussianExpFam
|
First Gaussian (the "true" distribution). |
required |
p
|
GaussianExpFam
|
Second Gaussian (the "approximate" distribution). |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar KL divergence. |
See Also
dist_kl_divergence: General KL
in mean/covariance form with lineax operators.