Skip to content

Linear-Algebra Utilities

Numerically careful building blocks shared by the higher layers: robust factorizations, classical matrix identities in operator form, and batched / matrix-RHS solve helpers.

Robust factorization & hygiene

safe_cholesky retries with geometrically growing diagonal jitter inside a jax.lax.while_loop (JIT-compatible) when a matrix is not numerically positive-definite; symmetrize removes the floating-point asymmetry that accumulates in covariance updates.

Structured linear algebra and Gaussian primitives for JAX.

safe_cholesky(operator: lx.AbstractLinearOperator, *, initial_jitter: float = 1e-08, max_jitter: float = 0.01, max_retries: int = 5, growth_factor: float = 10.0) -> Float[Array, 'N N']

Cholesky decomposition with adaptive jitter for near-singular matrices.

The first attempt routes through gaussx._primitives.cholesky, so structured operators (DiagonalLinearOperator, BlockDiag, Kronecker, BlockTriDiag) keep their structure on the happy path. If the result contains NaNs (the matrix is not numerically positive-definite), retries with geometrically increasing diagonal jitter: cholesky(A + eps * I) where eps starts at initial_jitter and grows by growth_factor each retry, up to max_jitter. Jittering in general destroys structure, so retries operate on the dense matrix form.

Uses jax.lax.while_loop internally so the function is fully JIT-compatible.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A lineax linear operator whose Cholesky factor is required. Must be square and positive-definite.

required
initial_jitter float

Starting jitter magnitude added to the diagonal.

1e-08
max_jitter float

Upper bound on jitter (clamped after growth).

0.01
max_retries int

Maximum number of jittered retries after the initial attempt.

5
growth_factor float

Multiplicative factor applied to jitter each retry.

10.0

Returns:

Name Type Description
Float[Array, 'N N']

Lower-triangular Cholesky factor as a dense array.

Float[Array, 'N N']

If all attempts fail the result will contain NaNs — this is

intentional Float[Array, 'N N']

JAX cannot raise exceptions inside jit-traced

Float[Array, 'N N']

code, so callers should check for NaNs when robustness matters.

Source code in src/gaussx/_linalg/_safe_cholesky.py
def safe_cholesky(
    operator: lx.AbstractLinearOperator,
    *,
    initial_jitter: float = 1e-8,
    max_jitter: float = 1e-2,
    max_retries: int = 5,
    growth_factor: float = 10.0,
) -> Float[Array, "N N"]:
    """Cholesky decomposition with adaptive jitter for near-singular matrices.

    The first attempt routes through `gaussx._primitives.cholesky`, so
    structured operators (``DiagonalLinearOperator``, ``BlockDiag``,
    ``Kronecker``, ``BlockTriDiag``) keep their structure on the happy path.
    If the result contains NaNs (the matrix is not numerically
    positive-definite), retries with geometrically increasing diagonal jitter:
    ``cholesky(A + eps * I)`` where *eps* starts at ``initial_jitter`` and
    grows by ``growth_factor`` each retry, up to ``max_jitter``. Jittering in
    general destroys structure, so retries operate on the dense matrix form.

    Uses ``jax.lax.while_loop`` internally so the function is fully
    JIT-compatible.

    Args:
        operator: A lineax linear operator whose Cholesky factor is
            required. Must be square and positive-definite.
        initial_jitter: Starting jitter magnitude added to the diagonal.
        max_jitter: Upper bound on jitter (clamped after growth).
        max_retries: Maximum number of jittered retries after the initial
            attempt.
        growth_factor: Multiplicative factor applied to jitter each retry.

    Returns:
        Lower-triangular Cholesky factor as a dense array.
        If all attempts fail the result will contain NaNs — this is
        intentional: JAX cannot raise exceptions inside ``jit``-traced
        code, so callers should check for NaNs when robustness matters.
    """
    # Structured first attempt — preserves operator structure where possible.
    L0 = _chol_matrix(operator)
    has_nan0 = jnp.any(jnp.isnan(L0))

    A = operator.as_matrix()
    n = A.shape[0]
    eye = jnp.eye(n, dtype=A.dtype)

    # State: (L, jitter, retry_count, still_bad)
    init_state = (L0, initial_jitter, 0, has_nan0)

    def _cond(state):
        _, _, count, still_bad = state
        return still_bad & (count < max_retries)

    def _body(state):
        _, eps, count, _ = state
        jittered = lx.MatrixLinearOperator(A + eps * eye, lx.positive_semidefinite_tag)
        L = _chol_matrix(jittered)
        has_nan = jnp.any(jnp.isnan(L))
        next_eps = jnp.minimum(eps * growth_factor, max_jitter)
        return (L, next_eps, count + 1, has_nan)

    L_final, _, _, _ = jax.lax.while_loop(_cond, _body, init_state)
    return L_final

symmetrize(mat: Float[Array, '... n n']) -> Float[Array, '... n n']

Symmetrize the trailing two axes: 0.5 * (X + X^T).

Eliminates residual floating-point asymmetry after covariance updates. Works on batched stacks (..., n, n).

Parameters:

Name Type Description Default
mat Float[Array, '... n n']

Square matrix or batched stack of square matrices.

required

Returns:

Type Description
Float[Array, '... n n']

The symmetric part of mat.

Source code in src/gaussx/_linalg/_symmetrize.py
def symmetrize(mat: Float[Array, "... n n"]) -> Float[Array, "... n n"]:
    """Symmetrize the trailing two axes: ``0.5 * (X + X^T)``.

    Eliminates residual floating-point asymmetry after covariance
    updates. Works on batched stacks ``(..., n, n)``.

    Args:
        mat: Square matrix or batched stack of square matrices.

    Returns:
        The symmetric part of ``mat``.
    """
    return 0.5 * (mat + jnp.swapaxes(mat, -1, -2))

Matrix identities

Woodbury, Schur complements, and conditional (Schur-complement) variances — the identities behind every Gaussian conditioning step, exposed directly so recipes never re-derive them.

Structured linear algebra and Gaussian primitives for JAX.

woodbury_solve(base: lx.AbstractLinearOperator, U: Float[Array, 'N k'], D: Float[Array, ' k'], b: Float[Array, ' N']) -> Float[Array, ' N']

Standalone Woodbury identity solve: (L + U diag(D) U^T)^{-1} b.

Convenience function for cases where the user has the components but doesn't want to construct a LowRankUpdate operator.

Uses the identity:

(L + U D U^T)^{-1} b = L^{-1}b - L^{-1}U C^{-1} U^T L^{-1}b

where C = D^{-1} + U^T L^{-1} U is the (k, k) capacitance matrix.

Parameters:

Name Type Description Default
base AbstractLinearOperator

Base operator L, shape (N, N).

required
U Float[Array, 'N k']

Low-rank factor, shape (N, k).

required
D Float[Array, ' k']

Diagonal scaling, shape (k,).

required
b Float[Array, ' N']

Right-hand side, shape (N,).

required

Returns:

Type Description
Float[Array, ' N']

Solution x, shape (N,).

