Skip to content

Primitives

Layer 0: pure functions over lineax.AbstractLinearOperator with structural dispatch — each primitive inspects the operator (diagonal, Kronecker, block-diagonal, low-rank, block-tridiagonal, …) and routes to the cheapest exact algorithm, falling back to a dense computation (with a DenseFallbackWarning) only when no structured path exists.

Solve, logdet & Cholesky

The workhorses behind Gaussian densities: \(A^{-1}b\), \(\log|A|\), and \(A = LL^\top\). cholesky returns a lazy lower-triangular operator that preserves structure (the Cholesky of a Kronecker is a Kronecker of Cholesky factors); cholesky_logdet turns an existing factor into \(\log|A| = 2\sum_i \log L_{ii}\) for free.

Structured linear algebra and Gaussian primitives for JAX.

solve(operator: lx.AbstractLinearOperator, vector: Float[Array, ' n'], *, solver: lx.AbstractLinearSolver | None = None) -> Float[Array, ' n']

Solve A x = b with structural dispatch.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

The linear operator A.

required
vector Float[Array, ' n']

The right-hand side b.

required
solver AbstractLinearSolver | None

Optional lineax solver override for the fallback path.

None

Returns:

Type Description
Float[Array, ' n']

The solution x.

Source code in src/gaussx/_primitives/_solve.py
def solve(
    operator: lx.AbstractLinearOperator,
    vector: Float[Array, " n"],
    *,
    solver: lx.AbstractLinearSolver | None = None,
) -> Float[Array, " n"]:
    """Solve ``A x = b`` with structural dispatch.

    Args:
        operator: The linear operator A.
        vector: The right-hand side b.
        solver: Optional lineax solver override for the fallback path.

    Returns:
        The solution x.
    """
    if isinstance(operator, lx.IdentityLinearOperator):
        return vector
    if isinstance(operator, lx.DiagonalLinearOperator):
        return _solve_diagonal(operator, vector)
    if isinstance(operator, BlockDiag):
        return _solve_block_diag(operator, vector, solver)
    if isinstance(operator, Kronecker):
        return _solve_kronecker(operator, vector, solver)
    if isinstance(operator, LowRankUpdate):
        return _solve_low_rank(operator, vector, solver)
    if isinstance(operator, KroneckerSum):
        return _solve_kronecker_sum(operator, vector)
    if isinstance(operator, KroneckerSumSqrt):
        return operator.solve(vector)
    if isinstance(operator, BlockTriDiag):
        return _solve_block_tridiag(operator, vector)
    if isinstance(operator, LowerBlockTriDiag):
        return _solve_lower_block_tridiag(operator, vector)
    if isinstance(operator, UpperBlockTriDiag):
        return _solve_upper_block_tridiag(operator, vector)
    if isinstance(operator, lx.TaggedLinearOperator):
        return _solve_tagged(operator, vector, solver)
    if isinstance(operator, lx.MulLinearOperator):
        return solve(operator.operator, vector, solver=solver) / operator.scalar
    if isinstance(operator, lx.DivLinearOperator):
        return solve(operator.operator, vector, solver=solver) * operator.scalar
    if isinstance(operator, lx.NegLinearOperator):
        return -solve(operator.operator, vector, solver=solver)
    if isinstance(operator, lx.ComposedLinearOperator):
        # (A B) x = b  =>  x = B^{-1} (A^{-1} b)
        y = solve(operator.operator1, vector, solver=solver)
        return solve(operator.operator2, y, solver=solver)
    return _solve_fallback(operator, vector, solver)

logdet(operator: lx.AbstractLinearOperator) -> Float[Array, '']

Compute log |det(A)| with structural dispatch.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

The linear operator A.

required

Returns:

Type Description
Float[Array, '']

Scalar log |det(A)|.

Source code in src/gaussx/_primitives/_logdet.py
def logdet(operator: lx.AbstractLinearOperator) -> Float[Array, ""]:
    """Compute log |det(A)| with structural dispatch.

    Args:
        operator: The linear operator A.

    Returns:
        Scalar log |det(A)|.
    """
    if isinstance(operator, lx.IdentityLinearOperator):
        return jnp.array(0.0)
    if isinstance(operator, lx.DiagonalLinearOperator):
        return _logdet_diagonal(operator)
    if isinstance(operator, BlockDiag):
        return _logdet_block_diag(operator)
    if isinstance(operator, Kronecker):
        return _logdet_kronecker(operator)
    if isinstance(operator, LowRankUpdate):
        return _logdet_low_rank(operator)
    if isinstance(operator, KroneckerSum):
        return _logdet_kronecker_sum(operator)
    if isinstance(operator, BlockTriDiag):
        return _logdet_block_tridiag(operator)
    if isinstance(operator, LowerBlockTriDiag | UpperBlockTriDiag):
        return _logdet_block_bidiagonal(operator)
    if isinstance(operator, lx.TaggedLinearOperator):
        return logdet(operator.operator)
    if isinstance(operator, lx.MulLinearOperator):
        n = operator.out_size()
        return n * jnp.log(jnp.abs(operator.scalar)) + logdet(operator.operator)
    if isinstance(operator, lx.DivLinearOperator):
        n = operator.out_size()
        return logdet(operator.operator) - n * jnp.log(jnp.abs(operator.scalar))
    if isinstance(operator, lx.NegLinearOperator):
        # log|det(-A)| = log|det(A)| since |(-1)^n| = 1.
        return logdet(operator.operator)
    if isinstance(operator, lx.ComposedLinearOperator) and (
        operator.operator1.in_size() == operator.operator1.out_size()
        and operator.operator2.in_size() == operator.operator2.out_size()
    ):
        return logdet(operator.operator1) + logdet(operator.operator2)
    return _logdet_dense(operator)

