gaussx¶
Structured linear algebra, Gaussian distributions, and exponential family primitives for JAX.
Built on lineax, equinox, and matfree.
New here? Start with the Vision to understand why gaussx exists, then read the Architecture to see how it's organized.
Installation¶
Or with uv:
Quickstart¶
import jax.numpy as jnp
import lineax as lx
import gaussx
# Structured operators
A = lx.DiagonalLinearOperator(jnp.array([1.0, 2.0]))
B = lx.DiagonalLinearOperator(jnp.array([3.0, 4.0]))
K = gaussx.Kronecker(A, B)
# Primitives exploit structure automatically
v = jnp.ones(4)
x = gaussx.solve(K, v) # per-factor solve
ld = gaussx.logdet(K) # n2*logdet(A) + n1*logdet(B)
L = gaussx.cholesky(K) # Kronecker(chol(A), chol(B))
t = gaussx.trace(K) # trace(A) * trace(B)
API Notes¶
Several of the newer public APIs have explicit requirements that are worth calling out up front:
gaussx.kronecker_posterior_predictive(...)requiresK_test_diag_factors=when you want predictive variances.gaussx.ssm_to_naturals(...)validates thatQ[0]matchesP_0so the joint prior is internally consistent.gaussx.ImplicitKernelOperator(...)only advertises symmetry and PSD tolineaxwhen those tags are provided explicitly.
import jax.numpy as jnp
import lineax as lx
import gaussx
mean, var = gaussx.kronecker_posterior_predictive(
[Kx, Ky],
y,
noise_var=1e-2,
grid_shape=(nx, ny),
K_cross_factors=[Kx_star, Ky_star],
K_test_diag_factors=[jnp.ones(nx_star), jnp.ones(ny_star)],
)
theta_1, theta_2 = gaussx.ssm_to_naturals(A, Q, mu_0, P_0=Q[0])
kernel_op = gaussx.ImplicitKernelOperator(
kernel_fn,
X,
tags=frozenset({lx.symmetric_tag, lx.positive_semidefinite_tag}),
)
Examples¶
- Basics — operators, primitives, JAX transforms
- Operator Zoo — every operator type with structure visualization
- Woodbury Solve — step-by-step Woodbury identity
- Kronecker Eigendecomposition — per-factor eigen/cholesky/sqrt
- Kernel Regression — GP regression with hyperparameter optimization
- GP on a 2D Grid — Kronecker structure for spatial data
- Sparse Variational GP — inducing points with ELBO optimization
- Structured GP — Kronecker and low-rank comparison
- Solver Comparison — DenseSolver vs CGSolver
- Differentiating Through Solve — jax.grad through gaussx primitives