Source code in src/gaussx/_linalg/_woodbury.py
def woodbury_solve(
    base: lx.AbstractLinearOperator,
    U: Float[Array, "N k"],
    D: Float[Array, " k"],
    b: Float[Array, " N"],
) -> Float[Array, " N"]:
    """Standalone Woodbury identity solve: ``(L + U diag(D) U^T)^{-1} b``.

    Convenience function for cases where the user has the components
    but doesn't want to construct a ``LowRankUpdate`` operator.

    Uses the identity:

        (L + U D U^T)^{-1} b = L^{-1}b - L^{-1}U C^{-1} U^T L^{-1}b

    where ``C = D^{-1} + U^T L^{-1} U`` is the ``(k, k)`` capacitance
    matrix.

    Args:
        base: Base operator L, shape ``(N, N)``.
        U: Low-rank factor, shape ``(N, k)``.
        D: Diagonal scaling, shape ``(k,)``.
        b: Right-hand side, shape ``(N,)``.

    Returns:
        Solution x, shape ``(N,)``.
    """
    from gaussx._operators._low_rank_update import LowRankUpdate

    op = LowRankUpdate(base, U, D)
    return solve(op, b)

schur_complement(K_XX: lx.AbstractLinearOperator, K_XZ: Float[Array, 'N M'], K_ZZ: lx.AbstractLinearOperator) -> LowRankUpdate

Schur complement: K_XX - K_XZ @ K_ZZ^{-1} @ K_ZX.

Central to GP conditional distributions. Returns a LowRankUpdate operator so that downstream operations (solve, logdet) can exploit the low-rank structure.

Parameters:

Name Type Description Default
K_XX AbstractLinearOperator

Prior covariance, shape (N, N).

required
K_XZ Float[Array, 'N M']

Cross-covariance, shape (N, M).

required
K_ZZ AbstractLinearOperator

Inducing covariance, shape (M, M).

required

Returns:

Type Description
LowRankUpdate

A LowRankUpdate representing K_XX - K_XZ K_ZZ^{-1} K_ZX.

Source code in src/gaussx/_linalg/_schur.py
def schur_complement(
    K_XX: lx.AbstractLinearOperator,
    K_XZ: Float[Array, "N M"],
    K_ZZ: lx.AbstractLinearOperator,
) -> LowRankUpdate:
    """Schur complement: ``K_XX - K_XZ @ K_ZZ^{-1} @ K_ZX``.

    Central to GP conditional distributions. Returns a
    ``LowRankUpdate`` operator so that downstream operations
    (solve, logdet) can exploit the low-rank structure.

    Args:
        K_XX: Prior covariance, shape ``(N, N)``.
        K_XZ: Cross-covariance, shape ``(N, M)``.
        K_ZZ: Inducing covariance, shape ``(M, M)``.

    Returns:
        A ``LowRankUpdate`` representing K_XX - K_XZ K_ZZ^{-1} K_ZX.
    """
    # Solve K_ZZ @ W_j = K_XZ[j, :]^T for each row j of K_XZ
    # W = K_ZZ^{-1} K_XZ^T has shape (M, N)
    # Then Schur = K_XX - K_XZ @ K_ZZ^{-1} @ K_XZ^T
    _N, M = K_XZ.shape

    # Solve K_ZZ w_j = k_xz_j for each row of K_XZ
    # W^T = vmap(solve(K_ZZ, ·))(K_XZ)  =>  (N, M)
    from gaussx._linalg._linalg import solve_rows

    W_T = solve_rows(K_ZZ, K_XZ)  # (N, M)
    W = W_T.T  # (M, N)

    # Represent as K_XX + K_XZ @ (-I_M) @ W
    # = K_XX - K_XZ @ K_ZZ^{-1} @ K_XZ^T
    # LowRankUpdate: base + U @ diag(d) @ V^T
    # U = K_XZ (N, M), d = -ones(M), V = W^T (N, M)
    d = -jnp.ones(M)
    return LowRankUpdate(K_XX, K_XZ, d, W.T)

conditional_variance(K_XX_diag: Float[Array, ' N'], K_XZ: Float[Array, 'N M'] | lx.AbstractLinearOperator | None = None, A_X: Float[Array, 'N M'] | None = None, S_u: lx.AbstractLinearOperator | None = None) -> Float[Array, ' N']

Predictive variance: Schur complement diagonal plus optional variational correction.

Computes the diagonal of the conditional covariance:

diag(K_XX - A_X K_XZ^T) + diag(A_X S_u A_X^T)

where A_X = K_XZ K_ZZ^{-1} is the projection matrix. Negative base variances are clamped to zero.

Without S_u this is the exact GP predictive variance (diagonal of the Schur complement). With S_u it is the sparse GP variational predictive variance that adds a variational covariance correction.

Parameters:

Name Type Description Default
K_XX_diag Float[Array, ' N']

Prior diagonal variances, shape (N,).

required
K_XZ Float[Array, 'N M'] | AbstractLinearOperator | None

Cross-covariance matrix, shape (N, M).

None
A_X Float[Array, 'N M'] | None

Projection matrix K_XZ K_ZZ^{-1}, shape (N, M).

None
S_u AbstractLinearOperator | None

Optional variational covariance, shape (M, M). When provided, adds diag(A_X S_u A_X^T) to the base variance.

None

Returns:

Type Description
Float[Array, ' N']

Predictive variances, shape (N,).

Note

For one release the legacy three-positional-argument form conditional_variance(base_diag, A_X, S_u) (where base_diag was the Schur diagonal already, A_X was the projection, and S_u was the variational covariance) is still accepted: it is detected when the second positional argument is a lineax.AbstractLinearOperator (the old S_u slot type). Such calls emit a DeprecationWarning and compute base_diag + diag(A_X S_u A_X^T) without the Schur subtraction.