cholesky(operator: lx.AbstractLinearOperator) -> lx.AbstractLinearOperator

Compute Cholesky factor L such that A = L L^T.

Returns a linear operator (not a raw array). For structured operators, the result preserves structure.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A positive-definite linear operator.

required

Returns:

Type Description
AbstractLinearOperator

Lower-triangular operator L.

Source code in src/gaussx/_primitives/_cholesky.py
def cholesky(
    operator: lx.AbstractLinearOperator,
) -> lx.AbstractLinearOperator:
    """Compute Cholesky factor L such that A = L L^T.

    Returns a linear operator (not a raw array). For structured
    operators, the result preserves structure.

    Args:
        operator: A positive-definite linear operator.

    Returns:
        Lower-triangular operator L.
    """
    if isinstance(operator, lx.IdentityLinearOperator):
        return operator
    if isinstance(operator, lx.DiagonalLinearOperator):
        return _cholesky_diagonal(operator)
    if isinstance(operator, BlockDiag):
        return _cholesky_block_diag(operator)
    if isinstance(operator, Kronecker):
        return _cholesky_kronecker(operator)
    if isinstance(operator, BlockTriDiag):
        return _cholesky_block_tridiag(operator)
    if isinstance(operator, SumKronecker):
        return _cholesky_sum_kronecker(operator)
    if isinstance(operator, lx.TaggedLinearOperator):
        return cholesky(operator.operator)
    return _cholesky_dense(operator)

cholesky_logdet(L: Float[Array, 'N N']) -> Float[Array, '']

Compute log|A| from Cholesky factor L where A = L Lᵀ.

Parameters:

Name Type Description Default
L Float[Array, 'N N']

Lower-triangular Cholesky factor, shape (N, N).

required

Returns:

Type Description
Float[Array, '']

Scalar log-determinant.

Source code in src/gaussx/_primitives/_logdet.py
def cholesky_logdet(L: Float[Array, "N N"]) -> Float[Array, ""]:
    """Compute log|A| from Cholesky factor L where A = L Lᵀ.

    Args:
        L: Lower-triangular Cholesky factor, shape ``(N, N)``.

    Returns:
        Scalar log-determinant.
    """
    return 2.0 * jnp.sum(jnp.log(jnp.diag(L)))

Trace & diagonal

Exact where structure allows; stochastic (Hutchinson / XTrace probing) for matrix-free operators. trace_and_diag shares one probe pass between both estimates.

Structured linear algebra and Gaussian primitives for JAX.

trace(operator: lx.AbstractLinearOperator, *, stochastic: bool = False, num_probes: int = 20, key: jax.Array | None = None, sampler: SamplerName | None = None, algorithm: Literal['hutchinson', 'xtrace'] = 'hutchinson') -> Float[Array, '']

Compute the trace of an operator.

When stochastic=True, uses a matfree stochastic estimator — only requires matvec access, no materialization.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A square linear operator.

required
stochastic bool

If True, use stochastic trace estimation.

False
num_probes int

Number of probe vectors for stochastic mode.

20
key Array | None

PRNG key for stochastic mode.

None
sampler SamplerName | None

Probe distribution ("signs", "normal", "sphere"). Defaults to "signs" for Hutchinson and "sphere" for XTrace (which requires a rotationally invariant distribution).

None
algorithm Literal['hutchinson', 'xtrace']

"hutchinson" (Monte-Carlo) or "xtrace" (leave-one-out, Epperly et al. 2024 — much lower variance for the same number of matvecs).

'hutchinson'

Returns:

Type Description
Float[Array, '']

Scalar trace value (exact or estimated).

Source code in src/gaussx/_primitives/_trace.py
def trace(
    operator: lx.AbstractLinearOperator,
    *,
    stochastic: bool = False,
    num_probes: int = 20,
    key: jax.Array | None = None,
    sampler: SamplerName | None = None,
    algorithm: Literal["hutchinson", "xtrace"] = "hutchinson",
) -> Float[Array, ""]:
    """Compute the trace of an operator.

    When ``stochastic=True``, uses a matfree stochastic estimator —
    only requires matvec access, no materialization.

    Args:
        operator: A square linear operator.
        stochastic: If ``True``, use stochastic trace estimation.
        num_probes: Number of probe vectors for stochastic mode.
        key: PRNG key for stochastic mode.
        sampler: Probe distribution (``"signs"``, ``"normal"``,
            ``"sphere"``). Defaults to ``"signs"`` for Hutchinson and
            ``"sphere"`` for XTrace (which requires a rotationally
            invariant distribution).
        algorithm: ``"hutchinson"`` (Monte-Carlo) or ``"xtrace"``
            (leave-one-out, Epperly et al. 2024 — much lower variance
            for the same number of matvecs).

    Returns:
        Scalar trace value (exact or estimated).
    """
    if isinstance(operator, lx.IdentityLinearOperator):
        return jnp.asarray(float(operator.in_size()))
    if isinstance(operator, lx.DiagonalLinearOperator):
        return jnp.sum(lx.diagonal(operator))
    if isinstance(operator, BlockDiag):
        return _trace_block_diag(operator)
    if isinstance(operator, Kronecker):
        return _trace_kronecker(operator)
    if isinstance(operator, BlockTriDiag):
        return _trace_block_tridiag(operator)
    if isinstance(operator, LowRankUpdate):
        return _trace_low_rank(operator)
    if isinstance(operator, KroneckerSum):
        return _trace_kronecker_sum(operator)
    if isinstance(operator, SumKronecker):
        return ft.reduce(jnp.add, (trace(kron) for kron in operator.operators))
    if isinstance(operator, lx.TaggedLinearOperator):
        return trace(
            operator.operator,
            stochastic=stochastic,
            num_probes=num_probes,
            key=key,
            sampler=sampler,
            algorithm=algorithm,
        )
    if isinstance(operator, lx.AddLinearOperator):
        return trace(operator.operator1) + trace(operator.operator2)
    if isinstance(operator, lx.MulLinearOperator):
        return operator.scalar * trace(operator.operator)
    if isinstance(operator, lx.DivLinearOperator):
        return trace(operator.operator) / operator.scalar
    if isinstance(operator, lx.NegLinearOperator):
        return -trace(operator.operator)
    if stochastic:
        return _trace_stochastic(operator, num_probes, key, sampler, algorithm)
    return jnp.trace(operator.as_matrix())

