Operators & Tags¶
Layer 1: structured linear operators extending
lineax.AbstractLinearOperator.
All are immutable equinox.Module pytrees, so they compose freely with jit,
grad, and vmap. The primitives dispatch on these types: a
solve against a Kronecker factorizes per Kronecker factor, a logdet of a
BlockDiag sums per block, a LowRankUpdate solve applies Woodbury.
Structured products & sums¶
The Kronecker product \(A_1 \otimes A_2 \otimes \cdots\) gives \(O(\sum_i n_i^3)\) solves on a \(\prod_i n_i\) grid; the Kronecker sum \(A \otimes I + I \otimes B\) diagonalises in the joint eigenbasis with eigenvalues \(\lambda_i + \mu_j\).
Structured linear algebra and Gaussian primitives for JAX.
Kronecker
¶
Bases: AbstractLinearOperator
Kronecker product operator A₁ ⊗ A₂ ⊗ … ⊗ Aₖ.
Matvec uses Roth's column lemma for efficient computation without materializing the full Kronecker product. For two factors A (m x n) and B (p x q), the product (A kron B) vec(X) is computed as vec(B X A^T) where X is reshaped to (q, n).
Complexity: O(sum n_i^3) instead of O((prod n_i)^2) for the naive approach.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*operators
|
AbstractLinearOperator
|
Two or more |
()
|
Source code in src/gaussx/_operators/_kronecker.py
BlockDiag
¶
Bases: AbstractLinearOperator
Block diagonal operator diag(A₁, A₂, …, Aₖ).
Each sub-operator acts on its own slice of the input vector. Matvec, transpose, logdet, solve, and cholesky all decompose per-block.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*operators
|
AbstractLinearOperator
|
One or more |
()
|
Source code in src/gaussx/_operators/_block_diag.py
KroneckerSum
¶
Bases: AbstractLinearOperator
Kronecker sum A \oplus B = A \otimes I_b + I_a \otimes B.
Appears in separable PDEs, graph Laplacians, and space-time GPs.
If A = Q_A \Lambda_A Q_A^T and B = Q_B \Lambda_B Q_B^T,
the Kronecker sum has eigenvectors Q_A \otimes Q_B with
eigenvalues \lambda^A_i + \lambda^B_j.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A
|
AbstractLinearOperator
|
First operator, shape |
required |
B
|
AbstractLinearOperator
|
Second operator, shape |
required |
Source code in src/gaussx/_operators/_kronecker_sum.py
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 | |
eigendecompose() -> tuple[Float[Array, ' n'], Float[Array, 'n n']]
¶
Symmetric eigendecomposition via per-factor decomposition.
Assumes both factors are symmetric so the returned
Q = Q_A ⊗ Q_B is orthonormal — callers rely on
self == Q @ diag(eigenvalues) @ Q.T. Diagonal factors get a
structural shortcut; other operators are materialized and
decomposed via jnp.linalg.eigh. We deliberately avoid
routing untagged factors through gaussx.eig because
that primitive falls back to jnp.linalg.eig for untagged
operators and would return general (non-orthonormal)
eigenvectors — breaking the Q.T == Q^{-1} contract for the
common case of numerically symmetric matrices wrapped as plain
lineax.MatrixLinearOperator.
Returns:
| Type | Description |
|---|---|
Float[Array, ' n']
|
Tuple |
Float[Array, 'n n']
|
eigenvalues are |
Source code in src/gaussx/_operators/_kronecker_sum.py
KroneckerSumSqrt
¶
Bases: AbstractLinearOperator
Symmetric square root of A \oplus B via per-factor eigenvectors.
Represents the symmetric matrix S with S @ S = A \oplus B
(where \oplus is the Kronecker sum A ⊗ I + I ⊗ B). The square
root is never materialized: mv and solve apply S and
S^{-1} matrix-free using the per-factor eigendecompositions, so the
cost is governed by the factor sizes rather than the full n_a · n_b
dimension.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A
|
AbstractLinearOperator
|
Symmetric PSD factor, shape |
required |
B
|
AbstractLinearOperator
|
Symmetric PSD factor, shape |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If either factor is non-square, untagged as symmetric, or
if |
Source code in src/gaussx/_operators/_kronecker_sum.py
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 | |
solve(vector: Float[Array, ' n']) -> Float[Array, ' n']
¶
Apply the inverse square root S^{-1} to vector.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
vector
|
Float[Array, ' n']
|
Input vector, shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, ' n']
|
|
Source code in src/gaussx/_operators/_kronecker_sum.py
SumKronecker
¶
Bases: AbstractLinearOperator
Sum of Kronecker products Σ_k A_k \otimes B_k.
Appears in multi-output GPs with correlated outputs, e.g.
K_task \otimes K_spatial + \sigma^2 I_task \otimes I_spatial.
Matvec is computed as the sum of the Kronecker matvecs.
For solve and logdet, call eigendecompose which uses a
joint eigendecomposition of the second Kronecker pair (requires
A_2, B_2 to be symmetric). The eigendecomposition forms a
dense (n_c n_d) x (n_c n_d) matrix internally, so it is
intended for moderate factor sizes (typical for multi-output GPs
where the task dimension is small).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
kron1
|
Kronecker
|
First Kronecker product |
required |
kron2
|
Kronecker
|
Second Kronecker product |
required |
*krons
|
Kronecker
|
Additional two-factor Kronecker products. |
()
|
Source code in src/gaussx/_operators/_sum_kronecker.py
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | |
eigendecompose() -> tuple[Float[Array, ' n'], Float[Array, 'n n']]
¶
Eigendecompose via joint eigendecomposition of the second pair.
Decomposes A_2 = Q_C \Lambda_C Q_C^T and
B_2 = Q_D \Lambda_D Q_D^T, then transforms the first pair
into the eigenbasis and diagonalizes the result.
Note
This forms a dense (n_c n_d) x (n_c n_d) matrix
internally and is O((n_c n_d)^3). Intended for moderate
factor sizes (e.g. multi-output GPs where task dimension
is small).
Raises:
| Type | Description |
|---|---|
ValueError
|
If the factors of |
Returns:
| Type | Description |
|---|---|
Float[Array, ' n']
|
Tuple |
Float[Array, 'n n']
|
|
Source code in src/gaussx/_operators/_sum_kronecker.py
Low-rank updates¶
\(L + U\,\mathrm{diag}(d)\,V^\top\) with Woodbury-efficient solves and matrix-determinant-lemma logdets. The factories build the common special cases directly from arrays.
Structured linear algebra and Gaussian primitives for JAX.
LowRankUpdate
¶
Bases: AbstractLinearOperator
Low-rank update operator L + U diag(d) Vᵀ.
Represents a base operator L plus a rank-k update. When L is cheap to solve (e.g. diagonal), the Woodbury identity gives efficient solves for the full operator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
base
|
AbstractLinearOperator
|
The base operator L, with shape |
required |
U
|
Float[Array, 'n k']
|
Left factor, shape |
required |
d
|
Float[Array, ' k'] | None
|
Diagonal scaling, shape |
None
|
V
|
Float[Array, 'n k'] | None
|
Right factor, shape |
None
|
orthonormal
|
bool
|
When |
False
|
Source code in src/gaussx/_operators/_low_rank_update.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | |
rank: int
property
¶
Rank of the low-rank update.
SVDLowRankUpdate
¶
Bases: LowRankUpdate
Deprecated subclass of LowRankUpdate with orthonormal=True.
Preserves the pre-consolidation public API for one release:
- Same constructor signature as the old class —
Sdefaults to ones (via the parentLowRankUpdate) if omitted, andVdefaults toUso calls likeSVDLowRankUpdate(base, U, S)andSVDLowRankUpdate(base, U)continue to work. - Inherits from
LowRankUpdatesoisinstance/issubclasschecks andsingledispatchregistrations keyed on this class keep working. - Forces
orthonormal=Trueand emits aDeprecationWarningon construction.
New code should construct LowRankUpdate(base, U, S, V,
orthonormal=True) (or use svd_low_rank_plus_diag)
directly. Will be removed in a future release.
Source code in src/gaussx/_operators/_low_rank_update.py
low_rank_plus_diag(diag: Float[Array, ' n'], U: Float[Array, 'n k'], d: Float[Array, ' k'] | None = None, V: Float[Array, 'n k'] | None = None) -> LowRankUpdate
¶
Construct diag(diag) + U diag(d) Vᵀ.
Common pattern for inducing-point / Nystrom approximations where the base is a diagonal matrix.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
diag
|
Float[Array, ' n']
|
Diagonal entries, shape |
required |
U
|
Float[Array, 'n k']
|
Left factor, shape |
required |
d
|
Float[Array, ' k'] | None
|
Diagonal scaling, shape |
None
|
V
|
Float[Array, 'n k'] | None
|
Right factor, shape |
None
|
Returns:
| Type | Description |
|---|---|
LowRankUpdate
|
A |
Source code in src/gaussx/_operators/_low_rank_update.py
low_rank_plus_identity(U: Float[Array, 'n k'], d: Float[Array, ' k'] | None = None, V: Float[Array, 'n k'] | None = None, *, scale: float = 1.0) -> LowRankUpdate
¶
Construct scale * I + U diag(d) Vᵀ.
Common pattern for regularised low-rank models (e.g. noise + signal).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
U
|
Float[Array, 'n k']
|
Left factor, shape |
required |
d
|
Float[Array, ' k'] | None
|
Diagonal scaling, shape |
None
|
V
|
Float[Array, 'n k'] | None
|
Right factor, shape |
None
|
scale
|
float
|
Scalar multiplier on the identity. Default 1.0. |
1.0
|
Returns:
| Type | Description |
|---|---|
LowRankUpdate
|
A |
Source code in src/gaussx/_operators/_low_rank_update.py
svd_low_rank_plus_diag(diag: Float[Array, ' n'], U: Float[Array, 'n k'], S: Float[Array, ' k'], V: Float[Array, 'n k']) -> LowRankUpdate
¶
Construct diag(diag) + U diag(S) Vᵀ from a truncated SVD.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
diag
|
Float[Array, ' n']
|
Diagonal entries, shape |
required |
U
|
Float[Array, 'n k']
|
Left singular vectors, shape |
required |
S
|
Float[Array, ' k']
|
Singular values, shape |
required |
V
|
Float[Array, 'n k']
|
Right singular vectors, shape |
required |
Returns:
| Type | Description |
|---|---|
LowRankUpdate
|
A |
Source code in src/gaussx/_operators/_low_rank_update.py
Banded & Toeplitz¶
Block-tridiagonal operators solve in \(O(N d^3)\) via block-banded Cholesky — the precision structure of Markovian (state-space) GPs. Symmetric Toeplitz operators get \(O(n \log n)\) matvecs and sampling via FFT circulant embedding.
Structured linear algebra and Gaussian primitives for JAX.
BlockTriDiag
¶
Bases: AbstractLinearOperator
Symmetric block-tridiagonal operator.
Represents the structure:
[D_1 A_1^T ]
[A_1 D_2 A_2^T ]
[ A_2 D_3 ... ]
[ A_{N-1} D_N]
where D_k are (d, d) diagonal blocks and A_k are
(d, d) sub-diagonal blocks. This is the precision matrix
structure arising from discretized SDEs in state-space GP inference.
All primitives (solve, logdet, cholesky, diag, trace) exploit the banded structure for O(Nd³) cost instead of O((Nd)³).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
diagonal
|
Float[Array, 'N d d']
|
Diagonal blocks, shape |
required |
sub_diagonal
|
Float[Array, 'Nm1 d d']
|
Sub-diagonal blocks, shape |
required |
Source code in src/gaussx/_operators/_block_tridiag.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | |
add(other: BlockTriDiag) -> BlockTriDiag
¶
Add two block-tridiagonal operators (e.g. prior + likelihood sites).
LowerBlockTriDiag
¶
Bases: AbstractLinearOperator
Lower triangular block-bidiagonal Cholesky factor.
Represents:
[L_1 ]
[B_1 L_2 ]
[ B_2 L_3 ]
[ ... L_N]
where L_k are (d, d) lower-triangular blocks and B_k are
(d, d) sub-diagonal blocks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
diagonal
|
Float[Array, 'N d d']
|
Lower-triangular diagonal blocks, shape |
required |
sub_diagonal
|
Float[Array, 'Nm1 d d']
|
Sub-diagonal blocks, shape |
required |
Source code in src/gaussx/_operators/_block_tridiag.py
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 | |
transpose() -> UpperBlockTriDiag
¶
Transpose gives upper block-bidiagonal.
UpperBlockTriDiag
¶
Bases: AbstractLinearOperator
Upper triangular block-bidiagonal (transpose of LowerBlockTriDiag).
Represents:
[U_1 C_1 ]
[ U_2 C_2 ]
[ ... C_{N-1}]
[ U_N ]
where U_k are upper-triangular diagonal blocks and C_k are
super-diagonal blocks.
Source code in src/gaussx/_operators/_block_tridiag.py
Toeplitz
¶
Bases: AbstractLinearOperator
Symmetric Toeplitz matrix from its first column.
K_{ij} = c_{|i-j|}. Stored as O(n) with O(n log n) matvec
via circulant embedding and FFT.
For stationary kernels on regular 1-D grids the full kernel matrix is Toeplitz, so this gives an asymptotic win over dense storage.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
column
|
Float[Array, ' n']
|
First column of the Toeplitz matrix, shape |
required |
Source code in src/gaussx/_operators/_toeplitz.py
ToeplitzCholesky
¶
Bases: AbstractLinearOperator
Circulant-embedding sample factor for a symmetric positive Toeplitz matrix.
The operator has shape (n, embedding_factor * n) and satisfies
L @ L.T == Toeplitz(column) — it is a rectangular sample factor,
not a traditional lower-triangular Cholesky factor. Applying it to
standard normal white noise gives samples from
𝒩(0, Toeplitz(column)) when the Wood--Chan condition holds.
The Wood--Chan non-negativity check is implemented with
eqx.error_if so it is JIT-friendly: the error fires at evaluation
time rather than tracing time. If the embedding's spectrum has a
materially negative eigenvalue, bump embedding_factor (typically
4 or 8 suffices for well-behaved covariances).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
column
|
Float[Array, ' n']
|
First column of the Toeplitz matrix, shape |
required |
embedding_factor
|
int
|
Circulant embedding size as a multiple of |
2
|
Source code in src/gaussx/_operators/_toeplitz.py
Kernel operators¶
Kernel matrices as operators — dense (KernelOperator), matrix-free
(ImplicitKernelOperator, rows generated on the fly per matvec), rectangular
cross-kernels, and grid-interpolated (KISS-GP style) variants.
Structured linear algebra and Gaussian primitives for JAX.
KernelOperator
¶
Bases: AbstractLinearOperator
Kernel matrix operator with efficient first-order autodiff.
Represents the matrix K where K[i, j] = kernel_fn(params, X1[i], X2[j]).
The matvec K @ v is computed via scan (O(N) memory), and a
jax.custom_jvp keeps first-order autodiff efficient without
materializing Jacobians.
Batched inputs are supported: X1 and X2 may carry leading
batch dimensions (*batch, N, D) / (*batch, M, D) (with
matching *batch). In that case mv expects a vector of shape
(*batch, M) and returns (*batch, N); as_matrix() returns
a (*batch, N, M) tensor; in_structure() / out_structure()
report the batched shapes so lineax helpers (linear_solve,
probe-vector allocators) construct compatible inputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
kernel_fn
|
Callable
|
Kernel function |
required |
X1
|
Float[Array, 'N D']
|
First set of data points, shape |
required |
X2
|
Float[Array, 'M D']
|
Second set of data points, shape |
required |
params
|
Any
|
Pytree of kernel hyperparameters (differentiable). |
required |
tags
|
object | frozenset[object]
|
Optional lineax structural tags. |
frozenset()
|
Source code in src/gaussx/_operators/_kernel.py
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 | |
mv(vector: Float[Array, '*batch M']) -> Float[Array, '*batch N']
¶
Compute K @ v via scan with custom JVP support.
Source code in src/gaussx/_operators/_kernel.py
as_matrix() -> Float[Array, '*batch N M']
¶
Materialize the full kernel matrix.
Source code in src/gaussx/_operators/_kernel.py
transpose() -> KernelOperator
¶
Return the transpose operator (X1, X2 swapped, kernel transposed).
Source code in src/gaussx/_operators/_kernel.py
ImplicitKernelOperator
¶
Bases: AbstractLinearOperator
Matrix-free kernel operator: (K + sigma^2 I) v via sequential scan.
Computes the kernel matvec without materializing the N x N kernel
matrix, using O(N) memory instead of O(N^2). Each element of
the output is computed as:
y_i = \sum_j k(x_i, x_j) v_j + sigma^2 v_i
The scan-based implementation is compatible with CG / BBMM solvers that only need matvec access.
Supports two kernel signatures:
- No params (default):
k(x, x') -> scalar. Hyperparameters are closed over in the lambda. - With params:
k(params, x, x') -> scalar. Pass a pytree of differentiable hyperparameters via theparamsargument and ajax.custom_jvpkeeps first-order autodiff efficient.
Batched inputs are supported: X may carry leading batch
dimensions (*batch, N, D). In that case mv expects a vector
of shape (*batch, N) and returns (*batch, N);
as_matrix() returns a (*batch, N, N) tensor;
in_structure() / out_structure() report the batched shapes.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
kernel_fn
|
Callable
|
Kernel function (see above for signature). |
required |
X
|
Float[Array, 'N D']
|
Training points, shape |
required |
noise_var
|
float
|
Diagonal noise variance |
0.0
|
params
|
Any | None
|
Optional pytree of kernel hyperparameters. |
None
|
Source code in src/gaussx/_operators/_implicit_kernel.py
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 | |
mv(vector: Float[Array, '*batch N']) -> Float[Array, '*batch N']
¶
Compute (K + sigma^2 I) @ v via scan over data points.
Source code in src/gaussx/_operators/_implicit_kernel.py
as_matrix() -> Float[Array, '*batch N N']
¶
Materialize the full kernel matrix (for debugging/testing).
Source code in src/gaussx/_operators/_implicit_kernel.py
ImplicitCrossKernelOperator
¶
Bases: AbstractLinearOperator
Matrix-free rectangular kernel operator K(X, Z) \cdot v.
Computes the cross-kernel matvec without materializing the full
N x M kernel matrix, using a batched scan that keeps peak memory
at O(batch\_size \times M) per step.
Forward matvec (mv):
y_i = \sum_j k(x_i, z_j) \cdot v_j
maps an M-vector to an N-vector.
Adjoint / transpose computes K^T u = K(Z, X) u, mapping an
N-vector to an M-vector.
Supports two kernel signatures:
- No params (default):
k(x, z) -> scalar. - With params:
k(params, x, z) -> scalar. Pass a pytree of differentiable hyperparameters and ajax.custom_jvpkeeps first-order autodiff efficient.
Batched inputs are supported: X_data and X_inducing may
carry leading batch dimensions (*batch, N, D) /
(*batch, M, D) (with matching *batch). In that case mv
expects a vector of shape (*batch, M) and returns (*batch, N);
the transposed operator follows the symmetric pattern. as_matrix()
returns a (*batch, N, M) tensor; in_structure() /
out_structure() report the batched shapes.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
kernel_fn
|
Callable
|
Kernel function (see above for signature). |
required |
X_data
|
Float[Array, 'N D']
|
Data points, shape |
required |
X_inducing
|
Float[Array, 'M D']
|
Inducing points, shape |
required |
batch_size
|
int
|
Number of rows of |
1024
|
params
|
Any | None
|
Optional pytree of kernel hyperparameters. |
None
|
Source code in src/gaussx/_operators/_implicit_cross_kernel.py
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 | |
mv(vector: Float[Array, '*batch M']) -> Float[Array, '*batch N']
¶
Compute K(X_data, X_inducing) @ v via batched scan.
Peak memory per step is O(batch_size * M).
Source code in src/gaussx/_operators/_implicit_cross_kernel.py
transpose() -> _TransposedCrossKernelOperator
¶
Return the adjoint operator K^T.
Uses a dedicated adjoint matvec that scans over batches of
X_data and accumulates K_batch^T @ u_batch into an
(M,) result, keeping peak memory at O(batch_size x M).
Source code in src/gaussx/_operators/_implicit_cross_kernel.py
as_matrix() -> Float[Array, '*batch N M']
¶
Materialize the full N x M cross-kernel matrix.
Source code in src/gaussx/_operators/_implicit_cross_kernel.py
InterpolatedOperator
¶
Bases: AbstractLinearOperator
Structured Kernel Interpolation: K \approx W K_{uu} W^T.
W is a sparse interpolation matrix with p nonzeros per row
(e.g. cubic interpolation weights). The base operator K_{uu}
acts on the inducing grid (typically Toeplitz for stationary
kernels).
Total matvec cost: O(n p + m log m) when the base is Toeplitz,
essentially linear in n.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
base_operator
|
AbstractLinearOperator
|
The inducing-point kernel |
required |
interp_indices
|
Int[Array, 'n p']
|
Integer indices into the inducing grid,
shape |
required |
interp_values
|
Float[Array, 'n p']
|
Interpolation weights, shape |
required |
Source code in src/gaussx/_operators/_interpolated.py
MaskedOperator
¶
Bases: AbstractLinearOperator
Row/column-masked view of a base operator.
Given a base operator A of shape (N, N) and boolean masks,
produces the sub-matrix A[row_mask][:, col_mask].
Matvec is computed without materializing the sub-matrix: zero-pad input to full size, apply base matvec, then extract masked rows.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
base
|
AbstractLinearOperator
|
The underlying |
required |
row_mask
|
Bool[Array, ' N']
|
Boolean mask of length N selecting output rows. |
required |
col_mask
|
Bool[Array, ' N']
|
Boolean mask of length N selecting input columns. |
required |
Source code in src/gaussx/_operators/_masked.py
implicit_cross_kernel(kernel_fn: Callable, X_data: Float[Array, 'N D'], X_inducing: Float[Array, 'M D'], batch_size: int = 1024, *, params: Any | None = None, tags: object | frozenset[object] = frozenset()) -> ImplicitCrossKernelOperator
¶
Create a matrix-free rectangular cross-kernel operator.
Convenience wrapper around ImplicitCrossKernelOperator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
kernel_fn
|
Callable
|
Kernel function |
required |
X_data
|
Float[Array, 'N D']
|
Data points, shape |
required |
X_inducing
|
Float[Array, 'M D']
|
Inducing points, shape |
required |
batch_size
|
int
|
Rows of |
1024
|
params
|
Any | None
|
Optional pytree of kernel hyperparameters. |
None
|
tags
|
object | frozenset[object]
|
Lineax structural tags. |
frozenset()
|
Returns:
| Type | Description |
|---|---|
ImplicitCrossKernelOperator
|
An |
Source code in src/gaussx/_operators/_implicit_cross_kernel.py
Lazy algebra & sampling¶
Sum / scale / compose operators without materializing, sample \(\varepsilon \sim \mathcal{N}(0, A)\) for the structured families, and solve bordered systems through the capacitance (Schur-complement) form.
Structured linear algebra and Gaussian primitives for JAX.
CapacitanceSolver
¶
Bases: Module
Solve a base system subject to homogeneous point constraints.
Given a fast base solver B^{-1} and a set of N_b constrained indices,
this enforces u = 0 at those indices via the capacitance-matrix
correction:
- Base solve:
u = B^{-1} f - Sample boundary:
u_b = u[boundary] - Correction:
alpha = C^{-1} u_b - Subtract:
x = u - G^T alpha
where G[k] = B^{-1} e_{b_k} are the Green's functions of the base solver
for unit sources at the constrained indices, and C[k, l] = G[l][b_k] is
the capacitance matrix. C^{-1} and G are precomputed at construction.
The solver operates on flat vectors. Any reshaping between fields and flat vectors, and any masking of the exterior, is the caller's responsibility -- keeping grid/mask concepts out of this class.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
base_solve
|
Callable[[Float[Array, ' n']], Float[Array, ' n']]
|
Callable applying the base inverse |
required |
boundary_indices
|
Int[Array, ' Nb']
|
Flat indices of the constrained degrees of freedom,
shape |
required |
n
|
int
|
Length of the flat solution vector. |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
base_solve |
Callable[[Float[Array, ' n']], Float[Array, ' n']]
|
The base inverse callable. |
boundary_indices |
Int[Array, ' Nb']
|
The constrained indices. |
green |
Float[Array, 'Nb n']
|
Green's functions |
capacitance_inv |
Float[Array, 'Nb Nb']
|
Inverse capacitance matrix |
Source code in src/gaussx/_operators/_capacitance.py
SumOperator(*operators: lx.AbstractLinearOperator, tags: object | frozenset[object] = frozenset()) -> lx.AbstractLinearOperator
¶
Lazy sum (A + B + …) v = A v + B v + ….
Defers materialization so that structured sub-operators keep their
efficient matvec. All operators must have the same input and output
sizes. Returns a (possibly tagged) chain of lineax
AddLinearOperator compositions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*operators
|
AbstractLinearOperator
|
Two or more |
()
|
tags
|
object | frozenset[object]
|
Optional explicit lineax tags for the combined operator. |
frozenset()
|
Returns:
| Type | Description |
|---|---|
AbstractLinearOperator
|
The lazy sum as a lineax operator. |
Source code in src/gaussx/_operators/_lazy_algebra.py
ScaledOperator(operator: lx.AbstractLinearOperator, scalar: float | Float[Array, ''], *, tags: object | frozenset[object] = frozenset()) -> lx.AbstractLinearOperator
¶
Lazy scalar multiply (c A) v = c (A v).
Returns a (possibly tagged) lineax MulLinearOperator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
A |
required |
scalar
|
float | Float[Array, '']
|
A scalar multiplier. |
required |
tags
|
object | frozenset[object]
|
Optional explicit lineax tags for the scaled operator. |
frozenset()
|
Returns:
| Type | Description |
|---|---|
AbstractLinearOperator
|
The lazy scaled operator. |
Source code in src/gaussx/_operators/_lazy_algebra.py
ProductOperator(left: lx.AbstractLinearOperator, right: lx.AbstractLinearOperator, *, tags: object | frozenset[object] = frozenset()) -> lx.AbstractLinearOperator
¶
Lazy matmul (A B) v = A (B v).
The inner dimension must match: left.in_size() == right.out_size().
Returns a (possibly tagged) lineax ComposedLinearOperator.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
left
|
AbstractLinearOperator
|
The left operator A. |
required |
right
|
AbstractLinearOperator
|
The right operator B. |
required |
tags
|
object | frozenset[object]
|
Optional explicit lineax tags for the composed operator. |
frozenset()
|
Returns:
| Type | Description |
|---|---|
AbstractLinearOperator
|
The lazy product as a lineax operator. |
Source code in src/gaussx/_operators/_lazy_algebra.py
kronecker_sum_sample(A_op: lx.AbstractLinearOperator, B_op: lx.AbstractLinearOperator, *, key: jax.Array, num_samples: int = 1) -> Float[Array, 'num_samples n_a n_b']
¶
Sample from 𝒩(0, A ⊕ B) using per-factor eigendecompositions.
Draws zero-mean samples with covariance A ⊕ B by applying the
matrix-free KroneckerSumSqrt to standard normal noise, avoiding
materialization of the full (n_a · n_b, n_a · n_b) covariance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A_op
|
AbstractLinearOperator
|
Symmetric PSD factor, shape |
required |
B_op
|
AbstractLinearOperator
|
Symmetric PSD factor, shape |
required |
key
|
Array
|
PRNG key for the standard normal draws. |
required |
num_samples
|
int
|
Number of samples to draw. |
1
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'num_samples n_a n_b']
|
Samples of shape |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Source code in src/gaussx/_operators/_kronecker_sum.py
sumkronecker_sample(op: SumKronecker, *, key: jax.Array, num_samples: int = 1, lanczos_order: int = 50) -> Float[Array, 'num_samples n']
¶
Sample from 𝒩(0, op) with matrix-free Lanczos square roots.
The square-root action is evaluated by matfree Lanczos against
op.mv. This avoids materialising the dense (n_A n_B) ×
(n_A n_B) covariance and costs lanczos_order SumKronecker
matvecs per sample.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
op
|
SumKronecker
|
Positive-semidefinite SumKronecker covariance operator. |
required |
key
|
Array
|
JAX PRNG key. |
required |
num_samples
|
int
|
Number of independent samples to draw. |
1
|
lanczos_order
|
int
|
Lanczos truncation order. |
50
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'num_samples n']
|
Samples with shape |
Source code in src/gaussx/_operators/_sum_kronecker.py
toeplitz_sample(column: Float[Array, ' n'], *, key: jax.Array, num_samples: int = 1, embedding_factor: int = 2) -> Float[Array, 'num_samples n']
¶
Sample from 𝒩(0, Toeplitz(column)) via FFT circulant embedding.
The Wood--Chan non-negativity check is JIT-friendly (via
eqx.error_if). If the embedding fails for the given
embedding_factor, the error fires at evaluation time — bump
embedding_factor (typically 4 or 8) for difficult
covariances rather than relying on a runtime fallback.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
column
|
Float[Array, ' n']
|
First column of the covariance matrix. |
required |
key
|
Array
|
JAX PRNG key used to draw white noise. |
required |
num_samples
|
int
|
Number of independent samples to draw. |
1
|
embedding_factor
|
int
|
Circulant embedding size as a multiple of |
2
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'num_samples n']
|
Samples with shape |
Source code in src/gaussx/_operators/_toeplitz.py
Structural tags & predicates¶
Tags mark structure and properties on operators; the is_* predicates are what
the primitives consult when choosing an algorithm. The property tags
(positive_semidefinite_tag, symmetric_tag, the triangular tags, …) are
re-exported from lineax so user code only needs one import.
Structured linear algebra and Gaussian primitives for JAX.
is_diagonal = lx.is_diagonal
module-attribute
¶
is_symmetric = lx.is_symmetric
module-attribute
¶
is_positive_semidefinite = lx.is_positive_semidefinite
module-attribute
¶
is_negative_semidefinite = lx.is_negative_semidefinite
module-attribute
¶
is_lower_triangular = lx.is_lower_triangular
module-attribute
¶
is_upper_triangular = lx.is_upper_triangular
module-attribute
¶
kronecker_tag = _Tag('kronecker_tag')
module-attribute
¶
Operator is a Kronecker product.
kronecker_sum_tag = _Tag('kronecker_sum_tag')
module-attribute
¶
Operator is a Kronecker sum A (+) B = A (x) I_b + I_a (x) B.
block_diagonal_tag = _Tag('block_diagonal_tag')
module-attribute
¶
Operator is block diagonal.
block_tridiagonal_tag = _Tag('block_tridiagonal_tag')
module-attribute
¶
Operator is block tridiagonal.
low_rank_tag = _Tag('low_rank_tag')
module-attribute
¶
Operator has low-rank structure (e.g. L + U D V^T).
diagonal_tag = lx.diagonal_tag
module-attribute
¶
symmetric_tag = lx.symmetric_tag
module-attribute
¶
positive_semidefinite_tag = lx.positive_semidefinite_tag
module-attribute
¶
negative_semidefinite_tag = lx.negative_semidefinite_tag
module-attribute
¶
lower_triangular_tag = lx.lower_triangular_tag
module-attribute
¶
upper_triangular_tag = lx.upper_triangular_tag
module-attribute
¶
tridiagonal_tag = lx.tridiagonal_tag
module-attribute
¶
unit_diagonal_tag = lx.unit_diagonal_tag
module-attribute
¶
is_kronecker(operator: lx.AbstractLinearOperator) -> bool
¶
Check whether operator carries the Kronecker tag.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
Linear operator to inspect. |
required |
Returns:
| Type | Description |
|---|---|
bool
|
|
Source code in src/gaussx/_tags.py
is_kronecker_sum(operator: lx.AbstractLinearOperator) -> bool
¶
Check whether operator carries the Kronecker sum tag.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
Linear operator to inspect. |
required |
Returns:
| Type | Description |
|---|---|
bool
|
|
Source code in src/gaussx/_tags.py
is_block_diagonal(operator: lx.AbstractLinearOperator) -> bool
¶
Check whether operator carries the block-diagonal tag.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
Linear operator to inspect. |
required |
Returns:
| Type | Description |
|---|---|
bool
|
|
Source code in src/gaussx/_tags.py
is_block_tridiagonal(operator: lx.AbstractLinearOperator) -> bool
¶
Check whether operator carries the block-tridiagonal tag.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
Linear operator to inspect. |
required |
Returns:
| Type | Description |
|---|---|
bool
|
|
bool
|
|
Source code in src/gaussx/_tags.py
is_low_rank(operator: lx.AbstractLinearOperator) -> bool
¶
Check whether operator carries the low-rank tag.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
Linear operator to inspect. |
required |
Returns:
| Type | Description |
|---|---|
bool
|
|