Linear-Algebra Utilities¶
Numerically careful building blocks shared by the higher layers: robust factorizations, classical matrix identities in operator form, and batched / matrix-RHS solve helpers.
Robust factorization & hygiene¶
safe_cholesky retries with geometrically growing diagonal jitter inside a
jax.lax.while_loop (JIT-compatible) when a matrix is not numerically
positive-definite; symmetrize removes the floating-point asymmetry that
accumulates in covariance updates.
Structured linear algebra and Gaussian primitives for JAX.
safe_cholesky(operator: lx.AbstractLinearOperator, *, initial_jitter: float = 1e-08, max_jitter: float = 0.01, max_retries: int = 5, growth_factor: float = 10.0) -> Float[Array, 'N N']
¶
Cholesky decomposition with adaptive jitter for near-singular matrices.
The first attempt routes through gaussx._primitives.cholesky, so
structured operators (DiagonalLinearOperator, BlockDiag,
Kronecker, BlockTriDiag) keep their structure on the happy path.
If the result contains NaNs (the matrix is not numerically
positive-definite), retries with geometrically increasing diagonal jitter:
cholesky(A + eps * I) where eps starts at initial_jitter and
grows by growth_factor each retry, up to max_jitter. Jittering in
general destroys structure, so retries operate on the dense matrix form.
Uses jax.lax.while_loop internally so the function is fully
JIT-compatible.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
A lineax linear operator whose Cholesky factor is required. Must be square and positive-definite. |
required |
initial_jitter
|
float
|
Starting jitter magnitude added to the diagonal. |
1e-08
|
max_jitter
|
float
|
Upper bound on jitter (clamped after growth). |
0.01
|
max_retries
|
int
|
Maximum number of jittered retries after the initial attempt. |
5
|
growth_factor
|
float
|
Multiplicative factor applied to jitter each retry. |
10.0
|
Returns:
| Name | Type | Description |
|---|---|---|
Float[Array, 'N N']
|
Lower-triangular Cholesky factor as a dense array. |
|
Float[Array, 'N N']
|
If all attempts fail the result will contain NaNs — this is |
|
intentional |
Float[Array, 'N N']
|
JAX cannot raise exceptions inside |
Float[Array, 'N N']
|
code, so callers should check for NaNs when robustness matters. |
Source code in src/gaussx/_linalg/_safe_cholesky.py
symmetrize(mat: Float[Array, '... n n']) -> Float[Array, '... n n']
¶
Symmetrize the trailing two axes: 0.5 * (X + X^T).
Eliminates residual floating-point asymmetry after covariance
updates. Works on batched stacks (..., n, n).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mat
|
Float[Array, '... n n']
|
Square matrix or batched stack of square matrices. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '... n n']
|
The symmetric part of |
Source code in src/gaussx/_linalg/_symmetrize.py
Matrix identities¶
Woodbury, Schur complements, and conditional (Schur-complement) variances — the identities behind every Gaussian conditioning step, exposed directly so recipes never re-derive them.
Structured linear algebra and Gaussian primitives for JAX.
woodbury_solve(base: lx.AbstractLinearOperator, U: Float[Array, 'N k'], D: Float[Array, ' k'], b: Float[Array, ' N']) -> Float[Array, ' N']
¶
Standalone Woodbury identity solve: (L + U diag(D) U^T)^{-1} b.
Convenience function for cases where the user has the components
but doesn't want to construct a LowRankUpdate operator.
Uses the identity:
(L + U D U^T)^{-1} b = L^{-1}b - L^{-1}U C^{-1} U^T L^{-1}b
where C = D^{-1} + U^T L^{-1} U is the (k, k) capacitance
matrix.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
base
|
AbstractLinearOperator
|
Base operator L, shape |
required |
U
|
Float[Array, 'N k']
|
Low-rank factor, shape |
required |
D
|
Float[Array, ' k']
|
Diagonal scaling, shape |
required |
b
|
Float[Array, ' N']
|
Right-hand side, shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, ' N']
|
Solution x, shape |
Source code in src/gaussx/_linalg/_woodbury.py
schur_complement(K_XX: lx.AbstractLinearOperator, K_XZ: Float[Array, 'N M'], K_ZZ: lx.AbstractLinearOperator) -> LowRankUpdate
¶
Schur complement: K_XX - K_XZ @ K_ZZ^{-1} @ K_ZX.
Central to GP conditional distributions. Returns a
LowRankUpdate operator so that downstream operations
(solve, logdet) can exploit the low-rank structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
K_XX
|
AbstractLinearOperator
|
Prior covariance, shape |
required |
K_XZ
|
Float[Array, 'N M']
|
Cross-covariance, shape |
required |
K_ZZ
|
AbstractLinearOperator
|
Inducing covariance, shape |
required |
Returns:
| Type | Description |
|---|---|
LowRankUpdate
|
A |
Source code in src/gaussx/_linalg/_schur.py
conditional_variance(K_XX_diag: Float[Array, ' N'], K_XZ: Float[Array, 'N M'] | lx.AbstractLinearOperator | None = None, A_X: Float[Array, 'N M'] | None = None, S_u: lx.AbstractLinearOperator | None = None) -> Float[Array, ' N']
¶
Predictive variance: Schur complement diagonal plus optional variational correction.
Computes the diagonal of the conditional covariance:
diag(K_XX - A_X K_XZ^T) + diag(A_X S_u A_X^T)
where A_X = K_XZ K_ZZ^{-1} is the projection matrix.
Negative base variances are clamped to zero.
Without S_u this is the exact GP predictive variance (diagonal of the
Schur complement). With S_u it is the sparse GP variational predictive
variance that adds a variational covariance correction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
K_XX_diag
|
Float[Array, ' N']
|
Prior diagonal variances, shape |
required |
K_XZ
|
Float[Array, 'N M'] | AbstractLinearOperator | None
|
Cross-covariance matrix, shape |
None
|
A_X
|
Float[Array, 'N M'] | None
|
Projection matrix |
None
|
S_u
|
AbstractLinearOperator | None
|
Optional variational covariance, shape |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, ' N']
|
Predictive variances, shape |
Note
For one release the legacy three-positional-argument form
conditional_variance(base_diag, A_X, S_u) (where
base_diag was the Schur diagonal already, A_X was the
projection, and S_u was the variational covariance) is
still accepted: it is detected when the second positional
argument is a lineax.AbstractLinearOperator (the old
S_u slot type). Such calls emit a
DeprecationWarning and compute
base_diag + diag(A_X S_u A_X^T) without the
Schur subtraction.
Source code in src/gaussx/_linalg/_schur.py
52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | |
diag_conditional_variance(K_XX_diag: Float[Array, ' N'], K_XZ: Float[Array, 'N M'], A_X: Float[Array, 'N M']) -> Float[Array, ' N']
¶
Diagonal of Schur complement: diag(K_XX - A K_ZX).
Thin wrapper around gaussx.conditional_variance (defined in
gaussx._linalg._schur.conditional_variance) without a
variational covariance. Use gaussx.conditional_variance
directly when a variational correction S_u is also needed.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
K_XX_diag
|
Float[Array, ' N']
|
Prior diagonal variances, shape |
required |
K_XZ
|
Float[Array, 'N M']
|
Cross-covariance, shape |
required |
A_X
|
Float[Array, 'N M']
|
Projection matrix |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, ' N']
|
Conditional variances, shape |
Source code in src/gaussx/_linalg/_linalg.py
cov_transform(J: Float[Array, 'M N'] | lx.AbstractLinearOperator, cov_operator: lx.AbstractLinearOperator) -> lx.AbstractLinearOperator
¶
Covariance propagation through a linear map: J @ Sigma @ J^T.
Used in error propagation, Kalman filter updates, and first-order uncertainty propagation.
Exploits structure where it can:
- Operator-valued
J: routes throughsandwich, which preserves matchedKronecker/BlockDiagstructure and avoids materialisingSigmawhen eitherJorcov_operatoris diagonal. - Diagonal
cov_operator(denseJ): computes(J * d) @ J^Tdirectly, skipping the(N, N)materialization ofSigma.
Otherwise materializes Sigma and forms the dense product. The
returned operator is tagged symmetric (and positive-semidefinite
when the input is).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
J
|
Float[Array, 'M N'] | AbstractLinearOperator
|
Jacobian or linear map, shape |
required |
cov_operator
|
AbstractLinearOperator
|
Input covariance, shape |
required |
Returns:
| Type | Description |
|---|---|
AbstractLinearOperator
|
Transformed covariance operator, shape |
AbstractLinearOperator
|
operator-valued |
AbstractLinearOperator
|
follows |
AbstractLinearOperator
|
|
Source code in src/gaussx/_linalg/_linalg.py
sandwich(A: lx.AbstractLinearOperator, P: lx.AbstractLinearOperator) -> lx.AbstractLinearOperator
¶
Return A @ P @ A.T exploiting compatible operator structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A
|
AbstractLinearOperator
|
Linear map with shape |
required |
P
|
AbstractLinearOperator
|
Covariance operator with shape |
required |
Returns:
| Type | Description |
|---|---|
AbstractLinearOperator
|
Transformed covariance operator with shape |
Examples:
Source code in src/gaussx/_linalg/_linalg.py
trace_product(A: lx.AbstractLinearOperator, B: lx.AbstractLinearOperator) -> Float[Array, '']
¶
Trace of a matrix product: tr(A @ B) with structural dispatch.
Uses operator structure where possible to avoid materialization:
- Both diagonal:
sum(diag(A) * diag(B)). - Diagonal × general (or vice versa): contract the diagonal with the diagonal of the other operator (no full materialization).
- Matched
gaussx.BlockDiag(same block sizes): sum of per-blocktrace_product. - Matched
gaussx.Kronecker(same factor structure):prod_i tr(A_i @ B_i). - Otherwise falls back to
sum(A * B^T)on the materialized matrices — the same O(N²) cost the previous implementation paid.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A
|
AbstractLinearOperator
|
Linear operator, shape |
required |
B
|
AbstractLinearOperator
|
Linear operator, shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar |
Source code in src/gaussx/_linalg/_linalg.py
diag_inv(operator: lx.AbstractLinearOperator, *, method: str = 'auto', num_probes: int = 30, key: jax.Array | None = None, solver: AbstractSolveStrategy | None = None) -> Float[Array, ' N']
¶
Compute the diagonal of the inverse of a linear operator.
Returns diag(A⁻¹) without forming the full inverse matrix.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
A linear operator representing A. |
required |
method
|
str
|
Algorithm to use. One of |
'auto'
|
num_probes
|
int
|
Number of Rademacher probe vectors for the hutchinson method. |
30
|
key
|
Array | None
|
PRNG key for probe generation in the hutchinson method.
When |
None
|
solver
|
AbstractSolveStrategy | None
|
Optional solve strategy for |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, ' N']
|
1D array of shape |
Source code in src/gaussx/_linalg/_diag_inv.py
Matrix-RHS & batched solves¶
Solve \(AX = B\) for matrix right-hand sides with one factorization
(solve_matrix), per-column or per-row structured dispatch
(solve_columns / solve_rows), or \(O(n)\) Thomas-algorithm tridiagonal
solves.
Structured linear algebra and Gaussian primitives for JAX.
solve_matrix(operator: lx.AbstractLinearOperator, matrix: Float[Array, 'N K'], *, solver: AbstractSolveStrategy | None = None) -> Float[Array, 'N K']
¶
Solve A X = B with a single factorization on the matrix RHS.
When solver is None and A is positive semidefinite, this
factors A = L L^T once via gaussx.cholesky and then uses
a single cho_solve on the full matrix RHS — avoiding the
per-column re-factorization incurred by solve_columns.
For non-PSD operators (or when a custom solver is supplied),
falls back to solve_columns.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
Linear operator A, shape |
required |
matrix
|
Float[Array, 'N K']
|
Right-hand side B, shape |
required |
solver
|
AbstractSolveStrategy | None
|
Optional solver strategy. When provided, dispatch is
delegated column-by-column via |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'N K']
|
Solution X = A⁻¹ B, shape |
Source code in src/gaussx/_linalg/_linalg.py
solve_columns(operator: lx.AbstractLinearOperator, matrix: Float[Array, 'N K'], *, solver: AbstractSolveStrategy | None = None) -> Float[Array, 'N K']
¶
Solve A X = B column-by-column via vmap.
Equivalent to A⁻¹ @ B but uses structured dispatch per column.
Use this when A is a lineax operator with efficient per-vector
solve (e.g., Cholesky, diagonal, block-diagonal).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
Linear operator A, shape |
required |
matrix
|
Float[Array, 'N K']
|
Right-hand side B, shape |
required |
solver
|
AbstractSolveStrategy | None
|
Optional solver strategy. |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'N K']
|
Solution X = A⁻¹ B, shape |
Source code in src/gaussx/_linalg/_linalg.py
solve_rows(operator: lx.AbstractLinearOperator, matrix: Float[Array, 'K N'], *, solver: AbstractSolveStrategy | None = None) -> Float[Array, 'K N']
¶
Solve A x = bᵢ for each row bᵢ of a matrix via vmap.
Equivalent to B @ A⁻¹ row-by-row. Used in Kalman gain, Schur
complement, and GP prediction computations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
Linear operator A, shape |
required |
matrix
|
Float[Array, 'K N']
|
Rows of right-hand sides, shape |
required |
solver
|
AbstractSolveStrategy | None
|
Optional solver strategy. |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'K N']
|
Solutions, shape |
Source code in src/gaussx/_linalg/_linalg.py
solve_tridiagonal(lower: Float[Array, ' n_minus_1'], diag: Float[Array, ' n'], upper: Float[Array, ' n_minus_1'], rhs: Float[Array, ' n']) -> Float[Array, ' n']
¶
Solve a tridiagonal system A x = d.
Thin wrapper over lineax.TridiagonalLinearOperator, which delegates
to jax.lax.linalg.tridiagonal_solve (LAPACK / cuSPARSE).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
lower
|
Float[Array, ' n_minus_1']
|
Sub-diagonal of |
required |
diag
|
Float[Array, ' n']
|
Main diagonal of |
required |
upper
|
Float[Array, ' n_minus_1']
|
Super-diagonal of |
required |
rhs
|
Float[Array, ' n']
|
Right-hand side, length |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, ' n']
|
The solution |
Source code in src/gaussx/_linalg/_tridiagonal.py
solve_tridiagonal_batched(lower: Float[Array, '*batch n_minus_1'], diag: Float[Array, '*batch n'], upper: Float[Array, '*batch n_minus_1'], rhs: Float[Array, '*batch n']) -> Float[Array, '*batch n']
¶
Solve independent tridiagonal systems over leading batch dimensions.
Applies solve_tridiagonal vmapped over all leading dimensions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
lower
|
Float[Array, '*batch n_minus_1']
|
Sub-diagonals, shape |
required |
diag
|
Float[Array, '*batch n']
|
Main diagonals, shape |
required |
upper
|
Float[Array, '*batch n_minus_1']
|
Super-diagonals, shape |
required |
rhs
|
Float[Array, '*batch n']
|
Right-hand sides, shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '*batch n']
|
Solutions, shape |
Source code in src/gaussx/_linalg/_tridiagonal.py
batched_kernel_matvec(kernel_fn: Callable[[Float[Array, ' D'], Float[Array, ' D']], Float[Array, '']], X: Float[Array, 'N D'], Z: Float[Array, 'M D'], v: Float[Array, ' M'], batch_size: int = 1024) -> Float[Array, ' N']
¶
Compute K(X, Z) @ v in memory-efficient batches.
Each batch evaluates a (batch_size, M) kernel sub-matrix and
immediately contracts with v, keeping peak memory at
O(batch_size * M) instead of O(N * M).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
kernel_fn
|
Callable[[Float[Array, ' D'], Float[Array, ' D']], Float[Array, '']]
|
Pairwise kernel function |
required |
X
|
Float[Array, 'N D']
|
First set of points, shape |
required |
Z
|
Float[Array, 'M D']
|
Second set of points, shape |
required |
v
|
Float[Array, ' M']
|
Vector to multiply, shape |
required |
batch_size
|
int
|
Rows of |
1024
|
Returns:
| Type | Description |
|---|---|
Float[Array, ' N']
|
Result of |
Source code in src/gaussx/_linalg/_batched_matvec.py
batched_kernel_rmatvec(kernel_fn: Callable[[Float[Array, ' D'], Float[Array, ' D']], Float[Array, '']], X: Float[Array, 'N D'], Z: Float[Array, 'M D'], u: Float[Array, ' N'], batch_size: int = 1024) -> Float[Array, ' M']
¶
Compute K(X, Z)^T @ u in memory-efficient batches.
Scans over batches of X, building each (batch_size, M)
kernel block and accumulating K_batch^T @ u_batch into an
(M,) result. Peak memory per step: O(batch_size * M).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
kernel_fn
|
Callable[[Float[Array, ' D'], Float[Array, ' D']], Float[Array, '']]
|
Pairwise kernel function |
required |
X
|
Float[Array, 'N D']
|
First set of points, shape |
required |
Z
|
Float[Array, 'M D']
|
Second set of points, shape |
required |
u
|
Float[Array, ' N']
|
Vector to multiply, shape |
required |
batch_size
|
int
|
Rows of |
1024
|
Returns:
| Type | Description |
|---|---|
Float[Array, ' M']
|
Result of |
Source code in src/gaussx/_linalg/_batched_matvec.py
Stable kernel arithmetic & Lyapunov¶
Structured linear algebra and Gaussian primitives for JAX.
stable_squared_distances(X: Float[Array, 'N D'], Z: Float[Array, 'M D'], *, compute_dtype: jnp.dtype = jnp.float32, accumulate_dtype: jnp.dtype = jnp.float64) -> Float[Array, 'N M']
¶
Squared Euclidean distances with mixed-precision stability.
The expansion ||x - z||^2 = ||x||^2 + ||z||^2 - 2 x^T z suffers
catastrophic cancellation in float32 for high-D data, producing
negative distances and non-PSD kernels.
This function computes dot products in compute_dtype (fast) and
performs the subtraction in accumulate_dtype (stable), then casts
the result back to compute_dtype.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X
|
Float[Array, 'N D']
|
First set of points, shape |
required |
Z
|
Float[Array, 'M D']
|
Second set of points, shape |
required |
compute_dtype
|
dtype
|
Dtype for dot products (default float32). |
float32
|
accumulate_dtype
|
dtype
|
Dtype for subtraction (default float64). |
float64
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'N M']
|
Squared distances, shape |
Source code in src/gaussx/_linalg/_mixed_precision.py
stable_rbf_kernel(X: Float[Array, 'N D'], Z: Float[Array, 'M D'], lengthscale: float | Float[Array, ''], variance: float | Float[Array, ''] = 1.0, *, compute_dtype: jnp.dtype = jnp.float32, accumulate_dtype: jnp.dtype = jnp.float64) -> Float[Array, 'N M']
¶
RBF (squared exponential) kernel with mixed-precision stability.
Computes variance * exp(-0.5 * ||x - z||^2 / lengthscale^2)
using stable_squared_distances for the distance computation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X
|
Float[Array, 'N D']
|
First set of points, shape |
required |
Z
|
Float[Array, 'M D']
|
Second set of points, shape |
required |
lengthscale
|
float | Float[Array, '']
|
Kernel lengthscale. |
required |
variance
|
float | Float[Array, '']
|
Kernel signal variance (default 1.0). |
1.0
|
compute_dtype
|
dtype
|
Dtype for dot products (default float32). |
float32
|
accumulate_dtype
|
dtype
|
Dtype for subtraction (default float64). |
float64
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'N M']
|
Kernel matrix, shape |
Source code in src/gaussx/_linalg/_mixed_precision.py
discrete_lyapunov_solve(G: Float[Array, 'N N'], Q: Float[Array, 'N N']) -> Float[Array, 'N N']
¶
Solve the discrete Lyapunov equation P - G P G^T = Q.
Uses the eigendecomposition G = V Λ V^{-1} so that
and then P = V \tilde{P} V^T. This avoids materializing the
(N², N²) Kronecker matrix the vectorized formulation requires.
Cost is O(N³) (one general eigendecomposition + a couple of
matrix multiplies) versus O(N⁶) for the vectorized solve, and
the memory footprint drops from O(N⁴) to O(N²).
The discrete Lyapunov equation has a unique solution iff
λ_i λ_j ≠ 1 for all eigenvalue pairs (λ_i, λ_j) of G.
The standard sufficient condition — G is stable
(spectral radius < 1) — guarantees this.
Warning
This implementation assumes G is diagonalizable. For
defective G (e.g., a Jordan block with eigenvalue magnitude
< 1) the eigenvector matrix V is singular and this
solve will return NaN / inf. The discrete Lyapunov
equation still has a unique solution in that case, but
recovering it requires a Schur decomposition (Bartels-Stewart),
which JAX does not currently expose. Fall back to the
vectorized (I − G ⊗ G) solve for those operators.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
G
|
Float[Array, 'N N']
|
Square matrix, shape |
required |
Q
|
Float[Array, 'N N']
|
Right-hand side, shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'N N']
|
Symmetric matrix |
Float[Array, 'N N']
|
|