diag(operator: lx.AbstractLinearOperator, *, stochastic: bool = False, num_probes: int = 20, key: jax.Array | None = None, sampler: SamplerName = 'signs') -> Float[Array, ' n']

Extract the diagonal of an operator as a 1D array.

When stochastic=True, uses Hutchinson's diagonal estimator via matfree — only requires matvec access, no materialization.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A linear operator.

required
stochastic bool

If True, use stochastic diagonal estimation.

False
num_probes int

Number of probe vectors for stochastic mode.

20
key Array | None

PRNG key for stochastic mode.

None
sampler SamplerName

Probe distribution for stochastic mode ("signs", "normal", "sphere").

'signs'

Returns:

Type Description
Float[Array, ' n']

1D array of diagonal entries (exact or estimated).

Source code in src/gaussx/_primitives/_diag.py
def diag(
    operator: lx.AbstractLinearOperator,
    *,
    stochastic: bool = False,
    num_probes: int = 20,
    key: jax.Array | None = None,
    sampler: SamplerName = "signs",
) -> Float[Array, " n"]:
    """Extract the diagonal of an operator as a 1D array.

    When ``stochastic=True``, uses Hutchinson's diagonal estimator
    via matfree — only requires matvec access, no materialization.

    Args:
        operator: A linear operator.
        stochastic: If ``True``, use stochastic diagonal estimation.
        num_probes: Number of probe vectors for stochastic mode.
        key: PRNG key for stochastic mode.
        sampler: Probe distribution for stochastic mode (``"signs"``,
            ``"normal"``, ``"sphere"``).

    Returns:
        1D array of diagonal entries (exact or estimated).
    """
    if isinstance(operator, lx.IdentityLinearOperator):
        return jnp.ones(operator.in_size(), dtype=operator.in_structure().dtype)
    if isinstance(operator, lx.DiagonalLinearOperator):
        return lx.diagonal(operator)
    if isinstance(operator, BlockDiag):
        return _diag_block_diag(operator)
    if isinstance(operator, Kronecker):
        return _diag_kronecker(operator)
    if isinstance(operator, BlockTriDiag | LowerBlockTriDiag | UpperBlockTriDiag):
        return _diag_block_tridiag(operator)
    if isinstance(operator, LowRankUpdate):
        return _diag_low_rank(operator)
    if isinstance(operator, KroneckerSum):
        return _diag_kronecker_sum(operator)
    if isinstance(operator, SumKronecker):
        return ft.reduce(jnp.add, (diag(kron) for kron in operator.operators))
    if isinstance(operator, lx.TaggedLinearOperator):
        return diag(
            operator.operator,
            stochastic=stochastic,
            num_probes=num_probes,
            key=key,
            sampler=sampler,
        )
    if isinstance(operator, lx.AddLinearOperator):
        return diag(operator.operator1) + diag(operator.operator2)
    if isinstance(operator, lx.MulLinearOperator):
        return operator.scalar * diag(operator.operator)
    if isinstance(operator, lx.DivLinearOperator):
        return diag(operator.operator) / operator.scalar
    if isinstance(operator, lx.NegLinearOperator):
        return -diag(operator.operator)
    if stochastic:
        return _diag_stochastic(operator, num_probes, key, sampler)
    return jnp.diag(operator.as_matrix())

trace_and_diag(operator: lx.AbstractLinearOperator, *, num_probes: int = 20, key: jax.Array | None = None, sampler: SamplerName = 'signs') -> tuple[Float[Array, ''], Float[Array, ' n']]

Jointly estimate the trace and diagonal from one probe pass.

Halves the matvec budget relative to calling trace(..., stochastic=True) and diag(..., stochastic=True) separately — both statistics are accumulated from the same A @ probe products.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A square linear operator.

required
num_probes int

Number of probe vectors.

20
key Array | None

PRNG key. If None, uses jax.random.PRNGKey(0).

None
sampler SamplerName

Probe distribution ("signs", "normal", "sphere").

'signs'

Returns:

Type Description
tuple[Float[Array, ''], Float[Array, ' n']]

Tuple (trace_estimate, diagonal_estimate).