Source code in src/gaussx/_linalg/_schur.py
def conditional_variance(
    K_XX_diag: Float[Array, " N"],
    K_XZ: Float[Array, "N M"] | lx.AbstractLinearOperator | None = None,
    A_X: Float[Array, "N M"] | None = None,
    S_u: lx.AbstractLinearOperator | None = None,
) -> Float[Array, " N"]:
    """Predictive variance: Schur complement diagonal plus optional variational
    correction.

    Computes the diagonal of the conditional covariance:

        diag(K_XX - A_X K_XZ^T) + diag(A_X S_u A_X^T)

    where ``A_X = K_XZ K_ZZ^{-1}`` is the projection matrix.
    Negative base variances are clamped to zero.

    Without ``S_u`` this is the exact GP predictive variance (diagonal of the
    Schur complement).  With ``S_u`` it is the sparse GP variational predictive
    variance that adds a variational covariance correction.

    Args:
        K_XX_diag: Prior diagonal variances, shape ``(N,)``.
        K_XZ: Cross-covariance matrix, shape ``(N, M)``.
        A_X: Projection matrix ``K_XZ K_ZZ^{-1}``, shape ``(N, M)``.
        S_u: Optional variational covariance, shape ``(M, M)``.  When
            provided, adds ``diag(A_X S_u A_X^T)`` to the base variance.

    Returns:
        Predictive variances, shape ``(N,)``.

    Note:
        For one release the legacy three-positional-argument form
        ``conditional_variance(base_diag, A_X, S_u)`` (where
        ``base_diag`` was the *Schur* diagonal already, ``A_X`` was the
        projection, and ``S_u`` was the variational covariance) is
        still accepted: it is detected when the second positional
        argument is a `lineax.AbstractLinearOperator` (the old
        ``S_u`` slot type). Such calls emit a
        `DeprecationWarning` and compute
        ``base_diag + diag(A_X S_u A_X^T)`` without the
        Schur subtraction.
    """
    # Backwards-compat: the legacy 3-positional signature was
    # ``conditional_variance(base_diag, A_X, S_u)``. The pre-#152 docs
    # called it as ``conditional_variance(adjusted_diag, A, S_u)`` with
    # the second arg as the projection matrix and the third as the
    # variational covariance operator. We can detect that pattern by
    # the third positional being an AbstractLinearOperator (it
    # corresponds to the old ``S_u``) while ``S_u`` is left at the
    # default ``None``.
    if (
        S_u is None
        and isinstance(A_X, lx.AbstractLinearOperator)
        and isinstance(K_XZ, jax.Array)
    ):
        import warnings

        warnings.warn(
            "conditional_variance(base_diag, A_X, S_u) is deprecated; "
            "use conditional_variance(K_XX_diag, K_XZ, A_X, S_u=S_u). "
            "The legacy form treats the first argument as the "
            "precomputed Schur diagonal and skips the K_XZ-based "
            "subtraction.",
            DeprecationWarning,
            stacklevel=2,
        )
        legacy_A_X = K_XZ  # second positional was the projection matrix
        legacy_S_u: lx.AbstractLinearOperator = A_X  # third was S_u
        S_mat = legacy_S_u.as_matrix()
        AS = legacy_A_X @ S_mat
        diag_ASAt = jnp.sum(AS * legacy_A_X, axis=1)
        return jnp.clip(K_XX_diag, 0.0) + diag_ASAt

    if K_XZ is None or A_X is None:
        raise TypeError(
            "conditional_variance requires K_XZ and A_X (or the legacy "
            f"(base_diag, A_X, S_u) call). Got K_XZ={K_XZ!r}, "
            f"A_X={A_X!r}, S_u={S_u!r}."
        )
    if not isinstance(K_XZ, jax.Array):
        raise TypeError(
            "K_XZ must be a jax array; the legacy 3-positional form "
            "passes an AbstractLinearOperator as the third positional "
            "argument."
        )

    # Diagonal of Schur complement: diag(K_XX - K_XZ K_ZZ^{-1} K_ZX)
    base = jnp.clip(K_XX_diag - jnp.sum(A_X * K_XZ, axis=1), 0.0)
    if S_u is None:
        return base
    # Variational correction: diag(A_X S_u A_X^T)
    S_mat = S_u.as_matrix()
    AS = A_X @ S_mat  # (N, M)
    diag_ASAt = jnp.sum(AS * A_X, axis=1)  # (N,)
    return base + diag_ASAt

diag_conditional_variance(K_XX_diag: Float[Array, ' N'], K_XZ: Float[Array, 'N M'], A_X: Float[Array, 'N M']) -> Float[Array, ' N']

Diagonal of Schur complement: diag(K_XX - A K_ZX).

Thin wrapper around gaussx.conditional_variance (defined in gaussx._linalg._schur.conditional_variance) without a variational covariance. Use gaussx.conditional_variance directly when a variational correction S_u is also needed.

Parameters:

Name Type Description Default
K_XX_diag Float[Array, ' N']

Prior diagonal variances, shape (N,).

required
K_XZ Float[Array, 'N M']

Cross-covariance, shape (N, M).

required
A_X Float[Array, 'N M']

Projection matrix K_XZ K_ZZ^{-1}, shape (N, M).

required

Returns:

Type Description
Float[Array, ' N']

Conditional variances, shape (N,).

Source code in src/gaussx/_linalg/_linalg.py
def diag_conditional_variance(
    K_XX_diag: Float[Array, " N"],
    K_XZ: Float[Array, "N M"],
    A_X: Float[Array, "N M"],
) -> Float[Array, " N"]:
    """Diagonal of Schur complement: ``diag(K_XX - A K_ZX)``.

    Thin wrapper around `gaussx.conditional_variance` (defined in
    `gaussx._linalg._schur.conditional_variance`) without a
    variational covariance.  Use `gaussx.conditional_variance`
    directly when a variational correction ``S_u`` is also needed.

    Args:
        K_XX_diag: Prior diagonal variances, shape ``(N,)``.
        K_XZ: Cross-covariance, shape ``(N, M)``.
        A_X: Projection matrix ``K_XZ K_ZZ^{-1}``, shape ``(N, M)``.

    Returns:
        Conditional variances, shape ``(N,)``.
    """
    return _conditional_variance(K_XX_diag, K_XZ, A_X)

cov_transform(J: Float[Array, 'M N'] | lx.AbstractLinearOperator, cov_operator: lx.AbstractLinearOperator) -> lx.AbstractLinearOperator

cov_transform(J: Float[Array, 'M N'], cov_operator: lx.AbstractLinearOperator) -> lx.MatrixLinearOperator
cov_transform(J: lx.AbstractLinearOperator, cov_operator: lx.AbstractLinearOperator) -> lx.AbstractLinearOperator

Covariance propagation through a linear map: J @ Sigma @ J^T.

Used in error propagation, Kalman filter updates, and first-order uncertainty propagation.

Exploits structure where it can:

  • Operator-valued J: routes through sandwich, which preserves matched Kronecker / BlockDiag structure and avoids materialising Sigma when either J or cov_operator is diagonal.
  • Diagonal cov_operator (dense J): computes (J * d) @ J^T directly, skipping the (N, N) materialization of Sigma.

Otherwise materializes Sigma and forms the dense product. The returned operator is tagged symmetric (and positive-semidefinite when the input is).

Parameters:

Name Type Description Default
J Float[Array, 'M N'] | AbstractLinearOperator

Jacobian or linear map, shape (M, N) — array or operator.

required
cov_operator AbstractLinearOperator

Input covariance, shape (N, N).

required

Returns:

Type Description
AbstractLinearOperator

Transformed covariance operator, shape (M, M). For

AbstractLinearOperator

operator-valued J the structural class of the return type

AbstractLinearOperator

follows sandwich; otherwise it is a

AbstractLinearOperator

lineax.MatrixLinearOperator.

Source code in src/gaussx/_linalg/_linalg.py
def cov_transform(
    J: Float[Array, "M N"] | lx.AbstractLinearOperator,
    cov_operator: lx.AbstractLinearOperator,
) -> lx.AbstractLinearOperator:
    """Covariance propagation through a linear map: ``J @ Sigma @ J^T``.

    Used in error propagation, Kalman filter updates, and
    first-order uncertainty propagation.

    Exploits structure where it can:

    - **Operator-valued** ``J``: routes through `sandwich`, which
      preserves matched ``Kronecker`` / ``BlockDiag`` structure and
      avoids materialising ``Sigma`` when either ``J`` or ``cov_operator``
      is diagonal.
    - **Diagonal** ``cov_operator`` (dense ``J``): computes
      ``(J * d) @ J^T`` directly, skipping the ``(N, N)``
      materialization of ``Sigma``.

    Otherwise materializes ``Sigma`` and forms the dense product. The
    returned operator is tagged symmetric (and positive-semidefinite
    when the input is).

    Args:
        J: Jacobian or linear map, shape ``(M, N)`` — array or operator.
        cov_operator: Input covariance, shape ``(N, N)``.

    Returns:
        Transformed covariance operator, shape ``(M, M)``. For
        operator-valued ``J`` the structural class of the return type
        follows `sandwich`; otherwise it is a
        `lineax.MatrixLinearOperator`.
    """
    if isinstance(J, lx.AbstractLinearOperator):
        return sandwich(J, cov_operator)

    tags = _sandwich_tags(cov_operator)

    if isinstance(cov_operator, lx.DiagonalLinearOperator):
        d = lx.diagonal(cov_operator)
        result = (J * d[None, :]) @ J.T
        return lx.MatrixLinearOperator(result, tags)

    Sigma = cov_operator.as_matrix()
    result = J @ Sigma @ J.T
    return lx.MatrixLinearOperator(result, tags)

