Skip to content

Architecture

gaussx is organized as a four-layer stack that extends lineax. Each layer builds on the one below. Users can enter at any layer depending on their needs.

Four-Layer Stack

┌─────────────────────────────────────────────────────────────────┐
│  Layer 3 — Recipes                                      (v0.3+) │
│  Kalman filter/smoother, Kronecker GP recipes, LOVE,             │
│  CVI sites, SSM natural params, interpolation                    │
├──────────────────────────────────────────────────────────────────┤
│  Layer 2 — Distributions + Sugar                        (v0.2+) │
│  MultivariateNormal, Schur complement, project,                  │
│  conditional variance, exponential family                        │
├──────────────────────────────────────────────────────────────────┤
│  Layer 1 — Operators                                     (v0.0) │
│  Kronecker, BlockDiag, BlockTriDiag, LowRankUpdate,              │
│  ImplicitKernelOperator                                          │
│  Extend lineax.AbstractLinearOperator                            │
├──────────────────────────────────────────────────────────────────┤
│  Layer 0 — Primitives                                    (v0.0) │
│  solve, logdet, cholesky, diag, trace, sqrt, inv                 │
│  Dispatch on operator type + tags                                │
└──────────────────────────────────────────────────────────────────┘
                 │                          │
          ┌──────▼──────┐           ┌───────▼──────┐
          │   lineax    │           │   matfree    │
          │  (solvers)  │           │  (Lanczos,   │
          │             │           │   SLQ, etc.) │
          └─────────────┘           └──────────────┘

Layer 0 --- Primitives

Pure functions that match the equations in papers. Every function takes an operator and returns arrays or operators:

x = gaussx.solve(A, b)      # solve Ax = b
ld = gaussx.logdet(A)        # log|det(A)|
L = gaussx.cholesky(A)       # A = LL^T
d = gaussx.diag(A)           # diagonal entries
t = gaussx.trace(A)          # tr(A)
S = gaussx.sqrt(A)           # S such that SS = A
A_inv = gaussx.inv(A)        # lazy A^{-1}

Each primitive uses isinstance dispatch to select the efficient code path based on operator type:

Operator solve logdet cholesky
Diagonal O(n) divide O(n) sum(log) O(n) sqrt
BlockDiag per-block sum of logdets per-block
Kronecker Roth's lemma scaled sum per-factor
LowRankUpdate Woodbury det lemma ---
Dense lineax solver slogdet jax.scipy cholesky

For large unstructured operators, Layer 0 delegates to matfree for iterative/stochastic algorithms (Lanczos for logdet, Hutchinson for trace).

Layer 1 --- Operators

Structured operators extending lineax.AbstractLinearOperator. Each is an equinox.Module (immutable PyTree), supports mv, as_matrix, transpose, and carries structural tags.

Operator Represents Efficient mv
Kronecker(A, B, ...) \(A \otimes B \otimes \cdots\) Roth's column lemma via einx
BlockDiag(A, B, ...) \(\mathrm{diag}(A, B, \ldots)\) Per-block, concatenate
BlockTriDiag(D, A) Symmetric block-tridiagonal precision Banded block matvec
LowRankUpdate(L, U, d, V) \(L + U \mathrm{diag}(d) V^\top\) Base mv + rank-k update
ImplicitKernelOperator(k, X) Matrix-free kernel Gram operator Nested vmap kernel matvec

ImplicitKernelOperator keeps structural claims explicit: if a kernel should be treated as symmetric or PSD by lineax, pass those tags when constructing it rather than relying on the operator to infer them.

Arithmetic (+, @, *) composes with lineax's built-in operators:

K = gaussx.Kronecker(A, B)
perturbed = K + 0.1 * lx.IdentityLinearOperator(...)

Layer 1.5 --- Solver Strategies

Pair solve + logdet into reusable strategy objects:

Strategy Algorithm Best for
DenseSolver Structural dispatch (Cholesky for PSD, etc.) Small-medium, structured
CGSolver Iterative CG + stochastic Lanczos logdet Large PSD, matrix-free

Strategies decouple the distribution from the solver --- a MultivariateNormal doesn't know or care whether it's doing dense Cholesky or iterative CG.

Layer 2 --- Distributions + Sugar

  • MultivariateNormal(loc, cov_operator, solver=...) --- accepts any covariance operator and any solver strategy
  • Compound operations: project, unwhiten, schur_complement, conditional_variance
  • Gaussian exponential family: natural/expectation parameters, Fisher information

Layer 3 --- Recipes

Cross-library patterns: Kalman filter, RTS smoother, ensemble covariance, natural gradient updates, Kronecker GP marginal likelihood / posterior prediction, LOVE predictive variance, CVI site updates, and SSM-natural conversions. Thin wiring of Layer 0--2 operations into domain-specific sequences.

Two API constraints worth knowing:

  • kronecker_posterior_predictive(...) needs exact test prior diagonals via K_test_diag_factors for predictive variances.
  • ssm_to_naturals(...) expects Q[0] == P_0 and raises on inconsistent initial covariance inputs.

Structural Tags

Operators carry tags that drive dispatch. gaussx extends lineax's tag set:

Tag Source Used by
symmetric_tag lineax solve, logdet
positive_semidefinite_tag lineax cholesky, sqrt, CG
diagonal_tag lineax all primitives (O(n) paths)
kronecker_tag gaussx all primitives (per-factor)
block_diagonal_tag gaussx all primitives (per-block)
low_rank_tag gaussx solve (Woodbury), logdet (det lemma)

Query functions (is_kronecker, is_block_diagonal, is_low_rank, plus all lineax queries) let you inspect operator properties without knowing the concrete type.

Dependencies

Package Role Required
jax / jaxlib Array backend Yes
equinox Module system, PyTrees Yes
lineax Linear operators, solvers Yes
matfree Krylov methods, stochastic trace Yes
jaxtyping Array type annotations Yes
einx Tensor reshaping/contraction (Principle 5) Yes
numpyro Distributions (Layer 2) Planned

Package Layout

src/gaussx/
├── __init__.py              # Public API
├── _tags.py                 # Structural tags + queries
├── _operators/              # Layer 1
│   ├── _kronecker.py
│   ├── _block_diag.py
│   ├── _block_tridiag.py
│   ├── _implicit_kernel.py
│   └── _low_rank_update.py
├── _primitives/             # Layer 0
│   ├── _solve.py
│   ├── _logdet.py
│   ├── _cholesky.py
│   ├── _diag.py
│   ├── _trace.py
│   ├── _sqrt.py
│   └── _inv.py
├── _strategies/             # Layer 1.5
│   ├── _base.py
│   ├── _dense.py
│   └── _cg.py
├── _recipes/                # Layer 3
│   ├── _kalman.py
│   ├── _kronecker_gp.py
│   ├── _love.py
│   ├── _cvi.py
│   └── _ssm_natural.py
└── _testing.py              # Test utilities