Source code in src/gaussx/_primitives/_trace.py
def trace_and_diag(
    operator: lx.AbstractLinearOperator,
    *,
    num_probes: int = 20,
    key: jax.Array | None = None,
    sampler: SamplerName = "signs",
) -> tuple[Float[Array, ""], Float[Array, " n"]]:
    """Jointly estimate the trace and diagonal from one probe pass.

    Halves the matvec budget relative to calling
    ``trace(..., stochastic=True)`` and ``diag(..., stochastic=True)``
    separately — both statistics are accumulated from the same
    ``A @ probe`` products.

    Args:
        operator: A square linear operator.
        num_probes: Number of probe vectors.
        key: PRNG key. If ``None``, uses ``jax.random.PRNGKey(0)``.
        sampler: Probe distribution (``"signs"``, ``"normal"``,
            ``"sphere"``).

    Returns:
        Tuple ``(trace_estimate, diagonal_estimate)``.
    """
    if key is None:
        key = jax.random.PRNGKey(0)

    n = operator.in_size()
    probe_fn = resolve_sampler(sampler, n, num_probes)
    integrand = matfree.stochtrace.monte_carlo_trace_and_diagonal()
    estimate = matfree.stochtrace.estimator_monte_carlo(integrand, probe_fn)
    result = estimate(operator.mv, key)
    return result["trace"], result["diagonal"]

Inverse, square root & spectral decompositions

inv and sqrt return lazy operators that route their matvecs through structured solves / Lanczos; eig, eigvals, and svd take an optional rank for partial (Krylov) decompositions.

Structured linear algebra and Gaussian primitives for JAX.

inv(operator: lx.AbstractLinearOperator, *, solver: lx.AbstractLinearSolver | None = None) -> lx.AbstractLinearOperator

Return a lazy inverse operator A^{-1}.

The returned operator computes A^{-1} v via solve(A, v) when mv is called. For structured operators, the inverse preserves structure.

Related to lineax.invert (lineax >= 0.1.1), which wraps lx.linear_solve in a FunctionLinearOperator. The gaussx fallback InverseOperator differs in that its matvec routes through the structured gaussx solve dispatch, and its as_matrix uses a Cholesky path for PSD operators.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

An invertible linear operator.

required
solver AbstractLinearSolver | None

Optional lineax solver for the fallback InverseOperator.

None

Returns:

Type Description
AbstractLinearOperator

An operator representing A^{-1}.

Source code in src/gaussx/_primitives/_inv.py
def inv(
    operator: lx.AbstractLinearOperator,
    *,
    solver: lx.AbstractLinearSolver | None = None,
) -> lx.AbstractLinearOperator:
    """Return a lazy inverse operator A^{-1}.

    The returned operator computes A^{-1} v via ``solve(A, v)``
    when ``mv`` is called. For structured operators, the inverse
    preserves structure.

    Related to ``lineax.invert`` (lineax >= 0.1.1), which wraps
    ``lx.linear_solve`` in a ``FunctionLinearOperator``. The gaussx
    fallback ``InverseOperator`` differs in that its matvec routes
    through the *structured* gaussx ``solve`` dispatch, and its
    ``as_matrix`` uses a Cholesky path for PSD operators.

    Args:
        operator: An invertible linear operator.
        solver: Optional lineax solver for the fallback InverseOperator.

    Returns:
        An operator representing A^{-1}.
    """
    if isinstance(operator, lx.IdentityLinearOperator):
        return operator
    if isinstance(operator, lx.DiagonalLinearOperator):
        return _inv_diagonal(operator)
    if isinstance(operator, BlockDiag):
        return _inv_block_diag(operator)
    if isinstance(operator, Kronecker):
        return _inv_kronecker(operator)
    if (
        isinstance(operator, LowRankUpdate)
        and lx.is_symmetric(operator)
        and _arrays_match(operator.U, operator.V)
    ):
        return _inv_low_rank_symmetric(operator, solver)
    if isinstance(operator, lx.MulLinearOperator):
        return (1.0 / operator.scalar) * inv(operator.operator, solver=solver)
    if isinstance(operator, lx.DivLinearOperator):
        return operator.scalar * inv(operator.operator, solver=solver)
    if isinstance(operator, lx.NegLinearOperator):
        return -inv(operator.operator, solver=solver)
    if isinstance(operator, lx.ComposedLinearOperator) and (
        operator.operator1.in_size() == operator.operator1.out_size()
        and operator.operator2.in_size() == operator.operator2.out_size()
    ):
        # (A B)^{-1} = B^{-1} A^{-1}
        return inv(operator.operator2, solver=solver) @ inv(
            operator.operator1, solver=solver
        )
    return InverseOperator(operator, solver)

sqrt(operator: lx.AbstractLinearOperator, *, lanczos_order: int | None = None) -> lx.AbstractLinearOperator

Compute matrix square root S such that S @ S = A.

Requires A to be positive semi-definite.

When lanczos_order is given, returns a lazy SqrtOperator that computes sqrt(A) @ v via matfree Lanczos without materializing the full square root matrix.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A PSD linear operator.

required
lanczos_order int | None

Order of Lanczos iteration for matrix-free sqrt. If None, uses dense eigendecomposition for most operators, except SumKronecker — where None falls back to a Lanczos sqrt with the module default order (no closed-form sqrt exists for the sum-of-Kronecker structure). Pass an explicit lanczos_order to override the default rank.

None

Returns:

Type Description
AbstractLinearOperator

Operator S satisfying S @ S = A.

