Skip to content

gaussx

Structured linear algebra, Gaussian distributions, and exponential family primitives for JAX.

Built on lineax, equinox, and matfree.

New here? Start with the Vision to understand why gaussx exists, then read the Architecture to see how it's organized.

Installation

pip install gaussx

Or with uv:

uv add gaussx

Quickstart

import jax.numpy as jnp
import lineax as lx
import gaussx

# Structured operators
A = lx.DiagonalLinearOperator(jnp.array([1.0, 2.0]))
B = lx.DiagonalLinearOperator(jnp.array([3.0, 4.0]))
K = gaussx.Kronecker(A, B)

# Primitives exploit structure automatically
v = jnp.ones(4)
x = gaussx.solve(K, v)       # per-factor solve
ld = gaussx.logdet(K)         # n2*logdet(A) + n1*logdet(B)
L = gaussx.cholesky(K)        # Kronecker(chol(A), chol(B))
t = gaussx.trace(K)           # trace(A) * trace(B)

API Notes

Several of the newer public APIs have explicit requirements that are worth calling out up front:

  • gaussx.kronecker_posterior_predictive(...) requires K_test_diag_factors= when you want predictive variances.
  • gaussx.ssm_to_naturals(...) validates that Q[0] matches P_0 so the joint prior is internally consistent.
  • gaussx.ImplicitKernelOperator(...) only advertises symmetry and PSD to lineax when those tags are provided explicitly.
import jax.numpy as jnp
import lineax as lx
import gaussx

mean, var = gaussx.kronecker_posterior_predictive(
    [Kx, Ky],
    y,
    noise_var=1e-2,
    grid_shape=(nx, ny),
    K_cross_factors=[Kx_star, Ky_star],
    K_test_diag_factors=[jnp.ones(nx_star), jnp.ones(ny_star)],
)

theta_1, theta_2 = gaussx.ssm_to_naturals(A, Q, mu_0, P_0=Q[0])

kernel_op = gaussx.ImplicitKernelOperator(
    kernel_fn,
    X,
    tags=frozenset({lx.symmetric_tag, lx.positive_semidefinite_tag}),
)

Examples