Primitives¶
Layer 0: pure functions over lineax.AbstractLinearOperator with structural
dispatch — each primitive inspects the operator (diagonal, Kronecker,
block-diagonal, low-rank, block-tridiagonal, …) and routes to the cheapest exact
algorithm, falling back to a dense computation (with a
DenseFallbackWarning) only when no structured
path exists.
Solve, logdet & Cholesky¶
The workhorses behind Gaussian densities: \(A^{-1}b\), \(\log|A|\), and \(A = LL^\top\).
cholesky returns a lazy lower-triangular operator that preserves structure
(the Cholesky of a Kronecker is a Kronecker of Cholesky factors);
cholesky_logdet turns an existing factor into \(\log|A| = 2\sum_i \log L_{ii}\)
for free.
Structured linear algebra and Gaussian primitives for JAX.
solve(operator: lx.AbstractLinearOperator, vector: Float[Array, ' n'], *, solver: lx.AbstractLinearSolver | None = None) -> Float[Array, ' n']
¶
Solve A x = b with structural dispatch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
The linear operator A. |
required |
vector
|
Float[Array, ' n']
|
The right-hand side b. |
required |
solver
|
AbstractLinearSolver | None
|
Optional lineax solver override for the fallback path. |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, ' n']
|
The solution x. |
Source code in src/gaussx/_primitives/_solve.py
logdet(operator: lx.AbstractLinearOperator) -> Float[Array, '']
¶
Compute log |det(A)| with structural dispatch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
The linear operator A. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar log |det(A)|. |
Source code in src/gaussx/_primitives/_logdet.py
cholesky(operator: lx.AbstractLinearOperator) -> lx.AbstractLinearOperator
¶
Compute Cholesky factor L such that A = L L^T.
Returns a linear operator (not a raw array). For structured operators, the result preserves structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
A positive-definite linear operator. |
required |
Returns:
| Type | Description |
|---|---|
AbstractLinearOperator
|
Lower-triangular operator L. |
Source code in src/gaussx/_primitives/_cholesky.py
cholesky_logdet(L: Float[Array, 'N N']) -> Float[Array, '']
¶
Compute log|A| from Cholesky factor L where A = L Lᵀ.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
L
|
Float[Array, 'N N']
|
Lower-triangular Cholesky factor, shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar log-determinant. |
Source code in src/gaussx/_primitives/_logdet.py
Trace & diagonal¶
Exact where structure allows; stochastic (Hutchinson / XTrace probing) for
matrix-free operators. trace_and_diag shares one probe pass between both
estimates.
Structured linear algebra and Gaussian primitives for JAX.
trace(operator: lx.AbstractLinearOperator, *, stochastic: bool = False, num_probes: int = 20, key: jax.Array | None = None, sampler: SamplerName | None = None, algorithm: Literal['hutchinson', 'xtrace'] = 'hutchinson') -> Float[Array, '']
¶
Compute the trace of an operator.
When stochastic=True, uses a matfree stochastic estimator —
only requires matvec access, no materialization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
A square linear operator. |
required |
stochastic
|
bool
|
If |
False
|
num_probes
|
int
|
Number of probe vectors for stochastic mode. |
20
|
key
|
Array | None
|
PRNG key for stochastic mode. |
None
|
sampler
|
SamplerName | None
|
Probe distribution ( |
None
|
algorithm
|
Literal['hutchinson', 'xtrace']
|
|
'hutchinson'
|
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar trace value (exact or estimated). |
Source code in src/gaussx/_primitives/_trace.py
diag(operator: lx.AbstractLinearOperator, *, stochastic: bool = False, num_probes: int = 20, key: jax.Array | None = None, sampler: SamplerName = 'signs') -> Float[Array, ' n']
¶
Extract the diagonal of an operator as a 1D array.
When stochastic=True, uses Hutchinson's diagonal estimator
via matfree — only requires matvec access, no materialization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
A linear operator. |
required |
stochastic
|
bool
|
If |
False
|
num_probes
|
int
|
Number of probe vectors for stochastic mode. |
20
|
key
|
Array | None
|
PRNG key for stochastic mode. |
None
|
sampler
|
SamplerName
|
Probe distribution for stochastic mode ( |
'signs'
|
Returns:
| Type | Description |
|---|---|
Float[Array, ' n']
|
1D array of diagonal entries (exact or estimated). |
Source code in src/gaussx/_primitives/_diag.py
trace_and_diag(operator: lx.AbstractLinearOperator, *, num_probes: int = 20, key: jax.Array | None = None, sampler: SamplerName = 'signs') -> tuple[Float[Array, ''], Float[Array, ' n']]
¶
Jointly estimate the trace and diagonal from one probe pass.
Halves the matvec budget relative to calling
trace(..., stochastic=True) and diag(..., stochastic=True)
separately — both statistics are accumulated from the same
A @ probe products.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
A square linear operator. |
required |
num_probes
|
int
|
Number of probe vectors. |
20
|
key
|
Array | None
|
PRNG key. If |
None
|
sampler
|
SamplerName
|
Probe distribution ( |
'signs'
|
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, ''], Float[Array, ' n']]
|
Tuple |
Source code in src/gaussx/_primitives/_trace.py
Inverse, square root & spectral decompositions¶
inv and sqrt return lazy operators that route their matvecs through
structured solves / Lanczos; eig, eigvals, and svd take an optional rank
for partial (Krylov) decompositions.
Structured linear algebra and Gaussian primitives for JAX.
inv(operator: lx.AbstractLinearOperator, *, solver: lx.AbstractLinearSolver | None = None) -> lx.AbstractLinearOperator
¶
Return a lazy inverse operator A^{-1}.
The returned operator computes A^{-1} v via solve(A, v)
when mv is called. For structured operators, the inverse
preserves structure.
Related to lineax.invert (lineax >= 0.1.1), which wraps
lx.linear_solve in a FunctionLinearOperator. The gaussx
fallback InverseOperator differs in that its matvec routes
through the structured gaussx solve dispatch, and its
as_matrix uses a Cholesky path for PSD operators.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
An invertible linear operator. |
required |
solver
|
AbstractLinearSolver | None
|
Optional lineax solver for the fallback InverseOperator. |
None
|
Returns:
| Type | Description |
|---|---|
AbstractLinearOperator
|
An operator representing A^{-1}. |
Source code in src/gaussx/_primitives/_inv.py
sqrt(operator: lx.AbstractLinearOperator, *, lanczos_order: int | None = None) -> lx.AbstractLinearOperator
¶
Compute matrix square root S such that S @ S = A.
Requires A to be positive semi-definite.
When lanczos_order is given, returns a lazy SqrtOperator
that computes sqrt(A) @ v via matfree Lanczos without
materializing the full square root matrix.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
A PSD linear operator. |
required |
lanczos_order
|
int | None
|
Order of Lanczos iteration for matrix-free
sqrt. If |
None
|
Returns:
| Type | Description |
|---|---|
AbstractLinearOperator
|
Operator S satisfying S @ S = A. |
Source code in src/gaussx/_primitives/_sqrt.py
eig(operator: lx.AbstractLinearOperator, *, rank: int | None = None, key: jax.Array | None = None) -> tuple[Array, Array]
¶
Compute eigenvalues and eigenvectors.
For symmetric operators returns real eigenvalues via eigh.
When rank is given, computes a partial eigendecomposition
via matfree Lanczos (symmetric) — no matrix materialization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
A square linear operator. |
required |
rank
|
int | None
|
Number of eigenvalues to compute. If |
None
|
key
|
Array | None
|
PRNG key for the initial random vector when using
partial eig. If |
None
|
Returns:
| Type | Description |
|---|---|
Array
|
Tuple |
Array
|
shape |
Source code in src/gaussx/_primitives/_eig.py
eigvals(operator: lx.AbstractLinearOperator, *, rank: int | None = None, key: jax.Array | None = None) -> Array
¶
Compute eigenvalues only.
When rank is given, returns the top-k eigenvalues via
matfree Lanczos without matrix materialization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
A square linear operator. |
required |
rank
|
int | None
|
Number of eigenvalues to compute. |
None
|
key
|
Array | None
|
PRNG key for partial eigendecomposition. |
None
|
Returns:
| Type | Description |
|---|---|
Array
|
Eigenvalues array of shape |
Source code in src/gaussx/_primitives/_eig.py
svd(operator: lx.AbstractLinearOperator, *, rank: int | None = None, key: jax.Array | None = None) -> tuple[Float[Array, 'm k'], Float[Array, ' k'], Float[Array, 'k n']]
¶
Compute the singular value decomposition A = U diag(s) V^T.
When rank is given, computes a partial (truncated) SVD via
matfree's Golub-Kahan bidiagonalization — no matrix materialization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
A linear operator. |
required |
rank
|
int | None
|
Number of singular values to compute. If |
None
|
key
|
Array | None
|
PRNG key for the initial random vector when using
partial SVD. If |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'm k']
|
Tuple |
Float[Array, ' k']
|
s has shape |
Source code in src/gaussx/_primitives/_svd.py
frobenius_norm(operator: lx.AbstractLinearOperator, *, stochastic: bool = False, num_probes: int = 20, key: jax.Array | None = None, sampler: SamplerName = 'signs') -> Float[Array, '']
¶
Compute the Frobenius norm ||A||_F with structural dispatch.
Structured operators avoid materialization:
- Diagonal: vector 2-norm of the diagonal.
- BlockDiag: root of the sum of squared block norms.
- Kronecker:
||A (x) B||_F = ||A||_F * ||B||_F. - Scaled/negated/tagged operators delegate to the wrapped operator.
When stochastic=True, estimates ||A||_F^2 = tr(A^T A) via
matfree's Hutchinson estimator — matvec access only.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
A linear operator. |
required |
stochastic
|
bool
|
If |
False
|
num_probes
|
int
|
Number of probe vectors for stochastic mode. |
20
|
key
|
Array | None
|
PRNG key for stochastic mode. |
None
|
sampler
|
SamplerName
|
Probe distribution for stochastic mode ( |
'signs'
|
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar Frobenius norm (exact or estimated). |
Source code in src/gaussx/_primitives/_frobenius.py
submatrix(operator: lx.AbstractLinearOperator, row_idx: Int[Array, ' R'], col_idx: Int[Array, ' C']) -> Float[Array, 'R C']
¶
Extract A[row_idx, col_idx] without forming the full matrix.
For structured operators, exploits the structure to avoid
materializing the full (N, N) matrix when only a sub-block is
needed (e.g., the conditional Gaussian extracts Sigma_AA,
Sigma_AB, Sigma_BB from a joint covariance).
Currently dispatches on:
lineax.DiagonalLinearOperatorlineax.TaggedLinearOperator(delegates to the wrapped operator)gaussx.BlockDiaggaussx.Kronecker
Falls back to operator.as_matrix()[ix_(row_idx, col_idx)] for
other operators.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
Linear operator A, shape |
required |
row_idx
|
Int[Array, ' R']
|
Row indices, shape |
required |
col_idx
|
Int[Array, ' C']
|
Column indices, shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'R C']
|
Dense sub-matrix |
Source code in src/gaussx/_primitives/_submatrix.py
Root decompositions¶
Tall-factor approximations \(RR^\top \approx A\) (and \(R^- (R^-)^\top \approx A^{-1}\)) via Cholesky, pivoted Cholesky, Lanczos, or truncated SVD — the building block for low-rank posterior sampling and BBMM-style solvers.
Structured linear algebra and Gaussian primitives for JAX.
RootDecomposition
¶
Bases: Module
Tall factor R with R Rᵀ ≈ A semantics.
Attributes:
| Name | Type | Description |
|---|---|---|
root |
Float[Array, 'N k']
|
Tall factor with shape |
Source code in src/gaussx/_primitives/_root.py
rank: int
property
¶
Number of retained root directions.
root_decomposition(operator: lx.AbstractLinearOperator, rank: int = 50, method: RootMethod = 'lanczos', key: jax.Array | None = None) -> RootDecomposition
¶
Compute a tall factor R such that R Rᵀ ≈ A.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
Square symmetric positive-definite operator |
required |
rank
|
int
|
Number of retained directions. Ignored by |
50
|
method
|
RootMethod
|
Decomposition method: |
'lanczos'
|
key
|
Array | None
|
PRNG key for random-start methods. |
None
|
Returns:
| Type | Description |
|---|---|
RootDecomposition
|
A |
Source code in src/gaussx/_primitives/_root.py
root_inv_decomposition(operator: lx.AbstractLinearOperator, rank: int = 50, method: RootMethod = 'lanczos', key: jax.Array | None = None) -> RootDecomposition
¶
Compute a tall factor R⁻ such that R⁻ (R⁻)ᵀ ≈ A⁻¹.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
operator
|
AbstractLinearOperator
|
Square symmetric positive-definite operator |
required |
rank
|
int
|
Number of retained directions. Ignored by |
50
|
method
|
RootMethod
|
Decomposition method: |
'lanczos'
|
key
|
Array | None
|
PRNG key for random-start methods. |
None
|
Returns:
| Type | Description |
|---|---|
RootDecomposition
|
A |
Source code in src/gaussx/_primitives/_root.py
Support types¶
Structured linear algebra and Gaussian primitives for JAX.
SumKroneckerSqrt
¶
Bases: SqrtOperator
Lazy Lanczos square-root operator for SumKronecker covariances.
Specialization of SqrtOperator that narrows original to a
SumKronecker operator. mv computes sqrt(A) v via
matfree's Lanczos matrix-function product without materializing the full
square root.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
original
|
SumKronecker
|
The |
required |
lanczos_order
|
int
|
Number of Lanczos iterations; clamped to the operator size. |
_DEFAULT_LANCZOS_ORDER
|