Source code in src/gaussx/_primitives/_sqrt.py
def sqrt(
    operator: lx.AbstractLinearOperator,
    *,
    lanczos_order: int | None = None,
) -> lx.AbstractLinearOperator:
    """Compute matrix square root S such that S @ S = A.

    Requires A to be positive semi-definite.

    When ``lanczos_order`` is given, returns a lazy ``SqrtOperator``
    that computes ``sqrt(A) @ v`` via matfree Lanczos without
    materializing the full square root matrix.

    Args:
        operator: A PSD linear operator.
        lanczos_order: Order of Lanczos iteration for matrix-free
            sqrt. If ``None``, uses dense eigendecomposition for most
            operators, *except* `SumKronecker` — where ``None``
            falls back to a Lanczos sqrt with the module default order
            (no closed-form sqrt exists for the sum-of-Kronecker
            structure). Pass an explicit ``lanczos_order`` to override
            the default rank.

    Returns:
        Operator S satisfying S @ S = A.
    """
    if isinstance(operator, lx.IdentityLinearOperator):
        return operator
    if isinstance(operator, lx.DiagonalLinearOperator):
        return _sqrt_diagonal(operator)
    if isinstance(operator, lx.TaggedLinearOperator):
        return sqrt(operator.operator, lanczos_order=lanczos_order)
    if isinstance(operator, BlockDiag):
        return _sqrt_block_diag(operator)
    if isinstance(operator, Kronecker):
        return _sqrt_kronecker(operator)
    if isinstance(operator, KroneckerSum):
        return _sqrt_kronecker_sum(operator)
    if isinstance(operator, SumKronecker):
        return _sqrt_sum_kronecker(
            operator,
            lanczos_order=(
                _DEFAULT_LANCZOS_ORDER if lanczos_order is None else lanczos_order
            ),
        )
    if lanczos_order is not None:
        return SqrtOperator(operator, lanczos_order)
    return _sqrt_dense(operator)

eig(operator: lx.AbstractLinearOperator, *, rank: int | None = None, key: jax.Array | None = None) -> tuple[Array, Array]

Compute eigenvalues and eigenvectors.

For symmetric operators returns real eigenvalues via eigh. When rank is given, computes a partial eigendecomposition via matfree Lanczos (symmetric) — no matrix materialization.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A square linear operator.

required
rank int | None

Number of eigenvalues to compute. If None, computes the full eigendecomposition.

None
key Array | None

PRNG key for the initial random vector when using partial eig. If None, uses jax.random.PRNGKey(0).

None

Returns:

Type Description
Array

Tuple (eigenvalues, eigenvectors) where eigenvalues has

Array

shape (K,) and eigenvectors has shape (N, K).

Source code in src/gaussx/_primitives/_eig.py
def eig(
    operator: lx.AbstractLinearOperator,
    *,
    rank: int | None = None,
    key: jax.Array | None = None,
) -> tuple[Array, Array]:
    """Compute eigenvalues and eigenvectors.

    For symmetric operators returns real eigenvalues via ``eigh``.
    When ``rank`` is given, computes a partial eigendecomposition
    via matfree Lanczos (symmetric) — no matrix materialization.

    Args:
        operator: A square linear operator.
        rank: Number of eigenvalues to compute. If ``None``,
            computes the full eigendecomposition.
        key: PRNG key for the initial random vector when using
            partial eig. If ``None``, uses ``jax.random.PRNGKey(0)``.

    Returns:
        Tuple ``(eigenvalues, eigenvectors)`` where eigenvalues has
        shape ``(K,)`` and eigenvectors has shape ``(N, K)``.
    """
    if isinstance(operator, lx.DiagonalLinearOperator):
        return _eig_diagonal(operator)
    if isinstance(operator, BlockDiag):
        return _eig_block_diag(operator)
    if isinstance(operator, Kronecker):
        return _eig_kronecker(operator)
    if isinstance(operator, KroneckerSum):
        return _eig_kronecker_sum(operator)
    if rank is not None:
        return _eig_partial(operator, rank, key)
    return _eig_dense(operator)

eigvals(operator: lx.AbstractLinearOperator, *, rank: int | None = None, key: jax.Array | None = None) -> Array

Compute eigenvalues only.

When rank is given, returns the top-k eigenvalues via matfree Lanczos without matrix materialization.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A square linear operator.

required
rank int | None

Number of eigenvalues to compute.

None
key Array | None

PRNG key for partial eigendecomposition.

None

Returns:

Type Description
Array

Eigenvalues array of shape (K,).

Source code in src/gaussx/_primitives/_eig.py
def eigvals(
    operator: lx.AbstractLinearOperator,
    *,
    rank: int | None = None,
    key: jax.Array | None = None,
) -> Array:
    """Compute eigenvalues only.

    When ``rank`` is given, returns the top-k eigenvalues via
    matfree Lanczos without matrix materialization.

    Args:
        operator: A square linear operator.
        rank: Number of eigenvalues to compute.
        key: PRNG key for partial eigendecomposition.

    Returns:
        Eigenvalues array of shape ``(K,)``.
    """
    if isinstance(operator, lx.DiagonalLinearOperator):
        return lx.diagonal(operator)
    if isinstance(operator, BlockDiag):
        return jnp.concatenate([eigvals(op) for op in operator.operators])
    if isinstance(operator, Kronecker):
        return _eigvals_kronecker(operator)
    if isinstance(operator, KroneckerSum):
        return _eigvals_kronecker_sum(operator)
    if rank is not None:
        vals, _ = _eig_partial(operator, rank, key)
        return vals
    return _eigvals_dense(operator)

svd(operator: lx.AbstractLinearOperator, *, rank: int | None = None, key: jax.Array | None = None) -> tuple[Float[Array, 'm k'], Float[Array, ' k'], Float[Array, 'k n']]

Compute the singular value decomposition A = U diag(s) V^T.

When rank is given, computes a partial (truncated) SVD via matfree's Golub-Kahan bidiagonalization — no matrix materialization.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A linear operator.

required
rank int | None

Number of singular values to compute. If None, computes the full SVD (requires materialization).