sandwich(A: lx.AbstractLinearOperator, P: lx.AbstractLinearOperator) -> lx.AbstractLinearOperator

Return A @ P @ A.T exploiting compatible operator structure.

Parameters:

Name Type Description Default
A AbstractLinearOperator

Linear map with shape (M, N).

required
P AbstractLinearOperator

Covariance operator with shape (N, N).

required

Returns:

Type Description
AbstractLinearOperator

Transformed covariance operator with shape (M, M).

Examples:

A = gaussx.Kronecker(A1, A2)
P = gaussx.Kronecker(P1, P2)
S = gaussx.sandwich(A, P)
Source code in src/gaussx/_linalg/_linalg.py
def sandwich(
    A: lx.AbstractLinearOperator,
    P: lx.AbstractLinearOperator,
) -> lx.AbstractLinearOperator:
    """Return ``A @ P @ A.T`` exploiting compatible operator structure.

    Args:
        A: Linear map with shape ``(M, N)``.
        P: Covariance operator with shape ``(N, N)``.

    Returns:
        Transformed covariance operator with shape ``(M, M)``.

    Examples:
        ```python
        A = gaussx.Kronecker(A1, A2)
        P = gaussx.Kronecker(P1, P2)
        S = gaussx.sandwich(A, P)
        ```
    """
    _check_sandwich_shapes(A, P)
    tags = _sandwich_tags(P)

    from gaussx._operators._block_diag import BlockDiag
    from gaussx._operators._kronecker import Kronecker

    if (
        isinstance(A, Kronecker)
        and isinstance(P, Kronecker)
        and len(A.operators) == len(P.operators)
        and all(
            a.in_size() == p.in_size() and p.in_size() == p.out_size()
            for a, p in zip(A.operators, P.operators, strict=True)
        )
    ):
        return Kronecker(
            *(sandwich(a, p) for a, p in zip(A.operators, P.operators, strict=True)),
            tags=tags,
        )

    if (
        isinstance(A, BlockDiag)
        and isinstance(P, BlockDiag)
        and len(A.operators) == len(P.operators)
        and all(
            a.in_size() == p.in_size() and p.in_size() == p.out_size()
            for a, p in zip(A.operators, P.operators, strict=True)
        )
    ):
        return BlockDiag(
            *(sandwich(a, p) for a, p in zip(A.operators, P.operators, strict=True)),
            tags=tags,
        )

    if isinstance(A, lx.DiagonalLinearOperator):
        d = lx.diagonal(A)
        if isinstance(P, lx.DiagonalLinearOperator):
            return lx.TaggedLinearOperator(
                lx.DiagonalLinearOperator(d * lx.diagonal(P) * d),
                tags,
            )
        P_mat = P.as_matrix()
        return lx.MatrixLinearOperator(d[:, None] * P_mat * d[None, :], tags)

    if isinstance(P, lx.DiagonalLinearOperator):
        d = lx.diagonal(P)
        A_mat = A.as_matrix()
        return lx.MatrixLinearOperator((A_mat * d[None, :]) @ A_mat.T, tags)

    A_mat = A.as_matrix()
    P_mat = P.as_matrix()
    return lx.MatrixLinearOperator(A_mat @ P_mat @ A_mat.T, tags)

trace_product(A: lx.AbstractLinearOperator, B: lx.AbstractLinearOperator) -> Float[Array, '']

Trace of a matrix product: tr(A @ B) with structural dispatch.

Uses operator structure where possible to avoid materialization:

  • Both diagonal: sum(diag(A) * diag(B)).
  • Diagonal × general (or vice versa): contract the diagonal with the diagonal of the other operator (no full materialization).
  • Matched gaussx.BlockDiag (same block sizes): sum of per-block trace_product.
  • Matched gaussx.Kronecker (same factor structure): prod_i tr(A_i @ B_i).
  • Otherwise falls back to sum(A * B^T) on the materialized matrices — the same O(N²) cost the previous implementation paid.

Parameters:

Name Type Description Default
A AbstractLinearOperator

Linear operator, shape (N, N).

required
B AbstractLinearOperator

Linear operator, shape (N, N).

required

Returns:

Type Description
Float[Array, '']

Scalar tr(A @ B).

Source code in src/gaussx/_linalg/_linalg.py
def trace_product(
    A: lx.AbstractLinearOperator,
    B: lx.AbstractLinearOperator,
) -> Float[Array, ""]:
    """Trace of a matrix product: ``tr(A @ B)`` with structural dispatch.

    Uses operator structure where possible to avoid materialization:

    - **Both diagonal**: ``sum(diag(A) * diag(B))``.
    - **Diagonal × general** (or vice versa): contract the diagonal with
      the diagonal of the other operator (no full materialization).
    - **Matched** `gaussx.BlockDiag` (same block sizes): sum of
      per-block ``trace_product``.
    - **Matched** `gaussx.Kronecker` (same factor structure):
      ``prod_i tr(A_i @ B_i)``.
    - Otherwise falls back to ``sum(A * B^T)`` on the materialized
      matrices — the same O(N²) cost the previous implementation paid.

    Args:
        A: Linear operator, shape ``(N, N)``.
        B: Linear operator, shape ``(N, N)``.

    Returns:
        Scalar ``tr(A @ B)``.
    """
    from gaussx._operators._block_diag import BlockDiag
    from gaussx._operators._kronecker import Kronecker
    from gaussx._primitives._diag import diag

    # Both diagonal: O(N) inner product of diagonals.
    if isinstance(A, lx.DiagonalLinearOperator) and isinstance(
        B, lx.DiagonalLinearOperator
    ):
        return jnp.sum(lx.diagonal(A) * lx.diagonal(B))

    # Diagonal × anything: tr(D @ B) = sum(diag(D) * diag(B)).
    if isinstance(A, lx.DiagonalLinearOperator):
        return jnp.sum(lx.diagonal(A) * diag(B))
    if isinstance(B, lx.DiagonalLinearOperator):
        return jnp.sum(diag(A) * lx.diagonal(B))

    # Matched BlockDiag: tr(blockdiag(A_i) @ blockdiag(B_i)) = sum_i tr(A_i @ B_i).
    if (
        isinstance(A, BlockDiag)
        and isinstance(B, BlockDiag)
        and len(A.operators) == len(B.operators)
        and all(
            a.in_size() == b.in_size()
            for a, b in zip(A.operators, B.operators, strict=True)
        )
    ):
        parts = [
            trace_product(a, b) for a, b in zip(A.operators, B.operators, strict=True)
        ]
        return jnp.sum(jnp.stack(parts))

    # Matched Kronecker: tr((A1⊗A2⊗…) @ (B1⊗B2⊗…)) = prod_i tr(A_i @ B_i).
    if (
        isinstance(A, Kronecker)
        and isinstance(B, Kronecker)
        and len(A.operators) == len(B.operators)
        and all(
            a.in_size() == b.in_size()
            for a, b in zip(A.operators, B.operators, strict=True)
        )
    ):
        parts = [
            trace_product(a, b) for a, b in zip(A.operators, B.operators, strict=True)
        ]
        return jnp.prod(jnp.stack(parts))

    return _trace_product_dense(A, B)

