Architecture¶
gaussx is organized as a four-layer stack that extends lineax. Each layer builds on the one below. Users can enter at any layer depending on their needs.
Four-Layer Stack¶
┌─────────────────────────────────────────────────────────────────┐
│ Layer 3 — Recipes (v0.3+) │
│ Kalman filter/smoother, Kronecker GP recipes, LOVE, │
│ CVI sites, SSM natural params, interpolation │
├──────────────────────────────────────────────────────────────────┤
│ Layer 2 — Distributions + Sugar (v0.2+) │
│ MultivariateNormal, Schur complement, project, │
│ conditional variance, exponential family │
├──────────────────────────────────────────────────────────────────┤
│ Layer 1 — Operators (v0.0) │
│ Kronecker, BlockDiag, BlockTriDiag, LowRankUpdate, │
│ ImplicitKernelOperator │
│ Extend lineax.AbstractLinearOperator │
├──────────────────────────────────────────────────────────────────┤
│ Layer 0 — Primitives (v0.0) │
│ solve, logdet, cholesky, diag, trace, sqrt, inv │
│ Dispatch on operator type + tags │
└──────────────────────────────────────────────────────────────────┘
│ │
┌──────▼──────┐ ┌───────▼──────┐
│ lineax │ │ matfree │
│ (solvers) │ │ (Lanczos, │
│ │ │ SLQ, etc.) │
└─────────────┘ └──────────────┘
Layer 0 --- Primitives¶
Pure functions that match the equations in papers. Every function takes an operator and returns arrays or operators:
x = gaussx.solve(A, b) # solve Ax = b
ld = gaussx.logdet(A) # log|det(A)|
L = gaussx.cholesky(A) # A = LL^T
d = gaussx.diag(A) # diagonal entries
t = gaussx.trace(A) # tr(A)
S = gaussx.sqrt(A) # S such that SS = A
A_inv = gaussx.inv(A) # lazy A^{-1}
Each primitive uses isinstance dispatch to select the efficient code path based on operator type:
| Operator | solve |
logdet |
cholesky |
|---|---|---|---|
| Diagonal | O(n) divide | O(n) sum(log) | O(n) sqrt |
| BlockDiag | per-block | sum of logdets | per-block |
| Kronecker | Roth's lemma | scaled sum | per-factor |
| LowRankUpdate | Woodbury | det lemma | --- |
| Dense | lineax solver | slogdet | jax.scipy cholesky |
For large unstructured operators, Layer 0 delegates to matfree for iterative/stochastic algorithms (Lanczos for logdet, Hutchinson for trace).
Layer 1 --- Operators¶
Structured operators extending lineax.AbstractLinearOperator. Each is an equinox.Module (immutable PyTree), supports mv, as_matrix, transpose, and carries structural tags.
| Operator | Represents | Efficient mv |
|---|---|---|
Kronecker(A, B, ...) |
\(A \otimes B \otimes \cdots\) | Roth's column lemma via einops |
BlockDiag(A, B, ...) |
\(\mathrm{diag}(A, B, \ldots)\) | Per-block, concatenate |
BlockTriDiag(D, A) |
Symmetric block-tridiagonal precision | Banded block matvec |
LowRankUpdate(L, U, d, V) |
\(L + U \mathrm{diag}(d) V^\top\) | Base mv + rank-k update |
ImplicitKernelOperator(k, X) |
Matrix-free kernel Gram operator | Nested vmap kernel matvec |
ImplicitKernelOperator keeps structural claims explicit: if a kernel should be treated as symmetric or PSD by lineax, pass those tags when constructing it rather than relying on the operator to infer them.
Arithmetic (+, @, *) composes with lineax's built-in operators:
Layer 1.5 --- Solver Strategies¶
Pair solve + logdet into reusable strategy objects:
| Strategy | Algorithm | Best for |
|---|---|---|
DenseSolver |
Structural dispatch (Cholesky for PSD, etc.) | Small-medium, structured |
CGSolver |
Iterative CG + stochastic Lanczos logdet | Large PSD, matrix-free |
Strategies decouple the distribution from the solver --- a MultivariateNormal doesn't know or care whether it's doing dense Cholesky or iterative CG.
Layer 2 --- Distributions + Sugar¶
MultivariateNormal(loc, cov_operator, solver=...)--- accepts any covariance operator and any solver strategy- Compound operations:
project,unwhiten,schur_complement,conditional_variance - Gaussian exponential family: natural/expectation parameters, Fisher information
Layer 3 --- Recipes¶
Cross-library patterns: Kalman filter, RTS smoother, ensemble covariance, natural gradient updates, Kronecker GP marginal likelihood / posterior prediction, LOVE predictive variance, CVI site updates, and SSM-natural conversions. Thin wiring of Layer 0--2 operations into domain-specific sequences.
Two API constraints worth knowing:
kronecker_posterior_predictive(...)needs exact test prior diagonals viaK_test_diag_factorsfor predictive variances.ssm_to_naturals(...)expectsQ[0] == P_0and raises on inconsistent initial covariance inputs.
Structural Tags¶
Operators carry tags that drive dispatch. gaussx extends lineax's tag set:
| Tag | Source | Used by |
|---|---|---|
symmetric_tag |
lineax | solve, logdet |
positive_semidefinite_tag |
lineax | cholesky, sqrt, CG |
diagonal_tag |
lineax | all primitives (O(n) paths) |
kronecker_tag |
gaussx | all primitives (per-factor) |
block_diagonal_tag |
gaussx | all primitives (per-block) |
low_rank_tag |
gaussx | solve (Woodbury), logdet (det lemma) |
Query functions (is_kronecker, is_block_diagonal, is_low_rank, plus all lineax queries) let you inspect operator properties without knowing the concrete type.
Dependencies¶
| Package | Role | Required |
|---|---|---|
jax / jaxlib |
Array backend | Yes |
equinox |
Module system, PyTrees | Yes |
lineax |
Linear operators, solvers | Yes |
matfree |
Krylov methods, stochastic trace | Yes |
jaxtyping |
Array type annotations | Yes |
einops |
Tensor reshaping (Principle 5) | Yes |
numpyro |
Distributions (Layer 2) | Planned |
Package Layout¶
src/gaussx/
├── __init__.py # Public API
├── _tags.py # Structural tags + queries
├── _operators/ # Layer 1
│ ├── _kronecker.py
│ ├── _block_diag.py
│ ├── _block_tridiag.py
│ ├── _implicit_kernel.py
│ └── _low_rank_update.py
├── _primitives/ # Layer 0
│ ├── _solve.py
│ ├── _logdet.py
│ ├── _cholesky.py
│ ├── _diag.py
│ ├── _trace.py
│ ├── _sqrt.py
│ └── _inv.py
├── _strategies/ # Layer 1.5
│ ├── _base.py
│ ├── _dense.py
│ └── _cg.py
├── _recipes/ # Layer 3
│ ├── _kalman.py
│ ├── _kronecker_gp.py
│ ├── _love.py
│ ├── _cvi.py
│ └── _ssm_natural.py
└── _testing.py # Test utilities