None
key Array | None

PRNG key for the initial random vector when using partial SVD. If None, uses jax.random.PRNGKey(0).

None

Returns:

Type Description
Float[Array, 'm k']

Tuple (U, s, Vt) where U has shape (M, K),

Float[Array, ' k']

s has shape (K,), and Vt has shape (K, N).

Source code in src/gaussx/_primitives/_svd.py
def svd(
    operator: lx.AbstractLinearOperator,
    *,
    rank: int | None = None,
    key: jax.Array | None = None,
) -> tuple[Float[Array, "m k"], Float[Array, " k"], Float[Array, "k n"]]:
    """Compute the singular value decomposition ``A = U diag(s) V^T``.

    When ``rank`` is given, computes a partial (truncated) SVD via
    matfree's Golub-Kahan bidiagonalization — no matrix materialization.

    Args:
        operator: A linear operator.
        rank: Number of singular values to compute. If ``None``,
            computes the full SVD (requires materialization).
        key: PRNG key for the initial random vector when using
            partial SVD. If ``None``, uses ``jax.random.PRNGKey(0)``.

    Returns:
        Tuple ``(U, s, Vt)`` where U has shape ``(M, K)``,
        s has shape ``(K,)``, and Vt has shape ``(K, N)``.
    """
    if isinstance(operator, lx.DiagonalLinearOperator):
        return _svd_diagonal(operator)
    if rank is not None:
        return _svd_partial(operator, rank, key)
    if isinstance(operator, Kronecker):
        return _svd_kronecker(operator)
    if isinstance(operator, BlockDiag):
        return _svd_block_diag(operator)
    if isinstance(operator, lx.TaggedLinearOperator):
        return svd(operator.operator)
    return _svd_dense(operator)

frobenius_norm(operator: lx.AbstractLinearOperator, *, stochastic: bool = False, num_probes: int = 20, key: jax.Array | None = None, sampler: SamplerName = 'signs') -> Float[Array, '']

Compute the Frobenius norm ||A||_F with structural dispatch.

Structured operators avoid materialization:

  • Diagonal: vector 2-norm of the diagonal.
  • BlockDiag: root of the sum of squared block norms.
  • Kronecker: ||A (x) B||_F = ||A||_F * ||B||_F.
  • Scaled/negated/tagged operators delegate to the wrapped operator.

When stochastic=True, estimates ||A||_F^2 = tr(A^T A) via matfree's Hutchinson estimator — matvec access only.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A linear operator.

required
stochastic bool

If True, use stochastic estimation.

False
num_probes int

Number of probe vectors for stochastic mode.

20
key Array | None

PRNG key for stochastic mode.

None
sampler SamplerName

Probe distribution for stochastic mode ("signs", "normal", "sphere").

'signs'

Returns:

Type Description
Float[Array, '']

Scalar Frobenius norm (exact or estimated).

Source code in src/gaussx/_primitives/_frobenius.py
def frobenius_norm(
    operator: lx.AbstractLinearOperator,
    *,
    stochastic: bool = False,
    num_probes: int = 20,
    key: jax.Array | None = None,
    sampler: SamplerName = "signs",
) -> Float[Array, ""]:
    """Compute the Frobenius norm ``||A||_F`` with structural dispatch.

    Structured operators avoid materialization:

    - Diagonal: vector 2-norm of the diagonal.
    - BlockDiag: root of the sum of squared block norms.
    - Kronecker: ``||A (x) B||_F = ||A||_F * ||B||_F``.
    - Scaled/negated/tagged operators delegate to the wrapped operator.

    When ``stochastic=True``, estimates ``||A||_F^2 = tr(A^T A)`` via
    matfree's Hutchinson estimator — matvec access only.

    Args:
        operator: A linear operator.
        stochastic: If ``True``, use stochastic estimation.
        num_probes: Number of probe vectors for stochastic mode.
        key: PRNG key for stochastic mode.
        sampler: Probe distribution for stochastic mode (``"signs"``,
            ``"normal"``, ``"sphere"``).

    Returns:
        Scalar Frobenius norm (exact or estimated).
    """
    if isinstance(operator, lx.IdentityLinearOperator):
        return jnp.sqrt(jnp.asarray(float(operator.in_size())))
    if isinstance(operator, lx.DiagonalLinearOperator):
        d = lx.diagonal(operator)
        return jnp.sqrt(jnp.sum(d * d))
    if isinstance(operator, BlockDiag):
        norms = jnp.stack([frobenius_norm(op) for op in operator.operators])
        return jnp.sqrt(jnp.sum(norms * norms))
    if isinstance(operator, Kronecker):
        return ft.reduce(
            jnp.multiply, (frobenius_norm(op) for op in operator.operators)
        )
    if isinstance(operator, lx.TaggedLinearOperator):
        return frobenius_norm(
            operator.operator,
            stochastic=stochastic,
            num_probes=num_probes,
            key=key,
            sampler=sampler,
        )
    if isinstance(operator, lx.MulLinearOperator):
        return jnp.abs(operator.scalar) * frobenius_norm(operator.operator)
    if isinstance(operator, lx.DivLinearOperator):
        return frobenius_norm(operator.operator) / jnp.abs(operator.scalar)
    if isinstance(operator, lx.NegLinearOperator):
        return frobenius_norm(operator.operator)
    if stochastic:
        return _frobenius_stochastic(operator, num_probes, key, sampler)
    mat = operator.as_matrix()
    return jnp.sqrt(jnp.sum(mat * mat))

submatrix(operator: lx.AbstractLinearOperator, row_idx: Int[Array, ' R'], col_idx: Int[Array, ' C']) -> Float[Array, 'R C']