diag_inv(operator: lx.AbstractLinearOperator, *, method: str = 'auto', num_probes: int = 30, key: jax.Array | None = None, solver: AbstractSolveStrategy | None = None) -> Float[Array, ' N']

Compute the diagonal of the inverse of a linear operator.

Returns diag(A⁻¹) without forming the full inverse matrix.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A linear operator representing A.

required
method str

Algorithm to use. One of "cholesky" (exact via dense Cholesky), "solve" (exact via repeated solves), "hutchinson" (stochastic estimator), or "auto" (cholesky for N ≤ 2048, hutchinson otherwise).

'auto'
num_probes int

Number of Rademacher probe vectors for the hutchinson method.

30
key Array | None

PRNG key for probe generation in the hutchinson method. When None, defaults to jax.random.PRNGKey(0).

None
solver AbstractSolveStrategy | None

Optional solve strategy for "solve" and "hutchinson" methods.

None

Returns:

Type Description
Float[Array, ' N']

1D array of shape (N,) with the diagonal entries of A⁻¹.

Source code in src/gaussx/_linalg/_diag_inv.py
def diag_inv(
    operator: lx.AbstractLinearOperator,
    *,
    method: str = "auto",
    num_probes: int = 30,
    key: jax.Array | None = None,
    solver: AbstractSolveStrategy | None = None,
) -> Float[Array, " N"]:
    """Compute the diagonal of the inverse of a linear operator.

    Returns ``diag(A⁻¹)`` without forming the full inverse matrix.

    Args:
        operator: A linear operator representing A.
        method: Algorithm to use. One of ``"cholesky"`` (exact via
            dense Cholesky), ``"solve"`` (exact via repeated solves),
            ``"hutchinson"`` (stochastic estimator),
            or ``"auto"`` (cholesky for N ≤ 2048, hutchinson otherwise).
        num_probes: Number of Rademacher probe vectors for the
            hutchinson method.
        key: PRNG key for probe generation in the hutchinson method.
            When ``None``, defaults to ``jax.random.PRNGKey(0)``.
        solver: Optional solve strategy for ``"solve"`` and
            ``"hutchinson"`` methods.

    Returns:
        1D array of shape ``(N,)`` with the diagonal entries of A⁻¹.
    """
    n = operator.in_size()

    if method == "auto":
        method = "cholesky" if n <= 2048 else "hutchinson"

    if method == "cholesky":
        return _diag_inv_cholesky(operator)
    if method == "solve":
        return _diag_inv_solve(operator, solver=solver)
    if method == "hutchinson":
        return _diag_inv_hutchinson(
            operator, num_probes=num_probes, key=key, solver=solver
        )

    msg = (
        f"Unknown method {method!r}; expected 'cholesky', 'solve', "
        "'hutchinson', or 'auto'."
    )
    raise ValueError(msg)

Matrix-RHS & batched solves

Solve \(AX = B\) for matrix right-hand sides with one factorization (solve_matrix), per-column or per-row structured dispatch (solve_columns / solve_rows), or \(O(n)\) Thomas-algorithm tridiagonal solves.

Structured linear algebra and Gaussian primitives for JAX.

solve_matrix(operator: lx.AbstractLinearOperator, matrix: Float[Array, 'N K'], *, solver: AbstractSolveStrategy | None = None) -> Float[Array, 'N K']

Solve A X = B with a single factorization on the matrix RHS.

When solver is None and A is positive semidefinite, this factors A = L L^T once via gaussx.cholesky and then uses a single cho_solve on the full matrix RHS — avoiding the per-column re-factorization incurred by solve_columns.

For non-PSD operators (or when a custom solver is supplied), falls back to solve_columns.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

Linear operator A, shape (N, N).

required
matrix Float[Array, 'N K']

Right-hand side B, shape (N, K).

required
solver AbstractSolveStrategy | None

Optional solver strategy. When provided, dispatch is delegated column-by-column via solve_columns.

None

Returns:

Type Description
Float[Array, 'N K']

Solution X = A⁻¹ B, shape (N, K).

Source code in src/gaussx/_linalg/_linalg.py
def solve_matrix(
    operator: lx.AbstractLinearOperator,
    matrix: Float[Array, "N K"],
    *,
    solver: AbstractSolveStrategy | None = None,
) -> Float[Array, "N K"]:
    """Solve ``A X = B`` with a single factorization on the matrix RHS.

    When ``solver`` is ``None`` and ``A`` is positive semidefinite, this
    factors ``A = L L^T`` once via `gaussx.cholesky` and then uses
    a single ``cho_solve`` on the full matrix RHS — avoiding the
    per-column re-factorization incurred by `solve_columns`.

    For non-PSD operators (or when a custom ``solver`` is supplied),
    falls back to `solve_columns`.

    Args:
        operator: Linear operator A, shape ``(N, N)``.
        matrix: Right-hand side B, shape ``(N, K)``.
        solver: Optional solver strategy. When provided, dispatch is
            delegated column-by-column via `solve_columns`.

    Returns:
        Solution X = A⁻¹ B, shape ``(N, K)``.
    """
    if solver is None and lx.is_positive_semidefinite(operator):
        L = cholesky(operator).as_matrix()
        return jax.scipy.linalg.cho_solve((L, True), matrix)
    return solve_columns(operator, matrix, solver=solver)

solve_columns(operator: lx.AbstractLinearOperator, matrix: Float[Array, 'N K'], *, solver: AbstractSolveStrategy | None = None) -> Float[Array, 'N K']

Solve A X = B column-by-column via vmap.

Equivalent to A⁻¹ @ B but uses structured dispatch per column. Use this when A is a lineax operator with efficient per-vector solve (e.g., Cholesky, diagonal, block-diagonal).

Parameters:

Name Type Description Default
operator AbstractLinearOperator

Linear operator A, shape (N, N).

required
matrix Float[Array, 'N K']

Right-hand side B, shape (N, K).

required
solver AbstractSolveStrategy | None

Optional solver strategy.

None

Returns:

Type Description
Float[Array, 'N K']

Solution X = A⁻¹ B, shape (N, K).

Source code in src/gaussx/_linalg/_linalg.py
def solve_columns(
    operator: lx.AbstractLinearOperator,
    matrix: Float[Array, "N K"],
    *,
    solver: AbstractSolveStrategy | None = None,
) -> Float[Array, "N K"]:
    """Solve A X = B column-by-column via vmap.

    Equivalent to ``A⁻¹ @ B`` but uses structured dispatch per column.
    Use this when ``A`` is a lineax operator with efficient per-vector
    solve (e.g., Cholesky, diagonal, block-diagonal).

    Args:
        operator: Linear operator A, shape ``(N, N)``.
        matrix: Right-hand side B, shape ``(N, K)``.
        solver: Optional solver strategy.

    Returns:
        Solution X = A⁻¹ B, shape ``(N, K)``.
    """
    if solver is not None:
        return jax.vmap(
            lambda col: dispatch_solve(operator, col, solver),
            in_axes=1,
            out_axes=1,
        )(matrix)
    return jax.vmap(
        lambda col: solve(operator, col),
        in_axes=1,
        out_axes=1,
    )(matrix)

solve_rows(operator: lx.AbstractLinearOperator, matrix: Float[Array, 'K N'], *, solver: AbstractSolveStrategy | None = None) -> Float[Array, 'K N']

Solve A x = bᵢ for each row bᵢ of a matrix via vmap.

Equivalent to B @ A⁻¹ row-by-row. Used in Kalman gain, Schur complement, and GP prediction computations.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

Linear operator A, shape (N, N).

required
matrix Float[Array, 'K N']

Rows of right-hand sides, shape (K, N).

required
solver AbstractSolveStrategy | None

Optional solver strategy.

None

Returns:

Type Description
Float[Array, 'K N']

Solutions, shape (K, N).

Source code in src/gaussx/_linalg/_linalg.py
def solve_rows(
    operator: lx.AbstractLinearOperator,
    matrix: Float[Array, "K N"],
    *,
    solver: AbstractSolveStrategy | None = None,
) -> Float[Array, "K N"]:
    """Solve A x = bᵢ for each row bᵢ of a matrix via vmap.

    Equivalent to ``B @ A⁻¹`` row-by-row. Used in Kalman gain, Schur
    complement, and GP prediction computations.

    Args:
        operator: Linear operator A, shape ``(N, N)``.
        matrix: Rows of right-hand sides, shape ``(K, N)``.
        solver: Optional solver strategy.

    Returns:
        Solutions, shape ``(K, N)``.
    """
    if solver is not None:
        return jax.vmap(
            lambda row: dispatch_solve(operator, row, solver),
        )(matrix)
    return jax.vmap(lambda row: solve(operator, row))(matrix)

solve_tridiagonal(lower: Float[Array, ' n_minus_1'], diag: Float[Array, ' n'], upper: Float[Array, ' n_minus_1'], rhs: Float[Array, ' n']) -> Float[Array, ' n']

Solve a tridiagonal system A x = d.

Thin wrapper over lineax.TridiagonalLinearOperator, which delegates to jax.lax.linalg.tridiagonal_solve (LAPACK / cuSPARSE).

Parameters:

Name Type Description Default
lower Float[Array, ' n_minus_1']

Sub-diagonal of A, length n - 1.

required
diag Float[Array, ' n']

Main diagonal of A, length n.

required
upper Float[Array, ' n_minus_1']

Super-diagonal of A, length n - 1.

required
rhs Float[Array, ' n']

Right-hand side, length n.

required

Returns:

Type Description
Float[Array, ' n']

The solution x, length n.

Source code in src/gaussx/_linalg/_tridiagonal.py
def solve_tridiagonal(
    lower: Float[Array, " n_minus_1"],
    diag: Float[Array, " n"],
    upper: Float[Array, " n_minus_1"],
    rhs: Float[Array, " n"],
) -> Float[Array, " n"]:
    """Solve a tridiagonal system ``A x = d``.

    Thin wrapper over `lineax.TridiagonalLinearOperator`, which delegates
    to ``jax.lax.linalg.tridiagonal_solve`` (LAPACK / cuSPARSE).

    Args:
        lower: Sub-diagonal of ``A``, length ``n - 1``.
        diag: Main diagonal of ``A``, length ``n``.
        upper: Super-diagonal of ``A``, length ``n - 1``.
        rhs: Right-hand side, length ``n``.

    Returns:
        The solution ``x``, length ``n``.
    """
    operator = lx.TridiagonalLinearOperator(diag, lower, upper)
    return lx.linear_solve(operator, rhs, solver=lx.Tridiagonal()).value

solve_tridiagonal_batched(lower: Float[Array, '*batch n_minus_1'], diag: Float[Array, '*batch n'], upper: Float[Array, '*batch n_minus_1'], rhs: Float[Array, '*batch n']) -> Float[Array, '*batch n']

Solve independent tridiagonal systems over leading batch dimensions.

Applies solve_tridiagonal vmapped over all leading dimensions.

Parameters:

Name Type Description Default
lower Float[Array, '*batch n_minus_1']

Sub-diagonals, shape (*batch, n - 1).

required
diag Float[Array, '*batch n']

Main diagonals, shape (*batch, n).

required
upper Float[Array, '*batch n_minus_1']

Super-diagonals, shape (*batch, n - 1).

required
rhs Float[Array, '*batch n']

Right-hand sides, shape (*batch, n).

required

Returns:

Type Description
Float[Array, '*batch n']

Solutions, shape (*batch, n).

Source code in src/gaussx/_linalg/_tridiagonal.py
def solve_tridiagonal_batched(
    lower: Float[Array, "*batch n_minus_1"],
    diag: Float[Array, "*batch n"],
    upper: Float[Array, "*batch n_minus_1"],
    rhs: Float[Array, "*batch n"],
) -> Float[Array, "*batch n"]:
    """Solve independent tridiagonal systems over leading batch dimensions.

    Applies `solve_tridiagonal` vmapped over all leading dimensions.

    Args:
        lower: Sub-diagonals, shape ``(*batch, n - 1)``.
        diag: Main diagonals, shape ``(*batch, n)``.
        upper: Super-diagonals, shape ``(*batch, n - 1)``.
        rhs: Right-hand sides, shape ``(*batch, n)``.

    Returns:
        Solutions, shape ``(*batch, n)``.
    """
    n_batch = diag.ndim - 1
    fn = solve_tridiagonal
    for _ in range(n_batch):
        fn = eqx.filter_vmap(fn)
    return fn(lower, diag, upper, rhs)

batched_kernel_matvec(kernel_fn: Callable[[Float[Array, ' D'], Float[Array, ' D']], Float[Array, '']], X: Float[Array, 'N D'], Z: Float[Array, 'M D'], v: Float[Array, ' M'], batch_size: int = 1024) -> Float[Array, ' N']

Compute K(X, Z) @ v in memory-efficient batches.

Each batch evaluates a (batch_size, M) kernel sub-matrix and immediately contracts with v, keeping peak memory at O(batch_size * M) instead of O(N * M).

Parameters:

Name Type Description Default
kernel_fn Callable[[Float[Array, ' D'], Float[Array, ' D']], Float[Array, '']]

Pairwise kernel function k(x, z) -> scalar.

required
X Float[Array, 'N D']

First set of points, shape (N, D).

required
Z Float[Array, 'M D']

Second set of points, shape (M, D).

required
v Float[Array, ' M']

Vector to multiply, shape (M,).

required
batch_size int

Rows of X processed per scan step.

1024

Returns:

Type Description
Float[Array, ' N']

Result of K(X, Z) @ v, shape (N,).