Extract A[row_idx, col_idx] without forming the full matrix.

For structured operators, exploits the structure to avoid materializing the full (N, N) matrix when only a sub-block is needed (e.g., the conditional Gaussian extracts Sigma_AA, Sigma_AB, Sigma_BB from a joint covariance).

Currently dispatches on:

  • lineax.DiagonalLinearOperator
  • lineax.TaggedLinearOperator (delegates to the wrapped operator)
  • gaussx.BlockDiag
  • gaussx.Kronecker

Falls back to operator.as_matrix()[ix_(row_idx, col_idx)] for other operators.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

Linear operator A, shape (N, N).

required
row_idx Int[Array, ' R']

Row indices, shape (R,).

required
col_idx Int[Array, ' C']

Column indices, shape (C,).

required

Returns:

Type Description
Float[Array, 'R C']

Dense sub-matrix A[ix_(row_idx, col_idx)] of shape (R, C).

Source code in src/gaussx/_primitives/_submatrix.py
def submatrix(
    operator: lx.AbstractLinearOperator,
    row_idx: Int[Array, " R"],
    col_idx: Int[Array, " C"],
) -> Float[Array, "R C"]:
    """Extract ``A[row_idx, col_idx]`` without forming the full matrix.

    For structured operators, exploits the structure to avoid
    materializing the full ``(N, N)`` matrix when only a sub-block is
    needed (e.g., the conditional Gaussian extracts ``Sigma_AA``,
    ``Sigma_AB``, ``Sigma_BB`` from a joint covariance).

    Currently dispatches on:

    - `lineax.DiagonalLinearOperator`
    - `lineax.TaggedLinearOperator` (delegates to the wrapped operator)
    - `gaussx.BlockDiag`
    - `gaussx.Kronecker`

    Falls back to ``operator.as_matrix()[ix_(row_idx, col_idx)]`` for
    other operators.

    Args:
        operator: Linear operator A, shape ``(N, N)``.
        row_idx: Row indices, shape ``(R,)``.
        col_idx: Column indices, shape ``(C,)``.

    Returns:
        Dense sub-matrix ``A[ix_(row_idx, col_idx)]`` of shape ``(R, C)``.
    """
    n = operator.in_size()
    row_idx = _normalize_indices(row_idx, n)
    col_idx = _normalize_indices(col_idx, n)
    if isinstance(operator, lx.DiagonalLinearOperator):
        return _submatrix_diagonal(operator, row_idx, col_idx)
    if isinstance(operator, BlockDiag):
        return _submatrix_block_diag(operator, row_idx, col_idx)
    if isinstance(operator, Kronecker):
        return _submatrix_kronecker(operator, row_idx, col_idx)
    if isinstance(operator, lx.TaggedLinearOperator):
        return submatrix(operator.operator, row_idx, col_idx)
    return operator.as_matrix()[jnp.ix_(row_idx, col_idx)]

Root decompositions

Tall-factor approximations \(RR^\top \approx A\) (and \(R^- (R^-)^\top \approx A^{-1}\)) via Cholesky, pivoted Cholesky, Lanczos, or truncated SVD — the building block for low-rank posterior sampling and BBMM-style solvers.

Structured linear algebra and Gaussian primitives for JAX.

RootDecomposition

Bases: Module

Tall factor R with R Rᵀ ≈ A semantics.

Attributes:

Name Type Description
root Float[Array, 'N k']

Tall factor with shape (N, k).

Source code in src/gaussx/_primitives/_root.py
class RootDecomposition(eqx.Module):
    r"""Tall factor ``R`` with ``R Rᵀ ≈ A`` semantics.

    Attributes:
        root: Tall factor with shape ``(N, k)``.
    """

    root: Float[Array, "N k"]

    @property
    def rank(self) -> int:
        """Number of retained root directions."""
        return self.root.shape[1]

    def matmul(self, x: Float[Array, "*B k"]) -> Float[Array, "*B N"]:
        """Right-multiply by the root factor: ``x Rᵀ``."""
        return x @ self.root.T

rank: int property

Number of retained root directions.

matmul(x: Float[Array, '*B k']) -> Float[Array, '*B N']

Right-multiply by the root factor: x Rᵀ.

Source code in src/gaussx/_primitives/_root.py
def matmul(self, x: Float[Array, "*B k"]) -> Float[Array, "*B N"]:
    """Right-multiply by the root factor: ``x Rᵀ``."""
    return x @ self.root.T

root_decomposition(operator: lx.AbstractLinearOperator, rank: int = 50, method: RootMethod = 'lanczos', key: jax.Array | None = None) -> RootDecomposition

Compute a tall factor R such that R Rᵀ ≈ A.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

Square symmetric positive-definite operator A.

required
rank int

Number of retained directions. Ignored by "cholesky", which returns the exact full-rank factor.

50
method RootMethod

Decomposition method: "lanczos", "cholesky", "pivoted_cholesky", or "svd".

'lanczos'
key Array | None

PRNG key for random-start methods.

None

Returns:

Type Description
RootDecomposition

A RootDecomposition with root shape (N, k).