Source code in src/gaussx/_linalg/_batched_matvec.py
def batched_kernel_matvec(
    kernel_fn: Callable[[Float[Array, " D"], Float[Array, " D"]], Float[Array, ""]],
    X: Float[Array, "N D"],
    Z: Float[Array, "M D"],
    v: Float[Array, " M"],
    batch_size: int = 1024,
) -> Float[Array, " N"]:
    r"""Compute ``K(X, Z) @ v`` in memory-efficient batches.

    Each batch evaluates a ``(batch_size, M)`` kernel sub-matrix and
    immediately contracts with ``v``, keeping peak memory at
    ``O(batch_size * M)`` instead of ``O(N * M)``.

    Args:
        kernel_fn: Pairwise kernel function ``k(x, z) -> scalar``.
        X: First set of points, shape ``(N, D)``.
        Z: Second set of points, shape ``(M, D)``.
        v: Vector to multiply, shape ``(M,)``.
        batch_size: Rows of ``X`` processed per scan step.

    Returns:
        Result of ``K(X, Z) @ v``, shape ``(N,)``.
    """
    n = X.shape[0]
    bs = batch_size
    n_padded = ((n + bs - 1) // bs) * bs
    pad_amount = n_padded - n
    X_padded = jnp.pad(X, ((0, pad_amount), (0, 0)), mode="constant")
    X_batched = rearrange(X_padded, "(B bs) D -> B bs D", bs=bs)

    def scan_body(
        carry: None, X_batch: Float[Array, "batch_size D"]
    ) -> tuple[None, Float[Array, " batch_size"]]:
        K_batch = jax.vmap(lambda x_i: jax.vmap(lambda z_j: kernel_fn(x_i, z_j))(Z))(
            X_batch
        )
        return carry, K_batch @ v

    _, results = jax.lax.scan(scan_body, None, X_batched)
    return rearrange(results, "B bs -> (B bs)")[:n]

batched_kernel_rmatvec(kernel_fn: Callable[[Float[Array, ' D'], Float[Array, ' D']], Float[Array, '']], X: Float[Array, 'N D'], Z: Float[Array, 'M D'], u: Float[Array, ' N'], batch_size: int = 1024) -> Float[Array, ' M']

Compute K(X, Z)^T @ u in memory-efficient batches.

Scans over batches of X, building each (batch_size, M) kernel block and accumulating K_batch^T @ u_batch into an (M,) result. Peak memory per step: O(batch_size * M).

Parameters:

Name Type Description Default
kernel_fn Callable[[Float[Array, ' D'], Float[Array, ' D']], Float[Array, '']]

Pairwise kernel function k(x, z) -> scalar.

required
X Float[Array, 'N D']

First set of points, shape (N, D).

required
Z Float[Array, 'M D']

Second set of points, shape (M, D).

required
u Float[Array, ' N']

Vector to multiply, shape (N,).

required
batch_size int

Rows of X processed per scan step.

1024

Returns:

Type Description
Float[Array, ' M']

Result of K(X, Z)^T @ u, shape (M,).

Source code in src/gaussx/_linalg/_batched_matvec.py
def batched_kernel_rmatvec(
    kernel_fn: Callable[[Float[Array, " D"], Float[Array, " D"]], Float[Array, ""]],
    X: Float[Array, "N D"],
    Z: Float[Array, "M D"],
    u: Float[Array, " N"],
    batch_size: int = 1024,
) -> Float[Array, " M"]:
    r"""Compute ``K(X, Z)^T @ u`` in memory-efficient batches.

    Scans over batches of ``X``, building each ``(batch_size, M)``
    kernel block and accumulating ``K_batch^T @ u_batch`` into an
    ``(M,)`` result.  Peak memory per step: ``O(batch_size * M)``.

    Args:
        kernel_fn: Pairwise kernel function ``k(x, z) -> scalar``.
        X: First set of points, shape ``(N, D)``.
        Z: Second set of points, shape ``(M, D)``.
        u: Vector to multiply, shape ``(N,)``.
        batch_size: Rows of ``X`` processed per scan step.

    Returns:
        Result of ``K(X, Z)^T @ u``, shape ``(M,)``.
    """
    n = X.shape[0]
    m = Z.shape[0]
    bs = batch_size
    n_padded = ((n + bs - 1) // bs) * bs
    pad_amount = n_padded - n
    X_padded = jnp.pad(X, ((0, pad_amount), (0, 0)), mode="constant")
    u_padded = jnp.pad(u, (0, pad_amount), mode="constant")
    X_batched = rearrange(X_padded, "(B bs) D -> B bs D", bs=bs)
    u_batched = rearrange(u_padded, "(B bs) -> B bs", bs=bs)

    def scan_body(
        acc: Float[Array, " M"],
        xu: tuple[Float[Array, "batch_size D"], Float[Array, " batch_size"]],
    ) -> tuple[Float[Array, " M"], None]:
        X_batch, u_batch = xu
        K_batch = jax.vmap(lambda x_i: jax.vmap(lambda z_j: kernel_fn(x_i, z_j))(Z))(
            X_batch
        )
        acc = acc + K_batch.T @ u_batch
        return acc, None

    acc_dtype = jnp.result_type(X, Z, u)
    result, _ = jax.lax.scan(
        scan_body,
        jnp.zeros(m, dtype=acc_dtype),
        (X_batched, u_batched),
    )
    return result

Stable kernel arithmetic & Lyapunov

Structured linear algebra and Gaussian primitives for JAX.

stable_squared_distances(X: Float[Array, 'N D'], Z: Float[Array, 'M D'], *, compute_dtype: jnp.dtype = jnp.float32, accumulate_dtype: jnp.dtype = jnp.float64) -> Float[Array, 'N M']

Squared Euclidean distances with mixed-precision stability.

The expansion ||x - z||^2 = ||x||^2 + ||z||^2 - 2 x^T z suffers catastrophic cancellation in float32 for high-D data, producing negative distances and non-PSD kernels.

This function computes dot products in compute_dtype (fast) and performs the subtraction in accumulate_dtype (stable), then casts the result back to compute_dtype.

Parameters:

Name Type Description Default
X Float[Array, 'N D']

First set of points, shape (N, D).

required
Z Float[Array, 'M D']

Second set of points, shape (M, D).

required
compute_dtype dtype

Dtype for dot products (default float32).

float32
accumulate_dtype dtype

Dtype for subtraction (default float64).

float64

Returns:

Type Description
Float[Array, 'N M']

Squared distances, shape (N, M), guaranteed non-negative.

Source code in src/gaussx/_linalg/_mixed_precision.py
def stable_squared_distances(
    X: Float[Array, "N D"],
    Z: Float[Array, "M D"],
    *,
    compute_dtype: jnp.dtype = jnp.float32,
    accumulate_dtype: jnp.dtype = jnp.float64,
) -> Float[Array, "N M"]:
    r"""Squared Euclidean distances with mixed-precision stability.

    The expansion ``||x - z||^2 = ||x||^2 + ||z||^2 - 2 x^T z`` suffers
    catastrophic cancellation in float32 for high-D data, producing
    negative distances and non-PSD kernels.

    This function computes dot products in ``compute_dtype`` (fast) and
    performs the subtraction in ``accumulate_dtype`` (stable), then casts
    the result back to ``compute_dtype``.

    Args:
        X: First set of points, shape ``(N, D)``.
        Z: Second set of points, shape ``(M, D)``.
        compute_dtype: Dtype for dot products (default float32).
        accumulate_dtype: Dtype for subtraction (default float64).

    Returns:
        Squared distances, shape ``(N, M)``, guaranteed non-negative.
    """
    X_c = X.astype(compute_dtype)
    Z_c = Z.astype(compute_dtype)

    # Squared norms — computed in compute_dtype
    X_sq = reduce(X_c**2, "N D -> N", "sum")
    Z_sq = reduce(Z_c**2, "M D -> M", "sum")

    # Cross term — computed in compute_dtype
    cross = X_c @ Z_c.T  # (N, M)

    # Subtraction in accumulate_dtype for stability
    dist_sq = (
        X_sq[:, None].astype(accumulate_dtype)
        + Z_sq[None, :].astype(accumulate_dtype)
        - 2.0 * cross.astype(accumulate_dtype)
    )

    # Clamp and cast back
    dist_sq = jnp.maximum(dist_sq, 0.0)
    return dist_sq.astype(compute_dtype)

stable_rbf_kernel(X: Float[Array, 'N D'], Z: Float[Array, 'M D'], lengthscale: float | Float[Array, ''], variance: float | Float[Array, ''] = 1.0, *, compute_dtype: jnp.dtype = jnp.float32, accumulate_dtype: jnp.dtype = jnp.float64) -> Float[Array, 'N M']

RBF (squared exponential) kernel with mixed-precision stability.

Computes variance * exp(-0.5 * ||x - z||^2 / lengthscale^2) using stable_squared_distances for the distance computation.

Parameters:

Name Type Description Default
X Float[Array, 'N D']

First set of points, shape (N, D).

required
Z Float[Array, 'M D']

Second set of points, shape (M, D).

required
lengthscale float | Float[Array, '']

Kernel lengthscale.

required
variance float | Float[Array, '']

Kernel signal variance (default 1.0).

1.0
compute_dtype dtype

Dtype for dot products (default float32).

float32
accumulate_dtype dtype

Dtype for subtraction (default float64).

float64

Returns:

Type Description
Float[Array, 'N M']

Kernel matrix, shape (N, M).

Source code in src/gaussx/_linalg/_mixed_precision.py
def stable_rbf_kernel(
    X: Float[Array, "N D"],
    Z: Float[Array, "M D"],
    lengthscale: float | Float[Array, ""],
    variance: float | Float[Array, ""] = 1.0,
    *,
    compute_dtype: jnp.dtype = jnp.float32,
    accumulate_dtype: jnp.dtype = jnp.float64,
) -> Float[Array, "N M"]:
    r"""RBF (squared exponential) kernel with mixed-precision stability.

    Computes ``variance * exp(-0.5 * ||x - z||^2 / lengthscale^2)``
    using `stable_squared_distances` for the distance computation.

    Args:
        X: First set of points, shape ``(N, D)``.
        Z: Second set of points, shape ``(M, D)``.
        lengthscale: Kernel lengthscale.
        variance: Kernel signal variance (default 1.0).
        compute_dtype: Dtype for dot products (default float32).
        accumulate_dtype: Dtype for subtraction (default float64).

    Returns:
        Kernel matrix, shape ``(N, M)``.
    """
    dist_sq = stable_squared_distances(
        X, Z, compute_dtype=compute_dtype, accumulate_dtype=accumulate_dtype
    )
    return variance * jnp.exp(-0.5 * dist_sq / lengthscale**2)

discrete_lyapunov_solve(G: Float[Array, 'N N'], Q: Float[Array, 'N N']) -> Float[Array, 'N N']

Solve the discrete Lyapunov equation P - G P G^T = Q.

Uses the eigendecomposition G = V Λ V^{-1} so that

\[ \tilde{P} - \Lambda \tilde{P} \Lambda^T = V^{-1} Q V^{-T}, \quad \tilde{P}_{ij} = \frac{(V^{-1} Q V^{-T})_{ij}}{1 - \lambda_i \lambda_j}, \]

and then P = V \tilde{P} V^T. This avoids materializing the (N², N²) Kronecker matrix the vectorized formulation requires.

Cost is O(N³) (one general eigendecomposition + a couple of matrix multiplies) versus O(N⁶) for the vectorized solve, and the memory footprint drops from O(N⁴) to O(N²).

The discrete Lyapunov equation has a unique solution iff λ_i λ_j ≠ 1 for all eigenvalue pairs (λ_i, λ_j) of G. The standard sufficient condition — G is stable (spectral radius < 1) — guarantees this.

Warning

This implementation assumes G is diagonalizable. For defective G (e.g., a Jordan block with eigenvalue magnitude < 1) the eigenvector matrix V is singular and this solve will return NaN / inf. The discrete Lyapunov equation still has a unique solution in that case, but recovering it requires a Schur decomposition (Bartels-Stewart), which JAX does not currently expose. Fall back to the vectorized (I − G ⊗ G) solve for those operators.

Parameters:

Name Type Description Default
G Float[Array, 'N N']

Square matrix, shape (N, N). Should be stable for a unique steady-state solution.

required
Q Float[Array, 'N N']

Right-hand side, shape (N, N). Typically symmetric PSD in Kalman-smoother applications.

required

Returns:

Type Description
Float[Array, 'N N']

Symmetric matrix P of shape (N, N) solving

Float[Array, 'N N']

P - G P G^T = Q.

Source code in src/gaussx/_linalg/_lyapunov.py
def discrete_lyapunov_solve(
    G: Float[Array, "N N"],
    Q: Float[Array, "N N"],
) -> Float[Array, "N N"]:
    r"""Solve the discrete Lyapunov equation ``P - G P G^T = Q``.

    Uses the eigendecomposition ``G = V Λ V^{-1}`` so that

    $$
    \tilde{P} - \Lambda \tilde{P} \Lambda^T
        = V^{-1} Q V^{-T}, \quad
    \tilde{P}_{ij} = \frac{(V^{-1} Q V^{-T})_{ij}}{1 - \lambda_i \lambda_j},
    $$

    and then ``P = V \tilde{P} V^T``. This avoids materializing the
    ``(N², N²)`` Kronecker matrix the vectorized formulation requires.

    Cost is ``O(N³)`` (one general eigendecomposition + a couple of
    matrix multiplies) versus ``O(N⁶)`` for the vectorized solve, and
    the memory footprint drops from ``O(N⁴)`` to ``O(N²)``.

    The discrete Lyapunov equation has a unique solution iff
    ``λ_i λ_j ≠ 1`` for all eigenvalue pairs ``(λ_i, λ_j)`` of ``G``.
    The standard sufficient condition — ``G`` is stable
    (spectral radius < 1) — guarantees this.

    !!! warning

        This implementation assumes ``G`` is **diagonalizable**. For
        defective ``G`` (e.g., a Jordan block with eigenvalue magnitude
        ``< 1``) the eigenvector matrix ``V`` is singular and this
        solve will return ``NaN`` / ``inf``. The discrete Lyapunov
        equation still has a unique solution in that case, but
        recovering it requires a Schur decomposition (Bartels-Stewart),
        which JAX does not currently expose. Fall back to the
        vectorized ``(I − G ⊗ G)`` solve for those operators.

    Args:
        G: Square matrix, shape ``(N, N)``. Should be stable for a
            unique steady-state solution.
        Q: Right-hand side, shape ``(N, N)``. Typically symmetric PSD
            in Kalman-smoother applications.

    Returns:
        Symmetric matrix ``P`` of shape ``(N, N)`` solving
        ``P - G P G^T = Q``.
    """
    eigs, V = jnp.linalg.eig(G)
    # Use solves rather than an explicit ``inv(V)``: more stable, and
    # cheaper when V is well-conditioned. We need
    #     Q_tilde = V^{-1} Q V^{-T}
    # so first solve V Y = Q for Y = V^{-1} Q, then solve V Z = Y^T for
    # Z = V^{-1} Y^T = V^{-1} Q (V^{-1})^T (so Q_tilde = Z^T = V^{-1} Q V^{-T}).
    Q_complex = Q.astype(V.dtype)
    Y = jnp.linalg.solve(V, Q_complex)
    Q_tilde = jnp.linalg.solve(V, Y.T).T
    denom = 1.0 - eigs[:, None] * eigs[None, :]
    P_tilde = Q_tilde / denom
    P = V @ P_tilde @ V.T
    P_real = jnp.real(P)
    # Symmetrize to eliminate residual floating-point asymmetry —
    # consistent with conditional / infinite_horizon_smoother covariance
    # post-processing.
    return symmetrize(P_real)