Source code in src/gaussx/_primitives/_root.py
def root_decomposition(
    operator: lx.AbstractLinearOperator,
    rank: int = 50,
    method: RootMethod = "lanczos",
    key: jax.Array | None = None,
) -> RootDecomposition:
    r"""Compute a tall factor ``R`` such that ``R Rᵀ ≈ A``.

    Args:
        operator: Square symmetric positive-definite operator ``A``.
        rank: Number of retained directions. Ignored by ``"cholesky"``,
            which returns the exact full-rank factor.
        method: Decomposition method: ``"lanczos"``, ``"cholesky"``,
            ``"pivoted_cholesky"``, or ``"svd"``.
        key: PRNG key for random-start methods.

    Returns:
        A `RootDecomposition` with root shape ``(N, k)``.
    """
    _n, rank = _validate_square_rank(operator, rank, method=method)
    if isinstance(operator, lx.DiagonalLinearOperator):
        return RootDecomposition(_diagonal_root(operator, rank, method, inverse=False))
    if method == "cholesky":
        return RootDecomposition(_cholesky_root(operator))
    if method == "lanczos":
        return RootDecomposition(_eig_root(operator, rank, key, inverse=False))
    if method == "svd":
        return RootDecomposition(_svd_root(operator, rank, key, inverse=False))
    if method == "pivoted_cholesky":
        mat = _symmetric_operator_matrix(operator)
        return RootDecomposition(_pivoted_cholesky_root(mat, rank))
    raise ValueError(f"Unknown root decomposition method: {method!r}")

root_inv_decomposition(operator: lx.AbstractLinearOperator, rank: int = 50, method: RootMethod = 'lanczos', key: jax.Array | None = None) -> RootDecomposition

Compute a tall factor R⁻ such that R⁻ (R⁻)ᵀ ≈ A⁻¹.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

Square symmetric positive-definite operator A.

required
rank int

Number of retained directions. Ignored by "cholesky", which returns the exact full-rank inverse-root factor.

50
method RootMethod

Decomposition method: "lanczos", "cholesky", "pivoted_cholesky", or "svd".

'lanczos'
key Array | None

PRNG key for random-start methods.

None

Returns:

Type Description
RootDecomposition

A RootDecomposition with inverse-root shape (N, k).

Source code in src/gaussx/_primitives/_root.py
def root_inv_decomposition(
    operator: lx.AbstractLinearOperator,
    rank: int = 50,
    method: RootMethod = "lanczos",
    key: jax.Array | None = None,
) -> RootDecomposition:
    r"""Compute a tall factor ``R⁻`` such that ``R⁻ (R⁻)ᵀ ≈ A⁻¹``.

    Args:
        operator: Square symmetric positive-definite operator ``A``.
        rank: Number of retained directions. Ignored by ``"cholesky"``,
            which returns the exact full-rank inverse-root factor.
        method: Decomposition method: ``"lanczos"``, ``"cholesky"``,
            ``"pivoted_cholesky"``, or ``"svd"``.
        key: PRNG key for random-start methods.

    Returns:
        A `RootDecomposition` with inverse-root shape ``(N, k)``.
    """
    n, rank = _validate_square_rank(operator, rank, method=method)
    if isinstance(operator, lx.DiagonalLinearOperator):
        return RootDecomposition(_diagonal_root(operator, rank, method, inverse=True))
    if method == "cholesky":
        L = _cholesky_root(operator)
        identity = jnp.eye(n, dtype=L.dtype)
        L_inv = jax.scipy.linalg.solve_triangular(L, identity, lower=True)
        return RootDecomposition(L_inv.T)
    if method == "lanczos":
        return RootDecomposition(_eig_root(operator, rank, key, inverse=True))
    if method == "svd":
        return RootDecomposition(_svd_root(operator, rank, key, inverse=True))
    if method == "pivoted_cholesky":
        # Dense fallback: forms A^{-1} explicitly via the lazy inverse
        # then runs pivoted Cholesky on it — O(N^3) time, O(N^2) memory.
        # Prefer ``method="lanczos"`` for large operators where structural
        # solves/matvecs are available.
        inv_mat = _symmetric_operator_matrix(inv(operator))
        return RootDecomposition(_pivoted_cholesky_root(inv_mat, rank))
    raise ValueError(f"Unknown inverse-root decomposition method: {method!r}")

Support types

Structured linear algebra and Gaussian primitives for JAX.

SumKroneckerSqrt

Bases: SqrtOperator

Lazy Lanczos square-root operator for SumKronecker covariances.

Specialization of SqrtOperator that narrows original to a SumKronecker operator. mv computes sqrt(A) v via matfree's Lanczos matrix-function product without materializing the full square root.

Parameters:

Name Type Description Default
original SumKronecker

The SumKronecker covariance to take the square root of.

required
lanczos_order int

Number of Lanczos iterations; clamped to the operator size.

_DEFAULT_LANCZOS_ORDER
Source code in src/gaussx/_primitives/_sqrt.py
class SumKroneckerSqrt(SqrtOperator):
    """Lazy Lanczos square-root operator for ``SumKronecker`` covariances.

    Specialization of `SqrtOperator` that narrows ``original`` to a
    `SumKronecker` operator. `mv` computes ``sqrt(A) v`` via
    matfree's Lanczos matrix-function product without materializing the full
    square root.

    Args:
        original: The ``SumKronecker`` covariance to take the square root of.
        lanczos_order: Number of Lanczos iterations; clamped to the operator
            size.
    """

    original: SumKronecker

    def __init__(
        self,
        original: SumKronecker,
        lanczos_order: int = _DEFAULT_LANCZOS_ORDER,
    ) -> None:
        super().__init__(original, lanczos_order=lanczos_order)

DenseFallbackWarning

Bases: UserWarning

Warning emitted when a structured primitive materialises an operator.

Source code in src/gaussx/_primitives/_cholesky.py
class DenseFallbackWarning(UserWarning):
    """Warning emitted when a structured primitive materialises an operator."""