Skip to content

Operators & Tags

Layer 1: structured linear operators extending lineax.AbstractLinearOperator. All are immutable equinox.Module pytrees, so they compose freely with jit, grad, and vmap. The primitives dispatch on these types: a solve against a Kronecker factorizes per Kronecker factor, a logdet of a BlockDiag sums per block, a LowRankUpdate solve applies Woodbury.

Structured products & sums

The Kronecker product \(A_1 \otimes A_2 \otimes \cdots\) gives \(O(\sum_i n_i^3)\) solves on a \(\prod_i n_i\) grid; the Kronecker sum \(A \otimes I + I \otimes B\) diagonalises in the joint eigenbasis with eigenvalues \(\lambda_i + \mu_j\).

Structured linear algebra and Gaussian primitives for JAX.

Kronecker

Bases: AbstractLinearOperator

Kronecker product operator A₁ ⊗ A₂ ⊗ … ⊗ Aₖ.

Matvec uses Roth's column lemma for efficient computation without materializing the full Kronecker product. For two factors A (m x n) and B (p x q), the product (A kron B) vec(X) is computed as vec(B X A^T) where X is reshaped to (q, n).

Complexity: O(sum n_i^3) instead of O((prod n_i)^2) for the naive approach.

Parameters:

Name Type Description Default
*operators AbstractLinearOperator

Two or more lineax.AbstractLinearOperator instances.

()
Source code in src/gaussx/_operators/_kronecker.py
class Kronecker(lx.AbstractLinearOperator):
    """Kronecker product operator ``A₁ ⊗ A₂ ⊗ … ⊗ Aₖ``.

    Matvec uses Roth's column lemma for efficient computation without
    materializing the full Kronecker product. For two factors A (m x n)
    and B (p x q), the product (A kron B) vec(X) is computed as
    vec(B X A^T) where X is reshaped to (q, n).

    Complexity: O(sum n_i^3) instead of O((prod n_i)^2) for the naive approach.

    Args:
        *operators: Two or more ``lineax.AbstractLinearOperator`` instances.
    """

    operators: tuple[lx.AbstractLinearOperator, ...]
    _in_size: int = eqx.field(static=True)
    _out_size: int = eqx.field(static=True)
    _dtype: str = eqx.field(static=True)
    tags: frozenset[object] = eqx.field(static=True)

    def __init__(
        self,
        *operators: lx.AbstractLinearOperator,
        tags: object | frozenset[object] = frozenset(),
    ) -> None:
        if len(operators) < 2:
            raise ValueError("Kronecker requires at least two operators.")
        self.operators = operators
        self._in_size = _prod(op.in_size() for op in operators)
        self._out_size = _prod(op.out_size() for op in operators)
        self._dtype = _resolve_dtype(*operators)
        from gaussx._tags import kronecker_tag

        self.tags = _to_frozenset(tags) | {kronecker_tag}

    def mv(self, vector: Float[Array, " n"]) -> Float[Array, " m"]:
        return _kronecker_mv(self.operators, vector)

    def as_matrix(self) -> Float[Array, "m n"]:
        result = self.operators[0].as_matrix()
        for op in self.operators[1:]:
            result = jnp.kron(result, op.as_matrix())
        return result

    def transpose(self) -> Kronecker:
        return Kronecker(
            *(op.T for op in self.operators),
            tags=lx.transpose_tags(self.tags),
        )

    def in_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((self._in_size,), jnp.dtype(self._dtype))

    def out_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((self._out_size,), jnp.dtype(self._dtype))

BlockDiag

Bases: AbstractLinearOperator

Block diagonal operator diag(A₁, A₂, …, Aₖ).

Each sub-operator acts on its own slice of the input vector. Matvec, transpose, logdet, solve, and cholesky all decompose per-block.

Parameters:

Name Type Description Default
*operators AbstractLinearOperator

One or more lineax.AbstractLinearOperator instances forming the diagonal blocks.

()
Source code in src/gaussx/_operators/_block_diag.py
class BlockDiag(lx.AbstractLinearOperator):
    """Block diagonal operator ``diag(A₁, A₂, …, Aₖ)``.

    Each sub-operator acts on its own slice of the input vector.
    Matvec, transpose, logdet, solve, and cholesky all decompose
    per-block.

    Args:
        *operators: One or more ``lineax.AbstractLinearOperator`` instances
            forming the diagonal blocks.
    """

    operators: tuple[lx.AbstractLinearOperator, ...]
    _in_size: int = eqx.field(static=True)
    _out_size: int = eqx.field(static=True)
    _dtype: str = eqx.field(static=True)
    tags: frozenset[object] = eqx.field(static=True)

    def __init__(
        self,
        *operators: lx.AbstractLinearOperator,
        tags: object | frozenset[object] = frozenset(),
    ) -> None:
        if len(operators) == 0:
            raise ValueError("BlockDiag requires at least one operator.")
        self.operators = operators
        self._in_size = sum(op.in_size() for op in operators)
        self._out_size = sum(op.out_size() for op in operators)
        self._dtype = _resolve_dtype(*operators)
        from gaussx._tags import block_diagonal_tag

        self.tags = _to_frozenset(tags) | {block_diagonal_tag}

    def mv(self, vector: Float[Array, " n"]) -> Float[Array, " m"]:
        # Use jax.lax.dynamic_slice with static offsets for JIT compatibility
        results = []
        offset = 0
        for op in self.operators:
            size = op.in_size()
            block = jax.lax.dynamic_slice(vector, (offset,), (size,))
            results.append(op.mv(block))
            offset += size
        return jnp.concatenate(results)

    def as_matrix(self) -> Float[Array, "m n"]:
        matrices = [op.as_matrix() for op in self.operators]
        return jax.scipy.linalg.block_diag(*matrices)

    def transpose(self) -> BlockDiag:
        return BlockDiag(
            *(op.T for op in self.operators),
            tags=lx.transpose_tags(self.tags),
        )

    def in_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((self._in_size,), jnp.dtype(self._dtype))

    def out_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((self._out_size,), jnp.dtype(self._dtype))

KroneckerSum

Bases: AbstractLinearOperator

Kronecker sum A \oplus B = A \otimes I_b + I_a \otimes B.

Appears in separable PDEs, graph Laplacians, and space-time GPs. If A = Q_A \Lambda_A Q_A^T and B = Q_B \Lambda_B Q_B^T, the Kronecker sum has eigenvectors Q_A \otimes Q_B with eigenvalues \lambda^A_i + \lambda^B_j.

Parameters:

Name Type Description Default
A AbstractLinearOperator

First operator, shape (n_a, n_a).

required
B AbstractLinearOperator

Second operator, shape (n_b, n_b).

required
Source code in src/gaussx/_operators/_kronecker_sum.py
class KroneckerSum(lx.AbstractLinearOperator):
    r"""Kronecker sum ``A \oplus B = A \otimes I_b + I_a \otimes B``.

    Appears in separable PDEs, graph Laplacians, and space-time GPs.
    If ``A = Q_A \Lambda_A Q_A^T`` and ``B = Q_B \Lambda_B Q_B^T``,
    the Kronecker sum has eigenvectors ``Q_A \otimes Q_B`` with
    eigenvalues ``\lambda^A_i + \lambda^B_j``.

    Args:
        A: First operator, shape ``(n_a, n_a)``.
        B: Second operator, shape ``(n_b, n_b)``.
    """

    A: lx.AbstractLinearOperator
    B: lx.AbstractLinearOperator
    _in_size: int = eqx.field(static=True)
    _out_size: int = eqx.field(static=True)
    _n_a: int = eqx.field(static=True)
    _n_b: int = eqx.field(static=True)
    _dtype: str = eqx.field(static=True)
    tags: frozenset[object] = eqx.field(static=True)

    def __init__(
        self,
        A: lx.AbstractLinearOperator,
        B: lx.AbstractLinearOperator,
        *,
        tags: object | frozenset[object] = frozenset(),
    ) -> None:
        if A.in_size() != A.out_size():
            raise ValueError(
                f"A must be square, got in_size={A.in_size()}, out_size={A.out_size()}."
            )
        if B.in_size() != B.out_size():
            raise ValueError(
                f"B must be square, got in_size={B.in_size()}, out_size={B.out_size()}."
            )
        self.A = A
        self.B = B
        n_a = A.in_size()
        n_b = B.in_size()
        self._n_a = n_a
        self._n_b = n_b
        self._in_size = n_a * n_b
        self._out_size = n_a * n_b
        self._dtype = _resolve_dtype(A, B)
        from gaussx._tags import kronecker_sum_tag

        self.tags = _to_frozenset(tags) | {kronecker_sum_tag}

    def mv(self, vector: Float[Array, " n"]) -> Float[Array, " n"]:
        # (A (x) I_b + I_a (x) B) vec(X) = vec(B X + X A^T)
        # where X is (n_b, n_a)
        X = rearrange(vector, "(a b) -> b a", a=self._n_a, b=self._n_b)
        # I_a (x) B: apply B to each column of X
        BX = jax.vmap(self.B.mv, in_axes=1, out_axes=1)(X)
        # A (x) I_b: apply A^T to each row of X (= apply A to rows of X^T)
        XAt = jax.vmap(self.A.mv)(X)  # (n_b, n_a): apply A to rows
        result = BX + XAt
        return rearrange(result, "b a -> (a b)")

    def as_matrix(self) -> Float[Array, "n n"]:
        A_mat = self.A.as_matrix()
        B_mat = self.B.as_matrix()
        I_a = jnp.eye(self._n_a, dtype=jnp.dtype(self._dtype))
        I_b = jnp.eye(self._n_b, dtype=jnp.dtype(self._dtype))
        return jnp.kron(A_mat, I_b) + jnp.kron(I_a, B_mat)

    def transpose(self) -> KroneckerSum:
        return KroneckerSum(
            self.A.T,
            self.B.T,
            tags=lx.transpose_tags(self.tags),
        )

    def in_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((self._in_size,), jnp.dtype(self._dtype))

    def out_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((self._out_size,), jnp.dtype(self._dtype))

    def eigendecompose(
        self,
    ) -> tuple[Float[Array, " n"], Float[Array, "n n"]]:
        """Symmetric eigendecomposition via per-factor decomposition.

        Assumes both factors are symmetric so the returned
        ``Q = Q_A ⊗ Q_B`` is orthonormal — callers rely on
        ``self == Q @ diag(eigenvalues) @ Q.T``. Diagonal factors get a
        structural shortcut; other operators are materialized and
        decomposed via ``jnp.linalg.eigh``. We deliberately avoid
        routing untagged factors through `gaussx.eig` because
        that primitive falls back to ``jnp.linalg.eig`` for untagged
        operators and would return general (non-orthonormal)
        eigenvectors — breaking the ``Q.T == Q^{-1}`` contract for the
        common case of numerically symmetric matrices wrapped as plain
        `lineax.MatrixLinearOperator`.

        Returns:
            Tuple ``(eigenvalues, Q)`` where ``Q = Q_A ⊗ Q_B`` and the
            eigenvalues are ``lambda^A_i + lambda^B_j`` for all pairs.
        """

        evals_a, evecs_a = _eigh_factor(self.A)
        evals_b, evecs_b = _eigh_factor(self.B)
        # Eigenvalues: lambda_a_i + lambda_b_j for all (i, j) pairs
        eigenvalues = rearrange(evals_a[:, None] + evals_b[None, :], "a b -> (a b)")
        # Eigenvectors: Q_A (x) Q_B
        Q = jnp.kron(evecs_a, evecs_b)
        return eigenvalues, Q

eigendecompose() -> tuple[Float[Array, ' n'], Float[Array, 'n n']]

Symmetric eigendecomposition via per-factor decomposition.

Assumes both factors are symmetric so the returned Q = Q_A ⊗ Q_B is orthonormal — callers rely on self == Q @ diag(eigenvalues) @ Q.T. Diagonal factors get a structural shortcut; other operators are materialized and decomposed via jnp.linalg.eigh. We deliberately avoid routing untagged factors through gaussx.eig because that primitive falls back to jnp.linalg.eig for untagged operators and would return general (non-orthonormal) eigenvectors — breaking the Q.T == Q^{-1} contract for the common case of numerically symmetric matrices wrapped as plain lineax.MatrixLinearOperator.

Returns:

Type Description
Float[Array, ' n']

Tuple (eigenvalues, Q) where Q = Q_A ⊗ Q_B and the

Float[Array, 'n n']

eigenvalues are lambda^A_i + lambda^B_j for all pairs.

Source code in src/gaussx/_operators/_kronecker_sum.py
def eigendecompose(
    self,
) -> tuple[Float[Array, " n"], Float[Array, "n n"]]:
    """Symmetric eigendecomposition via per-factor decomposition.

    Assumes both factors are symmetric so the returned
    ``Q = Q_A ⊗ Q_B`` is orthonormal — callers rely on
    ``self == Q @ diag(eigenvalues) @ Q.T``. Diagonal factors get a
    structural shortcut; other operators are materialized and
    decomposed via ``jnp.linalg.eigh``. We deliberately avoid
    routing untagged factors through `gaussx.eig` because
    that primitive falls back to ``jnp.linalg.eig`` for untagged
    operators and would return general (non-orthonormal)
    eigenvectors — breaking the ``Q.T == Q^{-1}`` contract for the
    common case of numerically symmetric matrices wrapped as plain
    `lineax.MatrixLinearOperator`.

    Returns:
        Tuple ``(eigenvalues, Q)`` where ``Q = Q_A ⊗ Q_B`` and the
        eigenvalues are ``lambda^A_i + lambda^B_j`` for all pairs.
    """

    evals_a, evecs_a = _eigh_factor(self.A)
    evals_b, evecs_b = _eigh_factor(self.B)
    # Eigenvalues: lambda_a_i + lambda_b_j for all (i, j) pairs
    eigenvalues = rearrange(evals_a[:, None] + evals_b[None, :], "a b -> (a b)")
    # Eigenvectors: Q_A (x) Q_B
    Q = jnp.kron(evecs_a, evecs_b)
    return eigenvalues, Q

KroneckerSumSqrt

Bases: AbstractLinearOperator

Symmetric square root of A \oplus B via per-factor eigenvectors.

Represents the symmetric matrix S with S @ S = A \oplus B (where \oplus is the Kronecker sum A ⊗ I + I ⊗ B). The square root is never materialized: mv and solve apply S and S^{-1} matrix-free using the per-factor eigendecompositions, so the cost is governed by the factor sizes rather than the full n_a · n_b dimension.

Parameters:

Name Type Description Default
A AbstractLinearOperator

Symmetric PSD factor, shape (n_a, n_a).

required
B AbstractLinearOperator

Symmetric PSD factor, shape (n_b, n_b).

required

Raises:

Type Description
ValueError

If either factor is non-square, untagged as symmetric, or if A \oplus B is not positive semidefinite.

Source code in src/gaussx/_operators/_kronecker_sum.py
class KroneckerSumSqrt(lx.AbstractLinearOperator):
    r"""Symmetric square root of ``A \oplus B`` via per-factor eigenvectors.

    Represents the symmetric matrix ``S`` with ``S @ S = A \oplus B``
    (where ``\oplus`` is the Kronecker sum ``A ⊗ I + I ⊗ B``). The square
    root is never materialized: `mv` and `solve` apply ``S`` and
    ``S^{-1}`` matrix-free using the per-factor eigendecompositions, so the
    cost is governed by the factor sizes rather than the full ``n_a · n_b``
    dimension.

    Args:
        A: Symmetric PSD factor, shape ``(n_a, n_a)``.
        B: Symmetric PSD factor, shape ``(n_b, n_b)``.

    Raises:
        ValueError: If either factor is non-square, untagged as symmetric, or
            if ``A \oplus B`` is not positive semidefinite.
    """

    eigenvectors_a: Float[Array, "a a"]
    eigenvectors_b: Float[Array, "b b"]
    sqrt_eigenvalues: Float[Array, "a b"]
    _in_size: int = eqx.field(static=True)
    _out_size: int = eqx.field(static=True)
    _n_a: int = eqx.field(static=True)
    _n_b: int = eqx.field(static=True)
    _dtype: str = eqx.field(static=True)

    def __init__(
        self,
        A: lx.AbstractLinearOperator,
        B: lx.AbstractLinearOperator,
    ) -> None:
        if A.in_size() != A.out_size():
            raise ValueError(
                f"A must be square, got in_size={A.in_size()}, out_size={A.out_size()}."
            )
        if B.in_size() != B.out_size():
            raise ValueError(
                f"B must be square, got in_size={B.in_size()}, out_size={B.out_size()}."
            )
        # The symmetric sqrt is well-defined only when A and B are
        # symmetric PSD. Without these tags, ``jnp.linalg.eigh`` would
        # silently use only the lower triangle and return wrong
        # eigenvectors for non-symmetric inputs.
        if not lx.is_symmetric(A) or not lx.is_symmetric(B):
            raise ValueError(
                "KroneckerSumSqrt requires both factors to be symmetric "
                "(tag them with lx.symmetric_tag or lx.positive_semidefinite_tag)."
            )
        evals_a, evecs_a = _eigh_factor(A)
        evals_b, evecs_b = _eigh_factor(B)
        eigenvalues = (evals_a[:, None] + evals_b[None, :]).astype(
            jnp.result_type(evals_a, evals_b, jnp.float32)
        )
        # Tolerance for "numerically zero" negative eigenvalues. We scale
        # by ``sqrt(spectrum magnitude)`` so the threshold stays tight
        # enough for large-magnitude spectra while still admitting eigh
        # roundoff. Linear scaling becomes too permissive: with
        # ``scale ~ 1e8`` and the previous ``100 * eps * scale`` formula,
        # genuinely-indefinite operators (negatives on the order of
        # ``-1e3``) could slip past the guard.
        scale = jnp.maximum(jnp.max(jnp.abs(eigenvalues)), 1.0)
        threshold = (
            -_NEGATIVE_EIGENVALUE_TOLERANCE_FACTOR
            * jnp.finfo(eigenvalues.dtype).eps
            * jnp.sqrt(scale)
        )
        min_eigenvalue = jnp.min(eigenvalues)
        if bool(min_eigenvalue < threshold):
            raise ValueError(
                "A ⊕ B must be positive semidefinite; "
                f"minimum eigenvalue {float(min_eigenvalue):.2e} is below "
                f"threshold {float(threshold):.2e}."
            )
        sqrt_eigenvalues = jnp.sqrt(jnp.maximum(eigenvalues, 0.0))

        self.eigenvectors_a = evecs_a
        self.eigenvectors_b = evecs_b
        self.sqrt_eigenvalues = sqrt_eigenvalues
        self._n_a = A.in_size()
        self._n_b = B.in_size()
        self._in_size = self._n_a * self._n_b
        self._out_size = self._in_size
        self._dtype = str(jnp.result_type(evecs_a, evecs_b, sqrt_eigenvalues))

    def mv(self, vector: Float[Array, " n"]) -> Float[Array, " n"]:
        X = rearrange(vector, "(a b) -> b a", a=self._n_a, b=self._n_b)
        C = self.eigenvectors_b.T @ X @ self.eigenvectors_a
        C = self.sqrt_eigenvalues.T * C
        result = self.eigenvectors_b @ C @ self.eigenvectors_a.T
        return rearrange(result, "b a -> (a b)")

    def solve(self, vector: Float[Array, " n"]) -> Float[Array, " n"]:
        """Apply the inverse square root ``S^{-1}`` to ``vector``.

        Args:
            vector: Input vector, shape ``(n_a · n_b,)``.

        Returns:
            ``S^{-1} @ vector``, shape ``(n_a · n_b,)``.
        """
        X = rearrange(vector, "(a b) -> b a", a=self._n_a, b=self._n_b)
        C = self.eigenvectors_b.T @ X @ self.eigenvectors_a
        C = C / self.sqrt_eigenvalues.T
        result = self.eigenvectors_b @ C @ self.eigenvectors_a.T
        return rearrange(result, "b a -> (a b)")

    def as_matrix(self) -> Float[Array, "n n"]:
        basis = jnp.eye(self._in_size, dtype=jnp.dtype(self._dtype))
        return jax.vmap(self.mv, in_axes=1, out_axes=1)(basis)

    def transpose(self) -> KroneckerSumSqrt:
        return self

    def in_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((self._in_size,), jnp.dtype(self._dtype))

    def out_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((self._out_size,), jnp.dtype(self._dtype))

solve(vector: Float[Array, ' n']) -> Float[Array, ' n']

Apply the inverse square root S^{-1} to vector.

Parameters:

Name Type Description Default
vector Float[Array, ' n']

Input vector, shape (n_a · n_b,).

required

Returns:

Type Description
Float[Array, ' n']

S^{-1} @ vector, shape (n_a · n_b,).

Source code in src/gaussx/_operators/_kronecker_sum.py
def solve(self, vector: Float[Array, " n"]) -> Float[Array, " n"]:
    """Apply the inverse square root ``S^{-1}`` to ``vector``.

    Args:
        vector: Input vector, shape ``(n_a · n_b,)``.

    Returns:
        ``S^{-1} @ vector``, shape ``(n_a · n_b,)``.
    """
    X = rearrange(vector, "(a b) -> b a", a=self._n_a, b=self._n_b)
    C = self.eigenvectors_b.T @ X @ self.eigenvectors_a
    C = C / self.sqrt_eigenvalues.T
    result = self.eigenvectors_b @ C @ self.eigenvectors_a.T
    return rearrange(result, "b a -> (a b)")

SumKronecker

Bases: AbstractLinearOperator

Sum of Kronecker products Σ_k A_k \otimes B_k.

Appears in multi-output GPs with correlated outputs, e.g. K_task \otimes K_spatial + \sigma^2 I_task \otimes I_spatial.

Matvec is computed as the sum of the Kronecker matvecs.

For solve and logdet, call eigendecompose which uses a joint eigendecomposition of the second Kronecker pair (requires A_2, B_2 to be symmetric). The eigendecomposition forms a dense (n_c n_d) x (n_c n_d) matrix internally, so it is intended for moderate factor sizes (typical for multi-output GPs where the task dimension is small).

Parameters:

Name Type Description Default
kron1 Kronecker

First Kronecker product A_1 \otimes B_1.

required
kron2 Kronecker

Second Kronecker product A_2 \otimes B_2.

required
*krons Kronecker

Additional two-factor Kronecker products.

()
Source code in src/gaussx/_operators/_sum_kronecker.py
class SumKronecker(lx.AbstractLinearOperator):
    r"""Sum of Kronecker products ``Σ_k A_k \otimes B_k``.

    Appears in multi-output GPs with correlated outputs, e.g.
    ``K_task \otimes K_spatial + \sigma^2 I_task \otimes I_spatial``.

    Matvec is computed as the sum of the Kronecker matvecs.

    For solve and logdet, call `eigendecompose` which uses a
    joint eigendecomposition of the second Kronecker pair (requires
    ``A_2, B_2`` to be symmetric).  The eigendecomposition forms a
    dense ``(n_c n_d) x (n_c n_d)`` matrix internally, so it is
    intended for moderate factor sizes (typical for multi-output GPs
    where the task dimension is small).

    Args:
        kron1: First Kronecker product ``A_1 \otimes B_1``.
        kron2: Second Kronecker product ``A_2 \otimes B_2``.
        *krons: Additional two-factor Kronecker products.
    """

    operators: tuple[Kronecker, ...]
    _in_size: int = eqx.field(static=True)
    _out_size: int = eqx.field(static=True)
    _dtype: str = eqx.field(static=True)
    tags: frozenset[object] = eqx.field(static=True)

    def __init__(
        self,
        kron1: Kronecker,
        kron2: Kronecker,
        *krons: Kronecker,
        tags: object | frozenset[object] = frozenset(),
    ) -> None:
        operators = (kron1, kron2, *krons)
        if any(len(kron.operators) != 2 for kron in operators):
            raise ValueError("SumKronecker requires two-factor Kronecker products.")
        if any(kron.in_size() != kron1.in_size() for kron in operators[1:]):
            raise ValueError("Kronecker products must have the same size (input size).")
        if any(kron.out_size() != kron1.out_size() for kron in operators[1:]):
            raise ValueError(
                "Kronecker products must have the same size (output size)."
            )
        self.operators = operators
        self._in_size = kron1.in_size()
        self._out_size = kron1.out_size()
        self._dtype = _resolve_dtype(*operators)
        self.tags = _to_frozenset(tags)

    @property
    def kron1(self) -> Kronecker:
        return self.operators[0]

    @property
    def kron2(self) -> Kronecker:
        return self.operators[1]

    def mv(self, vector: Float[Array, " n"]) -> Float[Array, " m"]:
        result = self.operators[0].mv(vector)
        for kron in self.operators[1:]:
            result = result + kron.mv(vector)
        return result

    def as_matrix(self) -> Float[Array, "n n"]:
        result = self.operators[0].as_matrix()
        for kron in self.operators[1:]:
            result = result + kron.as_matrix()
        return result

    def transpose(self) -> SumKronecker:
        return SumKronecker(
            *(
                Kronecker(
                    kron.operators[0].T,
                    kron.operators[1].T,
                    tags=lx.transpose_tags(kron.tags),
                )
                for kron in self.operators
            ),
            tags=lx.transpose_tags(self.tags),
        )

    def in_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((self._in_size,), jnp.dtype(self._dtype))

    def out_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((self._out_size,), jnp.dtype(self._dtype))

    def eigendecompose(
        self,
    ) -> tuple[Float[Array, " n"], Float[Array, "n n"]]:
        r"""Eigendecompose via joint eigendecomposition of the second pair.

        Decomposes ``A_2 = Q_C \Lambda_C Q_C^T`` and
        ``B_2 = Q_D \Lambda_D Q_D^T``, then transforms the first pair
        into the eigenbasis and diagonalizes the result.

        !!! note

            This forms a dense ``(n_c n_d) x (n_c n_d)`` matrix
            internally and is O((n_c n_d)^3).  Intended for moderate
            factor sizes (e.g. multi-output GPs where task dimension
            is small).

        Raises:
            ValueError: If the factors of ``kron2`` are not symmetric.

        Returns:
            Tuple ``(eigenvalues, Q)`` where
            ``self == Q @ diag(eigenvalues) @ Q^T``.
        """
        if len(self.operators) != 2:
            count = len(self.operators)
            raise ValueError(
                f"eigendecompose requires exactly two Kronecker products, got {count}."
            )
        A2_op, B2_op = self.kron2.operators
        A1_op, B1_op = self.kron1.operators
        # The final ``eigh(transformed)`` call requires ``transformed`` to
        # be symmetric, which in turn requires *both* kron pairs to have
        # symmetric factors (so that ``A1_tilde`` and ``B1_tilde`` stay
        # symmetric under the ``Q^T A Q`` rotation).
        if not lx.is_symmetric(A2_op) or not lx.is_symmetric(B2_op):
            raise ValueError("eigendecompose requires kron2 factors to be symmetric.")
        if not lx.is_symmetric(A1_op) or not lx.is_symmetric(B1_op):
            raise ValueError("eigendecompose requires kron1 factors to be symmetric.")

        from gaussx._primitives._eig import eig

        A1, B1 = (op.as_matrix() for op in self.kron1.operators)

        # Per-factor eigendecomposition routed through the structural
        # primitive: Diagonal / BlockDiag / nested Kronecker factors of
        # ``A_2`` or ``B_2`` skip materialization here.
        evals_c, Q_C = eig(A2_op)
        evals_d, Q_D = eig(B2_op)

        # Transform first pair into eigenbasis of second pair
        A1_tilde = Q_C.T @ A1 @ Q_C
        B1_tilde = Q_D.T @ B1 @ Q_D

        # kron(A1_tilde, B1_tilde) + diag(evals_c kron evals_d)
        diag_vals = rearrange(
            evals_c[:, None] * evals_d[None, :],
            "a b -> (a b)",
        )
        transformed = jnp.kron(A1_tilde, B1_tilde) + jnp.diag(diag_vals)
        evals, V = jnp.linalg.eigh(transformed)

        Q = jnp.kron(Q_C, Q_D) @ V
        return evals, Q

eigendecompose() -> tuple[Float[Array, ' n'], Float[Array, 'n n']]

Eigendecompose via joint eigendecomposition of the second pair.

Decomposes A_2 = Q_C \Lambda_C Q_C^T and B_2 = Q_D \Lambda_D Q_D^T, then transforms the first pair into the eigenbasis and diagonalizes the result.

Note

This forms a dense (n_c n_d) x (n_c n_d) matrix internally and is O((n_c n_d)^3). Intended for moderate factor sizes (e.g. multi-output GPs where task dimension is small).

Raises:

Type Description
ValueError

If the factors of kron2 are not symmetric.

Returns:

Type Description
Float[Array, ' n']

Tuple (eigenvalues, Q) where

Float[Array, 'n n']

self == Q @ diag(eigenvalues) @ Q^T.

Source code in src/gaussx/_operators/_sum_kronecker.py
def eigendecompose(
    self,
) -> tuple[Float[Array, " n"], Float[Array, "n n"]]:
    r"""Eigendecompose via joint eigendecomposition of the second pair.

    Decomposes ``A_2 = Q_C \Lambda_C Q_C^T`` and
    ``B_2 = Q_D \Lambda_D Q_D^T``, then transforms the first pair
    into the eigenbasis and diagonalizes the result.

    !!! note

        This forms a dense ``(n_c n_d) x (n_c n_d)`` matrix
        internally and is O((n_c n_d)^3).  Intended for moderate
        factor sizes (e.g. multi-output GPs where task dimension
        is small).

    Raises:
        ValueError: If the factors of ``kron2`` are not symmetric.

    Returns:
        Tuple ``(eigenvalues, Q)`` where
        ``self == Q @ diag(eigenvalues) @ Q^T``.
    """
    if len(self.operators) != 2:
        count = len(self.operators)
        raise ValueError(
            f"eigendecompose requires exactly two Kronecker products, got {count}."
        )
    A2_op, B2_op = self.kron2.operators
    A1_op, B1_op = self.kron1.operators
    # The final ``eigh(transformed)`` call requires ``transformed`` to
    # be symmetric, which in turn requires *both* kron pairs to have
    # symmetric factors (so that ``A1_tilde`` and ``B1_tilde`` stay
    # symmetric under the ``Q^T A Q`` rotation).
    if not lx.is_symmetric(A2_op) or not lx.is_symmetric(B2_op):
        raise ValueError("eigendecompose requires kron2 factors to be symmetric.")
    if not lx.is_symmetric(A1_op) or not lx.is_symmetric(B1_op):
        raise ValueError("eigendecompose requires kron1 factors to be symmetric.")

    from gaussx._primitives._eig import eig

    A1, B1 = (op.as_matrix() for op in self.kron1.operators)

    # Per-factor eigendecomposition routed through the structural
    # primitive: Diagonal / BlockDiag / nested Kronecker factors of
    # ``A_2`` or ``B_2`` skip materialization here.
    evals_c, Q_C = eig(A2_op)
    evals_d, Q_D = eig(B2_op)

    # Transform first pair into eigenbasis of second pair
    A1_tilde = Q_C.T @ A1 @ Q_C
    B1_tilde = Q_D.T @ B1 @ Q_D

    # kron(A1_tilde, B1_tilde) + diag(evals_c kron evals_d)
    diag_vals = rearrange(
        evals_c[:, None] * evals_d[None, :],
        "a b -> (a b)",
    )
    transformed = jnp.kron(A1_tilde, B1_tilde) + jnp.diag(diag_vals)
    evals, V = jnp.linalg.eigh(transformed)

    Q = jnp.kron(Q_C, Q_D) @ V
    return evals, Q

Low-rank updates

\(L + U\,\mathrm{diag}(d)\,V^\top\) with Woodbury-efficient solves and matrix-determinant-lemma logdets. The factories build the common special cases directly from arrays.

Structured linear algebra and Gaussian primitives for JAX.

LowRankUpdate

Bases: AbstractLinearOperator

Low-rank update operator L + U diag(d) Vᵀ.

Represents a base operator L plus a rank-k update. When L is cheap to solve (e.g. diagonal), the Woodbury identity gives efficient solves for the full operator.

Parameters:

Name Type Description Default
base AbstractLinearOperator

The base operator L, with shape (m, n).

required
U Float[Array, 'n k']

Left factor, shape (m, k).

required
d Float[Array, ' k'] | None

Diagonal scaling, shape (k,). Defaults to ones.

None
V Float[Array, 'n k'] | None

Right factor, shape (n, k). Defaults to U for square operators, yielding the symmetric update L + U diag(d) Uᵀ.

None
orthonormal bool

When True, symmetry inference for orthonormal factors uses object identity, so symmetric SVD-style updates must pass the same array object for U and V.

False
Source code in src/gaussx/_operators/_low_rank_update.py
class LowRankUpdate(lx.AbstractLinearOperator):
    """Low-rank update operator ``L + U diag(d) Vᵀ``.

    Represents a base operator *L* plus a rank-k update. When *L*
    is cheap to solve (e.g. diagonal), the Woodbury identity gives
    efficient solves for the full operator.

    Args:
        base: The base operator *L*, with shape ``(m, n)``.
        U: Left factor, shape ``(m, k)``.
        d: Diagonal scaling, shape ``(k,)``. Defaults to ones.
        V: Right factor, shape ``(n, k)``. Defaults to *U* for
            square operators, yielding the symmetric update
            ``L + U diag(d) Uᵀ``.
        orthonormal: When ``True``, symmetry inference for orthonormal
            factors uses object identity, so symmetric SVD-style updates
            must pass the same array object for ``U`` and ``V``.
    """

    base: lx.AbstractLinearOperator
    U: Float[Array, "m k"]
    d: Float[Array, " k"]
    V: Float[Array, "n k"]
    orthonormal: bool = eqx.field(static=True)
    tags: frozenset[object] = eqx.field(static=True)

    def __init__(
        self,
        base: lx.AbstractLinearOperator,
        U: Float[Array, "n k"],
        d: Float[Array, " k"] | None = None,
        V: Float[Array, "n k"] | None = None,
        *,
        tags: object | frozenset[object] = frozenset(),
        orthonormal: bool = False,
    ) -> None:
        m = base.out_size()
        n = base.in_size()
        k = U.shape[1] if U.ndim == 2 else 1
        if U.ndim == 1:
            U = U[:, None]
        if d is None:
            d = jnp.ones(k, dtype=U.dtype)
        if V is None:
            V = U
        if V.ndim == 1:
            V = V[:, None]
        if U.shape[0] != m or V.shape[0] != n:
            raise ValueError(
                f"U must have {m} rows and V must have {n} rows to match "
                f"base operator, "
                f"got U.shape={U.shape}, V.shape={V.shape}."
            )
        if U.shape[1] != d.shape[0] or V.shape[1] != d.shape[0]:
            raise ValueError(
                f"Rank dimensions must match: U has {U.shape[1]} cols, "
                f"V has {V.shape[1]} cols, d has {d.shape[0]} entries."
            )
        self.base = base
        self.U = U
        self.d = d
        self.V = V
        self.orthonormal = orthonormal
        from gaussx._tags import low_rank_tag

        inferred_tags = _infer_tags(base, U, d, V, orthonormal=orthonormal)
        self.tags = _to_frozenset(tags) | inferred_tags | {low_rank_tag}

    @property
    def rank(self) -> int:
        """Rank of the low-rank update."""
        return self.d.shape[0]

    def mv(self, vector: Float[Array, " n"]) -> Float[Array, " m"]:
        # (L + U diag(d) V^T) x = L x + U (d * (V^T x))
        base_part = self.base.mv(vector)
        vtx = self.V.T @ vector  # (k,)
        scaled = self.d * vtx  # (k,)
        update_part = self.U @ scaled  # (m,)
        return base_part + update_part

    def as_matrix(self) -> Float[Array, "m n"]:
        L = self.base.as_matrix()
        return L + self.U @ jnp.diag(self.d) @ self.V.T

    def transpose(self) -> LowRankUpdate:
        return LowRankUpdate(
            self.base.T,
            self.V,
            self.d,
            self.U,
            tags=lx.transpose_tags(self.tags),
            orthonormal=self.orthonormal,
        )

    def in_structure(self) -> jax.ShapeDtypeStruct:
        return self.base.in_structure()

    def out_structure(self) -> jax.ShapeDtypeStruct:
        return self.base.out_structure()

rank: int property

Rank of the low-rank update.

SVDLowRankUpdate

Bases: LowRankUpdate

Deprecated subclass of LowRankUpdate with orthonormal=True.

Preserves the pre-consolidation public API for one release:

  • Same constructor signature as the old class — S defaults to ones (via the parent LowRankUpdate) if omitted, and V defaults to U so calls like SVDLowRankUpdate(base, U, S) and SVDLowRankUpdate(base, U) continue to work.
  • Inherits from LowRankUpdate so isinstance / issubclass checks and singledispatch registrations keyed on this class keep working.
  • Forces orthonormal=True and emits a DeprecationWarning on construction.

New code should construct LowRankUpdate(base, U, S, V, orthonormal=True) (or use svd_low_rank_plus_diag) directly. Will be removed in a future release.

Source code in src/gaussx/_operators/_low_rank_update.py
class SVDLowRankUpdate(LowRankUpdate):
    """Deprecated subclass of `LowRankUpdate` with ``orthonormal=True``.

    Preserves the pre-consolidation public API for one release:

    - Same constructor signature as the old class — ``S`` defaults to
      ones (via the parent ``LowRankUpdate``) if omitted, and ``V``
      defaults to ``U`` so calls like ``SVDLowRankUpdate(base, U, S)``
      and ``SVDLowRankUpdate(base, U)`` continue to work.
    - Inherits from `LowRankUpdate` so ``isinstance`` /
      ``issubclass`` checks and ``singledispatch`` registrations keyed
      on this class keep working.
    - Forces ``orthonormal=True`` and emits a
      `DeprecationWarning` on construction.

    New code should construct ``LowRankUpdate(base, U, S, V,
    orthonormal=True)`` (or use `svd_low_rank_plus_diag`)
    directly. Will be removed in a future release.
    """

    def __init__(
        self,
        base: lx.AbstractLinearOperator,
        U: Float[Array, "n k"],
        S: Float[Array, " k"] | None = None,
        V: Float[Array, "n k"] | None = None,
        *,
        tags: object | frozenset[object] = frozenset(),
    ) -> None:
        import warnings

        warnings.warn(
            "SVDLowRankUpdate is deprecated; use "
            "LowRankUpdate(..., orthonormal=True) or svd_low_rank_plus_diag().",
            DeprecationWarning,
            stacklevel=2,
        )
        super().__init__(base, U, S, V, tags=tags, orthonormal=True)

low_rank_plus_diag(diag: Float[Array, ' n'], U: Float[Array, 'n k'], d: Float[Array, ' k'] | None = None, V: Float[Array, 'n k'] | None = None) -> LowRankUpdate

Construct diag(diag) + U diag(d) Vᵀ.

Common pattern for inducing-point / Nystrom approximations where the base is a diagonal matrix.

Parameters:

Name Type Description Default
diag Float[Array, ' n']

Diagonal entries, shape (n,).

required
U Float[Array, 'n k']

Left factor, shape (n, k).

required
d Float[Array, ' k'] | None

Diagonal scaling, shape (k,). Defaults to ones.

None
V Float[Array, 'n k'] | None

Right factor, shape (n, k). Defaults to U.

None

Returns:

Type Description
LowRankUpdate

A LowRankUpdate with a DiagonalLinearOperator base.

Source code in src/gaussx/_operators/_low_rank_update.py
def low_rank_plus_diag(
    diag: Float[Array, " n"],
    U: Float[Array, "n k"],
    d: Float[Array, " k"] | None = None,
    V: Float[Array, "n k"] | None = None,
) -> LowRankUpdate:
    """Construct ``diag(diag) + U diag(d) Vᵀ``.

    Common pattern for inducing-point / Nystrom approximations
    where the base is a diagonal matrix.

    Args:
        diag: Diagonal entries, shape ``(n,)``.
        U: Left factor, shape ``(n, k)``.
        d: Diagonal scaling, shape ``(k,)``. Defaults to ones.
        V: Right factor, shape ``(n, k)``. Defaults to *U*.

    Returns:
        A ``LowRankUpdate`` with a ``DiagonalLinearOperator`` base.
    """
    return _low_rank_update_with_diag_base(diag, U, d, V)

low_rank_plus_identity(U: Float[Array, 'n k'], d: Float[Array, ' k'] | None = None, V: Float[Array, 'n k'] | None = None, *, scale: float = 1.0) -> LowRankUpdate

Construct scale * I + U diag(d) Vᵀ.

Common pattern for regularised low-rank models (e.g. noise + signal).

Parameters:

Name Type Description Default
U Float[Array, 'n k']

Left factor, shape (n, k).

required
d Float[Array, ' k'] | None

Diagonal scaling, shape (k,). Defaults to ones.

None
V Float[Array, 'n k'] | None

Right factor, shape (n, k). Defaults to U.

None
scale float

Scalar multiplier on the identity. Default 1.0.

1.0

Returns:

Type Description
LowRankUpdate

A LowRankUpdate with a scaled identity base.

Source code in src/gaussx/_operators/_low_rank_update.py
def low_rank_plus_identity(
    U: Float[Array, "n k"],
    d: Float[Array, " k"] | None = None,
    V: Float[Array, "n k"] | None = None,
    *,
    scale: float = 1.0,
) -> LowRankUpdate:
    """Construct ``scale * I + U diag(d) Vᵀ``.

    Common pattern for regularised low-rank models (e.g. noise + signal).

    Args:
        U: Left factor, shape ``(n, k)``.
        d: Diagonal scaling, shape ``(k,)``. Defaults to ones.
        V: Right factor, shape ``(n, k)``. Defaults to *U*.
        scale: Scalar multiplier on the identity. Default 1.0.

    Returns:
        A ``LowRankUpdate`` with a scaled identity base.
    """
    n = U.shape[0]
    diag = jnp.full(n, scale, dtype=U.dtype)
    return _low_rank_update_with_diag_base(diag, U, d, V)

svd_low_rank_plus_diag(diag: Float[Array, ' n'], U: Float[Array, 'n k'], S: Float[Array, ' k'], V: Float[Array, 'n k']) -> LowRankUpdate

Construct diag(diag) + U diag(S) Vᵀ from a truncated SVD.

Parameters:

Name Type Description Default
diag Float[Array, ' n']

Diagonal entries, shape (n,).

required
U Float[Array, 'n k']

Left singular vectors, shape (n, k).

required
S Float[Array, ' k']

Singular values, shape (k,).

required
V Float[Array, 'n k']

Right singular vectors, shape (n, k).

required

Returns:

Type Description
LowRankUpdate

A LowRankUpdate with a DiagonalLinearOperator base.

Source code in src/gaussx/_operators/_low_rank_update.py
def svd_low_rank_plus_diag(
    diag: Float[Array, " n"],
    U: Float[Array, "n k"],
    S: Float[Array, " k"],
    V: Float[Array, "n k"],
) -> LowRankUpdate:
    """Construct ``diag(diag) + U diag(S) Vᵀ`` from a truncated SVD.

    Args:
        diag: Diagonal entries, shape ``(n,)``.
        U: Left singular vectors, shape ``(n, k)``.
        S: Singular values, shape ``(k,)``.
        V: Right singular vectors, shape ``(n, k)``.

    Returns:
        A ``LowRankUpdate`` with a ``DiagonalLinearOperator`` base.
    """
    return _low_rank_update_with_diag_base(diag, U, S, V, orthonormal=True)

Banded & Toeplitz

Block-tridiagonal operators solve in \(O(N d^3)\) via block-banded Cholesky — the precision structure of Markovian (state-space) GPs. Symmetric Toeplitz operators get \(O(n \log n)\) matvecs and sampling via FFT circulant embedding.

Structured linear algebra and Gaussian primitives for JAX.

BlockTriDiag

Bases: AbstractLinearOperator

Symmetric block-tridiagonal operator.

Represents the structure:

[D_1  A_1^T              ]
[A_1  D_2   A_2^T        ]
[     A_2   D_3   ...    ]
[               A_{N-1} D_N]

where D_k are (d, d) diagonal blocks and A_k are (d, d) sub-diagonal blocks. This is the precision matrix structure arising from discretized SDEs in state-space GP inference.

All primitives (solve, logdet, cholesky, diag, trace) exploit the banded structure for O(Nd³) cost instead of O((Nd)³).

Parameters:

Name Type Description Default
diagonal Float[Array, 'N d d']

Diagonal blocks, shape (N, d, d).

required
sub_diagonal Float[Array, 'Nm1 d d']

Sub-diagonal blocks, shape (N-1, d, d).

required
Source code in src/gaussx/_operators/_block_tridiag.py
class BlockTriDiag(lx.AbstractLinearOperator):
    r"""Symmetric block-tridiagonal operator.

    Represents the structure:

        [D_1  A_1^T              ]
        [A_1  D_2   A_2^T        ]
        [     A_2   D_3   ...    ]
        [               A_{N-1} D_N]

    where ``D_k`` are ``(d, d)`` diagonal blocks and ``A_k`` are
    ``(d, d)`` sub-diagonal blocks. This is the precision matrix
    structure arising from discretized SDEs in state-space GP inference.

    All primitives (solve, logdet, cholesky, diag, trace) exploit the
    banded structure for O(Nd³) cost instead of O((Nd)³).

    Args:
        diagonal: Diagonal blocks, shape ``(N, d, d)``.
        sub_diagonal: Sub-diagonal blocks, shape ``(N-1, d, d)``.
    """

    diagonal: Float[Array, "N d d"]
    sub_diagonal: Float[Array, "Nm1 d d"]
    _num_blocks: int = eqx.field(static=True)
    _block_size: int = eqx.field(static=True)
    _size: int = eqx.field(static=True)
    _dtype: str = eqx.field(static=True)
    tags: frozenset[object] = eqx.field(static=True)

    def __init__(
        self,
        diagonal: Float[Array, "N d d"],
        sub_diagonal: Float[Array, "Nm1 d d"],
        *,
        tags: object | frozenset[object] = frozenset(),
    ) -> None:
        if diagonal.ndim != 3:
            raise ValueError(
                f"diagonal must have 3 dimensions (N, d, d), got {diagonal.ndim}."
            )
        if sub_diagonal.ndim != 3:
            raise ValueError(
                f"sub_diagonal must have 3 dimensions (N-1, d, d), "
                f"got {sub_diagonal.ndim}."
            )
        N, d, d2 = diagonal.shape
        if d != d2:
            raise ValueError(f"Diagonal blocks must be square, got ({d}, {d2}).")
        if sub_diagonal.shape[0] != N - 1:
            raise ValueError(
                f"sub_diagonal must have {N - 1} blocks, got {sub_diagonal.shape[0]}."
            )
        if sub_diagonal.shape[1] != d or sub_diagonal.shape[2] != d:
            raise ValueError(
                f"Sub-diagonal blocks must have shape ({d}, {d}), "
                f"got ({sub_diagonal.shape[1]}, {sub_diagonal.shape[2]})."
            )
        self.diagonal = diagonal
        self.sub_diagonal = sub_diagonal
        self._num_blocks = N
        self._block_size = d
        self._size = N * d
        self._dtype = str(diagonal.dtype)
        from gaussx._tags import block_tridiagonal_tag

        self.tags = _to_frozenset(tags) | {block_tridiagonal_tag, lx.symmetric_tag}

    def mv(self, vector: Float[Array, " n"]) -> Float[Array, " n"]:
        N = self._num_blocks
        d = self._block_size
        x = rearrange(vector, "(N d) -> N d", N=N, d=d)
        # Dₖ xₖ for all k
        result = einsum(self.diagonal, x, "N d1 d2, N d2 -> N d1")
        # Aₖ xₖ₋₁ for k = 1, ..., N-1 (sub-diagonal)
        sub_contrib = einsum(self.sub_diagonal, x[:-1], "N d1 d2, N d2 -> N d1")
        result = result.at[1:].add(sub_contrib)
        # Aₖᵀ xₖ₊₁ for k = 0, ..., N-2 (super-diagonal)
        super_contrib = einsum(self.sub_diagonal, x[1:], "N d1 d2, N d1 -> N d2")
        result = result.at[:-1].add(super_contrib)
        return rearrange(result, "N d -> (N d)")

    def as_matrix(self) -> Float[Array, "n n"]:
        N = self._num_blocks
        d = self._block_size
        n = self._size
        mat = jnp.zeros((n, n), dtype=jnp.dtype(self._dtype))
        for k in range(N):
            r = k * d
            mat = mat.at[r : r + d, r : r + d].set(self.diagonal[k])
        for k in range(N - 1):
            r = (k + 1) * d
            c = k * d
            mat = mat.at[r : r + d, c : c + d].set(self.sub_diagonal[k])
            mat = mat.at[c : c + d, r : r + d].set(self.sub_diagonal[k].T)
        return mat

    def transpose(self) -> BlockTriDiag:
        # as_matrix puts sub[k] at (k+1,k) and sub[k].T at (k,k+1).
        # Transposing: new (k+1,k) = old (k,k+1).T = sub[k].
        # So new sub_diagonal = self.sub_diagonal (unchanged).
        return BlockTriDiag(
            rearrange(self.diagonal, "N i j -> N j i"),
            self.sub_diagonal,
            tags=lx.transpose_tags(self.tags),
        )

    def in_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((self._size,), jnp.dtype(self._dtype))

    def out_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((self._size,), jnp.dtype(self._dtype))

    def add(self, other: BlockTriDiag) -> BlockTriDiag:
        """Add two block-tridiagonal operators (e.g. prior + likelihood sites)."""
        return BlockTriDiag(
            self.diagonal + other.diagonal,
            self.sub_diagonal + other.sub_diagonal,
        )

    def __add__(self, other: BlockTriDiag) -> BlockTriDiag:
        return self.add(other)

    def __radd__(self, other: object) -> BlockTriDiag:
        if isinstance(other, BlockTriDiag):
            return other.add(self)
        if other == 0:
            return self
        return NotImplemented

    def __sub__(self, other: BlockTriDiag) -> BlockTriDiag:
        return BlockTriDiag(
            self.diagonal - other.diagonal,
            self.sub_diagonal - other.sub_diagonal,
        )

    def __neg__(self) -> BlockTriDiag:
        return BlockTriDiag(-self.diagonal, -self.sub_diagonal)

    def __mul__(self, other: object) -> BlockTriDiag:
        scalar = jnp.asarray(other)
        if scalar.ndim != 0:
            msg = "BlockTriDiag can only be multiplied by a scalar"
            raise TypeError(msg)
        return BlockTriDiag(scalar * self.diagonal, scalar * self.sub_diagonal)

    def __rmul__(self, other: object) -> BlockTriDiag:
        return self.__mul__(other)

add(other: BlockTriDiag) -> BlockTriDiag

Add two block-tridiagonal operators (e.g. prior + likelihood sites).

Source code in src/gaussx/_operators/_block_tridiag.py
def add(self, other: BlockTriDiag) -> BlockTriDiag:
    """Add two block-tridiagonal operators (e.g. prior + likelihood sites)."""
    return BlockTriDiag(
        self.diagonal + other.diagonal,
        self.sub_diagonal + other.sub_diagonal,
    )

LowerBlockTriDiag

Bases: AbstractLinearOperator

Lower triangular block-bidiagonal Cholesky factor.

Represents:

[L_1              ]
[B_1  L_2          ]
[     B_2  L_3     ]
[          ...  L_N]

where L_k are (d, d) lower-triangular blocks and B_k are (d, d) sub-diagonal blocks.

Parameters:

Name Type Description Default
diagonal Float[Array, 'N d d']

Lower-triangular diagonal blocks, shape (N, d, d).

required
sub_diagonal Float[Array, 'Nm1 d d']

Sub-diagonal blocks, shape (N-1, d, d).

required
Source code in src/gaussx/_operators/_block_tridiag.py
class LowerBlockTriDiag(lx.AbstractLinearOperator):
    """Lower triangular block-bidiagonal Cholesky factor.

    Represents:

        [L_1              ]
        [B_1  L_2          ]
        [     B_2  L_3     ]
        [          ...  L_N]

    where ``L_k`` are ``(d, d)`` lower-triangular blocks and ``B_k`` are
    ``(d, d)`` sub-diagonal blocks.

    Args:
        diagonal: Lower-triangular diagonal blocks, shape ``(N, d, d)``.
        sub_diagonal: Sub-diagonal blocks, shape ``(N-1, d, d)``.
    """

    diagonal: Float[Array, "N d d"]
    sub_diagonal: Float[Array, "Nm1 d d"]
    _num_blocks: int = eqx.field(static=True)
    _block_size: int = eqx.field(static=True)
    _size: int = eqx.field(static=True)
    _dtype: str = eqx.field(static=True)
    tags: frozenset[object] = eqx.field(static=True)

    def __init__(
        self,
        diagonal: Float[Array, "N d d"],
        sub_diagonal: Float[Array, "Nm1 d d"],
        *,
        tags: object | frozenset[object] = frozenset(),
    ) -> None:
        N, d, _ = diagonal.shape
        self.diagonal = diagonal
        self.sub_diagonal = sub_diagonal
        self._num_blocks = N
        self._block_size = d
        self._size = N * d
        self._dtype = str(diagonal.dtype)
        self.tags = _to_frozenset(tags) | {lx.lower_triangular_tag}

    def mv(self, vector: Float[Array, " n"]) -> Float[Array, " n"]:
        N = self._num_blocks
        d = self._block_size
        x = rearrange(vector, "(N d) -> N d", N=N, d=d)
        # Lₖ xₖ
        result = einsum(self.diagonal, x, "N d1 d2, N d2 -> N d1")
        # Bₖ xₖ₋₁
        sub_contrib = einsum(self.sub_diagonal, x[:-1], "N d1 d2, N d2 -> N d1")
        result = result.at[1:].add(sub_contrib)
        return rearrange(result, "N d -> (N d)")

    def as_matrix(self) -> Float[Array, "n n"]:
        N = self._num_blocks
        d = self._block_size
        n = self._size
        mat = jnp.zeros((n, n), dtype=jnp.dtype(self._dtype))
        for k in range(N):
            r = k * d
            mat = mat.at[r : r + d, r : r + d].set(self.diagonal[k])
        for k in range(N - 1):
            r = (k + 1) * d
            c = k * d
            mat = mat.at[r : r + d, c : c + d].set(self.sub_diagonal[k])
        return mat

    def transpose(self) -> UpperBlockTriDiag:
        """Transpose gives upper block-bidiagonal."""
        return UpperBlockTriDiag(
            rearrange(self.diagonal, "N i j -> N j i"),
            rearrange(self.sub_diagonal, "N i j -> N j i"),
        )

    def in_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((self._size,), jnp.dtype(self._dtype))

    def out_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((self._size,), jnp.dtype(self._dtype))

transpose() -> UpperBlockTriDiag

Transpose gives upper block-bidiagonal.

Source code in src/gaussx/_operators/_block_tridiag.py
def transpose(self) -> UpperBlockTriDiag:
    """Transpose gives upper block-bidiagonal."""
    return UpperBlockTriDiag(
        rearrange(self.diagonal, "N i j -> N j i"),
        rearrange(self.sub_diagonal, "N i j -> N j i"),
    )

UpperBlockTriDiag

Bases: AbstractLinearOperator

Upper triangular block-bidiagonal (transpose of LowerBlockTriDiag).

Represents:

[U_1  C_1            ]
[     U_2  C_2        ]
[          ...   C_{N-1}]
[               U_N  ]

where U_k are upper-triangular diagonal blocks and C_k are super-diagonal blocks.

Source code in src/gaussx/_operators/_block_tridiag.py
class UpperBlockTriDiag(lx.AbstractLinearOperator):
    """Upper triangular block-bidiagonal (transpose of LowerBlockTriDiag).

    Represents:

        [U_1  C_1            ]
        [     U_2  C_2        ]
        [          ...   C_{N-1}]
        [               U_N  ]

    where ``U_k`` are upper-triangular diagonal blocks and ``C_k`` are
    super-diagonal blocks.
    """

    diagonal: Float[Array, "N d d"]
    super_diagonal: Float[Array, "Nm1 d d"]
    _num_blocks: int = eqx.field(static=True)
    _block_size: int = eqx.field(static=True)
    _size: int = eqx.field(static=True)
    _dtype: str = eqx.field(static=True)
    tags: frozenset[object] = eqx.field(static=True)

    def __init__(
        self,
        diagonal: Float[Array, "N d d"],
        super_diagonal: Float[Array, "Nm1 d d"],
        *,
        tags: object | frozenset[object] = frozenset(),
    ) -> None:
        N, d, _ = diagonal.shape
        self.diagonal = diagonal
        self.super_diagonal = super_diagonal
        self._num_blocks = N
        self._block_size = d
        self._size = N * d
        self._dtype = str(diagonal.dtype)
        self.tags = _to_frozenset(tags) | {lx.upper_triangular_tag}

    def mv(self, vector: Float[Array, " n"]) -> Float[Array, " n"]:
        N = self._num_blocks
        d = self._block_size
        x = rearrange(vector, "(N d) -> N d", N=N, d=d)
        # Uₖ xₖ
        result = einsum(self.diagonal, x, "N d1 d2, N d2 -> N d1")
        # Cₖ xₖ₊₁
        super_contrib = einsum(self.super_diagonal, x[1:], "N d1 d2, N d2 -> N d1")
        result = result.at[:-1].add(super_contrib)
        return rearrange(result, "N d -> (N d)")

    def as_matrix(self) -> Float[Array, "n n"]:
        N = self._num_blocks
        d = self._block_size
        n = self._size
        mat = jnp.zeros((n, n), dtype=jnp.dtype(self._dtype))
        for k in range(N):
            r = k * d
            mat = mat.at[r : r + d, r : r + d].set(self.diagonal[k])
        for k in range(N - 1):
            r = k * d
            c = (k + 1) * d
            mat = mat.at[r : r + d, c : c + d].set(self.super_diagonal[k])
        return mat

    def transpose(self) -> LowerBlockTriDiag:
        return LowerBlockTriDiag(
            rearrange(self.diagonal, "N i j -> N j i"),
            rearrange(self.super_diagonal, "N i j -> N j i"),
        )

    def in_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((self._size,), jnp.dtype(self._dtype))

    def out_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((self._size,), jnp.dtype(self._dtype))

Toeplitz

Bases: AbstractLinearOperator

Symmetric Toeplitz matrix from its first column.

K_{ij} = c_{|i-j|}. Stored as O(n) with O(n log n) matvec via circulant embedding and FFT.

For stationary kernels on regular 1-D grids the full kernel matrix is Toeplitz, so this gives an asymptotic win over dense storage.

Parameters:

Name Type Description Default
column Float[Array, ' n']

First column of the Toeplitz matrix, shape (n,).

required
Source code in src/gaussx/_operators/_toeplitz.py
class Toeplitz(lx.AbstractLinearOperator):
    r"""Symmetric Toeplitz matrix from its first column.

    ``K_{ij} = c_{|i-j|}``.  Stored as O(n) with O(n log n) matvec
    via circulant embedding and FFT.

    For stationary kernels on regular 1-D grids the full kernel matrix
    is Toeplitz, so this gives an asymptotic win over dense storage.

    Args:
        column: First column of the Toeplitz matrix, shape ``(n,)``.
    """

    column: Float[Array, " n"]
    _size: int = eqx.field(static=True)
    _dtype: str = eqx.field(static=True)
    tags: frozenset[object] = eqx.field(static=True)

    def __init__(
        self,
        column: Float[Array, " n"],
        *,
        tags: object | frozenset[object] = frozenset(),
    ) -> None:
        self.column = jnp.asarray(column)
        self._size = self.column.shape[0]
        self._dtype = str(self.column.dtype)
        self.tags = _to_frozenset(tags) | {lx.symmetric_tag}

    def mv(self, vector: Float[Array, " n"]) -> Float[Array, " n"]:
        return _toeplitz_mv(self.column, vector)

    def as_matrix(self) -> Float[Array, "n n"]:
        n = self._size
        indices = jnp.abs(jnp.arange(n)[:, None] - jnp.arange(n)[None, :])
        return self.column[indices]

    def transpose(self) -> Toeplitz:
        # Symmetric Toeplitz is self-transpose.
        return self

    def in_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((self._size,), jnp.dtype(self._dtype))

    def out_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((self._size,), jnp.dtype(self._dtype))

ToeplitzCholesky

Bases: AbstractLinearOperator

Circulant-embedding sample factor for a symmetric positive Toeplitz matrix.

The operator has shape (n, embedding_factor * n) and satisfies L @ L.T == Toeplitz(column) — it is a rectangular sample factor, not a traditional lower-triangular Cholesky factor. Applying it to standard normal white noise gives samples from 𝒩(0, Toeplitz(column)) when the Wood--Chan condition holds.

The Wood--Chan non-negativity check is implemented with eqx.error_if so it is JIT-friendly: the error fires at evaluation time rather than tracing time. If the embedding's spectrum has a materially negative eigenvalue, bump embedding_factor (typically 4 or 8 suffices for well-behaved covariances).

Parameters:

Name Type Description Default
column Float[Array, ' n']

First column of the Toeplitz matrix, shape (n,).

required
embedding_factor int

Circulant embedding size as a multiple of n.

2
Source code in src/gaussx/_operators/_toeplitz.py
class ToeplitzCholesky(lx.AbstractLinearOperator):
    """Circulant-embedding sample factor for a symmetric positive Toeplitz matrix.

    The operator has shape ``(n, embedding_factor * n)`` and satisfies
    ``L @ L.T == Toeplitz(column)`` — it is a rectangular sample factor,
    *not* a traditional lower-triangular Cholesky factor. Applying it to
    standard normal white noise gives samples from
    ``𝒩(0, Toeplitz(column))`` when the Wood--Chan condition holds.

    The Wood--Chan non-negativity check is implemented with
    ``eqx.error_if`` so it is JIT-friendly: the error fires at evaluation
    time rather than tracing time. If the embedding's spectrum has a
    materially negative eigenvalue, bump ``embedding_factor`` (typically
    ``4`` or ``8`` suffices for well-behaved covariances).

    Args:
        column: First column of the Toeplitz matrix, shape ``(n,)``.
        embedding_factor: Circulant embedding size as a multiple of ``n``.
    """

    sqrt_spectrum: Float[Array, " m_rfft"]
    _size: int = eqx.field(static=True)
    _embedding_size: int = eqx.field(static=True)
    _dtype: str = eqx.field(static=True)
    tags: frozenset[object] = eqx.field(static=True)

    def __init__(
        self,
        column: Float[Array, " n"],
        *,
        embedding_factor: int = 2,
    ) -> None:
        column = _as_floating_column(column)
        self.sqrt_spectrum = _circulant_sqrt_spectrum(
            column,
            embedding_factor=embedding_factor,
        )
        self._size = column.shape[0]
        self._embedding_size = embedding_factor * self._size
        self._dtype = str(column.dtype)
        self.tags = frozenset()

    def mv(self, vector: Float[Array, " m"]) -> Float[Array, " n"]:
        return _circulant_sqrt_mv(self.sqrt_spectrum, vector, self._size)

    def as_matrix(self) -> Float[Array, "n m"]:
        eye = jnp.eye(self._embedding_size, dtype=jnp.dtype(self._dtype))
        return jax.vmap(self.mv, in_axes=1, out_axes=1)(eye)

    def transpose(self) -> _ToeplitzCholeskyTranspose:
        return _ToeplitzCholeskyTranspose(self)

    def in_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((self._embedding_size,), jnp.dtype(self._dtype))

    def out_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((self._size,), jnp.dtype(self._dtype))

Kernel operators

Kernel matrices as operators — dense (KernelOperator), matrix-free (ImplicitKernelOperator, rows generated on the fly per matvec), rectangular cross-kernels, and grid-interpolated (KISS-GP style) variants.

Structured linear algebra and Gaussian primitives for JAX.

KernelOperator

Bases: AbstractLinearOperator

Kernel matrix operator with efficient first-order autodiff.

Represents the matrix K where K[i, j] = kernel_fn(params, X1[i], X2[j]). The matvec K @ v is computed via scan (O(N) memory), and a jax.custom_jvp keeps first-order autodiff efficient without materializing Jacobians.

Batched inputs are supported: X1 and X2 may carry leading batch dimensions (*batch, N, D) / (*batch, M, D) (with matching *batch). In that case mv expects a vector of shape (*batch, M) and returns (*batch, N); as_matrix() returns a (*batch, N, M) tensor; in_structure() / out_structure() report the batched shapes so lineax helpers (linear_solve, probe-vector allocators) construct compatible inputs.

Parameters:

Name Type Description Default
kernel_fn Callable

Kernel function k(params, x, x') -> scalar. The first argument is a pytree of hyperparameters.

required
X1 Float[Array, 'N D']

First set of data points, shape (*batch, N, D). Leading batch dimensions are optional.

required
X2 Float[Array, 'M D']

Second set of data points, shape (*batch, M, D) with *batch matching X1.

required
params Any

Pytree of kernel hyperparameters (differentiable).

required
tags object | frozenset[object]

Optional lineax structural tags.

frozenset()
Source code in src/gaussx/_operators/_kernel.py
class KernelOperator(lx.AbstractLinearOperator):
    r"""Kernel matrix operator with efficient first-order autodiff.

    Represents the matrix ``K`` where ``K[i, j] = kernel_fn(params, X1[i], X2[j])``.
    The matvec ``K @ v`` is computed via scan (O(N) memory), and a
    ``jax.custom_jvp`` keeps first-order autodiff efficient without
    materializing Jacobians.

    Batched inputs are supported: ``X1`` and ``X2`` may carry leading
    batch dimensions ``(*batch, N, D)`` / ``(*batch, M, D)`` (with
    matching ``*batch``). In that case ``mv`` expects a vector of shape
    ``(*batch, M)`` and returns ``(*batch, N)``; ``as_matrix()`` returns
    a ``(*batch, N, M)`` tensor; ``in_structure()`` / ``out_structure()``
    report the batched shapes so lineax helpers (``linear_solve``,
    probe-vector allocators) construct compatible inputs.

    Args:
        kernel_fn: Kernel function ``k(params, x, x') -> scalar``.  The first
            argument is a pytree of hyperparameters.
        X1: First set of data points, shape ``(*batch, N, D)``. Leading
            batch dimensions are optional.
        X2: Second set of data points, shape ``(*batch, M, D)`` with
            ``*batch`` matching ``X1``.
        params: Pytree of kernel hyperparameters (differentiable).
        tags: Optional lineax structural tags.
    """

    kernel_fn: Callable = eqx.field(static=True)
    X1: Float[Array, "*batch N D"]
    X2: Float[Array, "*batch M D"]
    params: Any  # pytree of kernel hyperparameters
    _nrows: int = eqx.field(static=True)
    _ncols: int = eqx.field(static=True)
    _batch_shape: tuple[int, ...] = eqx.field(static=True)
    tags: frozenset[object] = eqx.field(static=True)
    _kernel_mv: Callable = eqx.field(static=True)

    def __init__(
        self,
        kernel_fn: Callable,
        X1: Float[Array, "N D"],
        X2: Float[Array, "M D"],
        params: Any,
        *,
        tags: object | frozenset[object] = frozenset(),
    ) -> None:
        if X1.shape[:-2] != X2.shape[:-2]:
            raise ValueError("X1 and X2 must have matching batch shapes.")
        self.kernel_fn = kernel_fn
        self.X1 = X1
        self.X2 = X2
        self.params = params
        self._nrows = X1.shape[-2]
        self._ncols = X2.shape[-2]
        self._batch_shape = X1.shape[:-2]
        normalized_tags = _to_frozenset(tags)
        if lx.positive_semidefinite_tag in normalized_tags:
            if self._nrows != self._ncols:
                raise ValueError(
                    "positive_semidefinite_tag is only valid for square operators."
                )
            normalized_tags = normalized_tags | {lx.symmetric_tag}
        self.tags = normalized_tags
        self._kernel_mv = _make_kernel_mv(kernel_fn)

    def mv(self, vector: Float[Array, "*batch M"]) -> Float[Array, "*batch N"]:
        """Compute ``K @ v`` via scan with custom JVP support."""
        if vector.shape[:-1] != self._batch_shape:
            raise ValueError(
                "vector must have leading batch dimensions matching X1/X2."
            )
        if not self._batch_shape:
            return self._kernel_mv(self.params, self.X1, self.X2, vector)
        batched_mv = vmap_over_batch_dims(
            lambda x1, x2, v: self._kernel_mv(self.params, x1, x2, v),
            len(self._batch_shape),
        )
        return batched_mv(self.X1, self.X2, vector)

    def as_matrix(self) -> Float[Array, "*batch N M"]:
        """Materialize the full kernel matrix."""
        matrix_fn = lambda X1, X2: jax.vmap(
            lambda x_i: jax.vmap(lambda x_j: self.kernel_fn(self.params, x_i, x_j))(X2)
        )(X1)
        if not self._batch_shape:
            return matrix_fn(self.X1, self.X2)
        batched_matrix_fn = vmap_over_batch_dims(matrix_fn, len(self._batch_shape))
        return batched_matrix_fn(self.X1, self.X2)

    def transpose(self) -> KernelOperator:
        """Return the transpose operator (X1, X2 swapped, kernel transposed)."""
        if lx.symmetric_tag in self.tags:
            return self
        return KernelOperator(
            lambda p, x_i, x_j: self.kernel_fn(p, x_j, x_i),
            self.X2,
            self.X1,
            self.params,
            tags=lx.transpose_tags(self.tags),
        )

    def in_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((*self._batch_shape, self._ncols), self.X1.dtype)

    def out_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((*self._batch_shape, self._nrows), self.X1.dtype)

mv(vector: Float[Array, '*batch M']) -> Float[Array, '*batch N']

Compute K @ v via scan with custom JVP support.

Source code in src/gaussx/_operators/_kernel.py
def mv(self, vector: Float[Array, "*batch M"]) -> Float[Array, "*batch N"]:
    """Compute ``K @ v`` via scan with custom JVP support."""
    if vector.shape[:-1] != self._batch_shape:
        raise ValueError(
            "vector must have leading batch dimensions matching X1/X2."
        )
    if not self._batch_shape:
        return self._kernel_mv(self.params, self.X1, self.X2, vector)
    batched_mv = vmap_over_batch_dims(
        lambda x1, x2, v: self._kernel_mv(self.params, x1, x2, v),
        len(self._batch_shape),
    )
    return batched_mv(self.X1, self.X2, vector)

as_matrix() -> Float[Array, '*batch N M']

Materialize the full kernel matrix.

Source code in src/gaussx/_operators/_kernel.py
def as_matrix(self) -> Float[Array, "*batch N M"]:
    """Materialize the full kernel matrix."""
    matrix_fn = lambda X1, X2: jax.vmap(
        lambda x_i: jax.vmap(lambda x_j: self.kernel_fn(self.params, x_i, x_j))(X2)
    )(X1)
    if not self._batch_shape:
        return matrix_fn(self.X1, self.X2)
    batched_matrix_fn = vmap_over_batch_dims(matrix_fn, len(self._batch_shape))
    return batched_matrix_fn(self.X1, self.X2)

transpose() -> KernelOperator

Return the transpose operator (X1, X2 swapped, kernel transposed).

Source code in src/gaussx/_operators/_kernel.py
def transpose(self) -> KernelOperator:
    """Return the transpose operator (X1, X2 swapped, kernel transposed)."""
    if lx.symmetric_tag in self.tags:
        return self
    return KernelOperator(
        lambda p, x_i, x_j: self.kernel_fn(p, x_j, x_i),
        self.X2,
        self.X1,
        self.params,
        tags=lx.transpose_tags(self.tags),
    )

ImplicitKernelOperator

Bases: AbstractLinearOperator

Matrix-free kernel operator: (K + sigma^2 I) v via sequential scan.

Computes the kernel matvec without materializing the N x N kernel matrix, using O(N) memory instead of O(N^2). Each element of the output is computed as:

y_i = \sum_j k(x_i, x_j) v_j + sigma^2 v_i

The scan-based implementation is compatible with CG / BBMM solvers that only need matvec access.

Supports two kernel signatures:

  • No params (default): k(x, x') -> scalar. Hyperparameters are closed over in the lambda.
  • With params: k(params, x, x') -> scalar. Pass a pytree of differentiable hyperparameters via the params argument and a jax.custom_jvp keeps first-order autodiff efficient.

Batched inputs are supported: X may carry leading batch dimensions (*batch, N, D). In that case mv expects a vector of shape (*batch, N) and returns (*batch, N); as_matrix() returns a (*batch, N, N) tensor; in_structure() / out_structure() report the batched shapes.

Parameters:

Name Type Description Default
kernel_fn Callable

Kernel function (see above for signature).

required
X Float[Array, 'N D']

Training points, shape (*batch, N, D). Leading batch dimensions are optional.

required
noise_var float

Diagonal noise variance sigma^2.

0.0
params Any | None

Optional pytree of kernel hyperparameters.

None
Source code in src/gaussx/_operators/_implicit_kernel.py
class ImplicitKernelOperator(lx.AbstractLinearOperator):
    r"""Matrix-free kernel operator: ``(K + sigma^2 I) v`` via sequential scan.

    Computes the kernel matvec without materializing the ``N x N`` kernel
    matrix, using ``O(N)`` memory instead of ``O(N^2)``.  Each element of
    the output is computed as:

        y_i = \sum_j k(x_i, x_j) v_j + sigma^2 v_i

    The scan-based implementation is compatible with CG / BBMM solvers
    that only need matvec access.

    Supports two kernel signatures:

    - **No params** (default): ``k(x, x') -> scalar``.  Hyperparameters
      are closed over in the lambda.
    - **With params**: ``k(params, x, x') -> scalar``.  Pass a pytree of
      differentiable hyperparameters via the ``params`` argument and a
      ``jax.custom_jvp`` keeps first-order autodiff efficient.

    Batched inputs are supported: ``X`` may carry leading batch
    dimensions ``(*batch, N, D)``. In that case ``mv`` expects a vector
    of shape ``(*batch, N)`` and returns ``(*batch, N)``;
    ``as_matrix()`` returns a ``(*batch, N, N)`` tensor;
    ``in_structure()`` / ``out_structure()`` report the batched shapes.

    Args:
        kernel_fn: Kernel function (see above for signature).
        X: Training points, shape ``(*batch, N, D)``. Leading batch
            dimensions are optional.
        noise_var: Diagonal noise variance ``sigma^2``.
        params: Optional pytree of kernel hyperparameters.
    """

    kernel_fn: Callable = eqx.field(static=True)
    X: Float[Array, "*batch N D"]
    noise_var: float = eqx.field(static=True)
    params: Any
    _size: int = eqx.field(static=True)
    _batch_shape: tuple[int, ...] = eqx.field(static=True)
    _has_params: bool = eqx.field(static=True)
    tags: frozenset[object] = eqx.field(static=True)
    _kernel_mv: Callable | None = eqx.field(static=True)

    def __init__(
        self,
        kernel_fn: Callable,
        X: Float[Array, "N D"],
        noise_var: float = 0.0,
        *,
        params: Any | None = None,
        tags: object | frozenset[object] = frozenset(),
    ) -> None:
        self.kernel_fn = kernel_fn
        self.X = X
        self.noise_var = noise_var
        self._size = X.shape[-2]
        self._batch_shape = X.shape[:-2]
        self._has_params = params is not None
        self.params = params
        normalized_tags = _to_frozenset(tags)
        if lx.positive_semidefinite_tag in normalized_tags:
            normalized_tags = normalized_tags | {lx.symmetric_tag}
        self.tags = normalized_tags
        if self._has_params:
            self._kernel_mv = _make_implicit_kernel_mv(kernel_fn)
        else:
            self._kernel_mv = None

    def mv(self, vector: Float[Array, "*batch N"]) -> Float[Array, "*batch N"]:
        """Compute ``(K + sigma^2 I) @ v`` via scan over data points."""
        if vector.shape[:-1] != self._batch_shape:
            raise ValueError("vector must have leading batch dimensions matching X.")

        def mv_single(
            X: Float[Array, "N D"], v: Float[Array, " N"]
        ) -> Float[Array, " N"]:
            if self._has_params:
                assert self._kernel_mv is not None
                return self._kernel_mv(self.params, X, v)

            def row_dot(x_i: Float[Array, " D"]) -> Float[Array, ""]:
                k_row = jax.vmap(lambda x_j: self.kernel_fn(x_i, x_j))(X)
                return jnp.dot(k_row, v)

            def body_fn(
                carry: None, x_i: Float[Array, " D"]
            ) -> tuple[None, Float[Array, ""]]:
                return carry, row_dot(x_i)

            _, Kv = jax.lax.scan(body_fn, None, X)
            return Kv

        if not self._batch_shape:
            Kv = mv_single(self.X, vector)
        else:
            batched_mv = vmap_over_batch_dims(mv_single, len(self._batch_shape))
            Kv = batched_mv(self.X, vector)

        if self.noise_var != 0.0:
            Kv = Kv + self.noise_var * vector
        return Kv

    def as_matrix(self) -> Float[Array, "*batch N N"]:
        """Materialize the full kernel matrix (for debugging/testing)."""
        if self._has_params:
            matrix_fn = lambda X: jax.vmap(
                lambda x_i: jax.vmap(lambda x_j: self.kernel_fn(self.params, x_i, x_j))(
                    X
                )
            )(X)
        else:
            matrix_fn = lambda X: jax.vmap(
                lambda x_i: jax.vmap(lambda x_j: self.kernel_fn(x_i, x_j))(X)
            )(X)
        if not self._batch_shape:
            K = matrix_fn(self.X)
        else:
            batched_matrix_fn = vmap_over_batch_dims(matrix_fn, len(self._batch_shape))
            K = batched_matrix_fn(self.X)
        if self.noise_var != 0.0:
            K = K + self.noise_var * jnp.eye(self._size)
        return K

    def transpose(self) -> ImplicitKernelOperator:
        if lx.symmetric_tag in self.tags:
            return self
        if self._has_params:
            return ImplicitKernelOperator(
                lambda p, x_i, x_j: self.kernel_fn(p, x_j, x_i),
                self.X,
                noise_var=self.noise_var,
                params=self.params,
                tags=lx.transpose_tags(self.tags),
            )
        return ImplicitKernelOperator(
            lambda x_i, x_j: self.kernel_fn(x_j, x_i),
            self.X,
            noise_var=self.noise_var,
            tags=lx.transpose_tags(self.tags),
        )

    def in_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((*self._batch_shape, self._size), self.X.dtype)

    def out_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((*self._batch_shape, self._size), self.X.dtype)

mv(vector: Float[Array, '*batch N']) -> Float[Array, '*batch N']

Compute (K + sigma^2 I) @ v via scan over data points.

Source code in src/gaussx/_operators/_implicit_kernel.py
def mv(self, vector: Float[Array, "*batch N"]) -> Float[Array, "*batch N"]:
    """Compute ``(K + sigma^2 I) @ v`` via scan over data points."""
    if vector.shape[:-1] != self._batch_shape:
        raise ValueError("vector must have leading batch dimensions matching X.")

    def mv_single(
        X: Float[Array, "N D"], v: Float[Array, " N"]
    ) -> Float[Array, " N"]:
        if self._has_params:
            assert self._kernel_mv is not None
            return self._kernel_mv(self.params, X, v)

        def row_dot(x_i: Float[Array, " D"]) -> Float[Array, ""]:
            k_row = jax.vmap(lambda x_j: self.kernel_fn(x_i, x_j))(X)
            return jnp.dot(k_row, v)

        def body_fn(
            carry: None, x_i: Float[Array, " D"]
        ) -> tuple[None, Float[Array, ""]]:
            return carry, row_dot(x_i)

        _, Kv = jax.lax.scan(body_fn, None, X)
        return Kv

    if not self._batch_shape:
        Kv = mv_single(self.X, vector)
    else:
        batched_mv = vmap_over_batch_dims(mv_single, len(self._batch_shape))
        Kv = batched_mv(self.X, vector)

    if self.noise_var != 0.0:
        Kv = Kv + self.noise_var * vector
    return Kv

as_matrix() -> Float[Array, '*batch N N']

Materialize the full kernel matrix (for debugging/testing).

Source code in src/gaussx/_operators/_implicit_kernel.py
def as_matrix(self) -> Float[Array, "*batch N N"]:
    """Materialize the full kernel matrix (for debugging/testing)."""
    if self._has_params:
        matrix_fn = lambda X: jax.vmap(
            lambda x_i: jax.vmap(lambda x_j: self.kernel_fn(self.params, x_i, x_j))(
                X
            )
        )(X)
    else:
        matrix_fn = lambda X: jax.vmap(
            lambda x_i: jax.vmap(lambda x_j: self.kernel_fn(x_i, x_j))(X)
        )(X)
    if not self._batch_shape:
        K = matrix_fn(self.X)
    else:
        batched_matrix_fn = vmap_over_batch_dims(matrix_fn, len(self._batch_shape))
        K = batched_matrix_fn(self.X)
    if self.noise_var != 0.0:
        K = K + self.noise_var * jnp.eye(self._size)
    return K

ImplicitCrossKernelOperator

Bases: AbstractLinearOperator

Matrix-free rectangular kernel operator K(X, Z) \cdot v.

Computes the cross-kernel matvec without materializing the full N x M kernel matrix, using a batched scan that keeps peak memory at O(batch\_size \times M) per step.

Forward matvec (mv):

y_i = \sum_j k(x_i, z_j) \cdot v_j

maps an M-vector to an N-vector.

Adjoint / transpose computes K^T u = K(Z, X) u, mapping an N-vector to an M-vector.

Supports two kernel signatures:

  • No params (default): k(x, z) -> scalar.
  • With params: k(params, x, z) -> scalar. Pass a pytree of differentiable hyperparameters and a jax.custom_jvp keeps first-order autodiff efficient.

Batched inputs are supported: X_data and X_inducing may carry leading batch dimensions (*batch, N, D) / (*batch, M, D) (with matching *batch). In that case mv expects a vector of shape (*batch, M) and returns (*batch, N); the transposed operator follows the symmetric pattern. as_matrix() returns a (*batch, N, M) tensor; in_structure() / out_structure() report the batched shapes.

Parameters:

Name Type Description Default
kernel_fn Callable

Kernel function (see above for signature).

required
X_data Float[Array, 'N D']

Data points, shape (*batch, N, D). Leading batch dimensions are optional.

required
X_inducing Float[Array, 'M D']

Inducing points, shape (*batch, M, D) with *batch matching X_data.

required
batch_size int

Number of rows of X_data processed per scan step.

1024
params Any | None

Optional pytree of kernel hyperparameters.

None
Source code in src/gaussx/_operators/_implicit_cross_kernel.py
class ImplicitCrossKernelOperator(lx.AbstractLinearOperator):
    r"""Matrix-free rectangular kernel operator ``K(X, Z) \cdot v``.

    Computes the cross-kernel matvec without materializing the full
    ``N x M`` kernel matrix, using a batched scan that keeps peak memory
    at ``O(batch\_size \times M)`` per step.

    Forward matvec (``mv``):

        y_i = \sum_j k(x_i, z_j) \cdot v_j

    maps an ``M``-vector to an ``N``-vector.

    Adjoint / transpose computes ``K^T u = K(Z, X) u``, mapping an
    ``N``-vector to an ``M``-vector.

    Supports two kernel signatures:

    - **No params** (default): ``k(x, z) -> scalar``.
    - **With params**: ``k(params, x, z) -> scalar``.  Pass a pytree of
      differentiable hyperparameters and a ``jax.custom_jvp`` keeps
      first-order autodiff efficient.

    Batched inputs are supported: ``X_data`` and ``X_inducing`` may
    carry leading batch dimensions ``(*batch, N, D)`` /
    ``(*batch, M, D)`` (with matching ``*batch``). In that case ``mv``
    expects a vector of shape ``(*batch, M)`` and returns ``(*batch, N)``;
    the transposed operator follows the symmetric pattern. ``as_matrix()``
    returns a ``(*batch, N, M)`` tensor; ``in_structure()`` /
    ``out_structure()`` report the batched shapes.

    Args:
        kernel_fn: Kernel function (see above for signature).
        X_data: Data points, shape ``(*batch, N, D)``. Leading batch
            dimensions are optional.
        X_inducing: Inducing points, shape ``(*batch, M, D)`` with
            ``*batch`` matching ``X_data``.
        batch_size: Number of rows of ``X_data`` processed per scan step.
        params: Optional pytree of kernel hyperparameters.
    """

    kernel_fn: Callable = eqx.field(static=True)
    X_data: Float[Array, "*batch N D"]
    X_inducing: Float[Array, "*batch M D"]
    params: Any
    batch_size: int = eqx.field(static=True)
    _n: int = eqx.field(static=True)
    _m: int = eqx.field(static=True)
    _batch_shape: tuple[int, ...] = eqx.field(static=True)
    _has_params: bool = eqx.field(static=True)
    tags: frozenset[object] = eqx.field(static=True)
    _kernel_mv: Callable | None = eqx.field(static=True)

    def __init__(
        self,
        kernel_fn: Callable,
        X_data: Float[Array, "N D"],
        X_inducing: Float[Array, "M D"],
        batch_size: int = 1024,
        *,
        params: Any | None = None,
        tags: object | frozenset[object] = frozenset(),
    ) -> None:
        if batch_size < 1:
            raise ValueError(
                f"batch_size must be a positive integer, got {batch_size}."
            )
        if X_data.shape[:-2] != X_inducing.shape[:-2]:
            raise ValueError("X_data and X_inducing must have matching batch shapes.")
        self.kernel_fn = kernel_fn
        self.X_data = X_data
        self.X_inducing = X_inducing
        self.params = params
        self.batch_size = batch_size
        self._n = X_data.shape[-2]
        self._m = X_inducing.shape[-2]
        self._batch_shape = X_data.shape[:-2]
        self._has_params = params is not None
        normalized_tags = _to_frozenset(tags)
        if lx.positive_semidefinite_tag in normalized_tags:
            if self._n != self._m:
                raise ValueError(
                    "positive_semidefinite_tag is only valid for square operators."
                )
            normalized_tags = normalized_tags | {lx.symmetric_tag}
        self.tags = normalized_tags
        if self._has_params:
            self._kernel_mv = _make_cross_kernel_mv(kernel_fn, batch_size)
        else:
            self._kernel_mv = None

    def mv(self, vector: Float[Array, "*batch M"]) -> Float[Array, "*batch N"]:
        """Compute ``K(X_data, X_inducing) @ v`` via batched scan.

        Peak memory per step is ``O(batch_size * M)``.
        """
        if vector.shape[:-1] != self._batch_shape:
            raise ValueError(
                "vector must have leading batch dimensions matching X_data/X_inducing."
            )

        def mv_single(
            X_data: Float[Array, "N D"],
            X_inducing: Float[Array, "M D"],
            v: Float[Array, " M"],
        ) -> Float[Array, " N"]:
            if self._has_params:
                assert self._kernel_mv is not None
                return self._kernel_mv(self.params, X_data, X_inducing, v)

            n = self._n
            bs = self.batch_size
            n_padded = ((n + bs - 1) // bs) * bs
            pad_amount = n_padded - n
            X_padded = jnp.pad(X_data, ((0, pad_amount), (0, 0)), mode="constant")
            X_batched = rearrange(X_padded, "(B bs) D -> B bs D", bs=bs)

            def batch_matvec(
                carry: None, X_batch: Float[Array, "batch_size D"]
            ) -> tuple[None, Float[Array, " batch_size"]]:
                K_batch = jax.vmap(
                    lambda x_i: jax.vmap(lambda z_j: self.kernel_fn(x_i, z_j))(
                        X_inducing
                    )
                )(X_batch)
                return carry, K_batch @ v

            _, results = jax.lax.scan(batch_matvec, None, X_batched)
            return rearrange(results, "B bs -> (B bs)")[:n]

        if not self._batch_shape:
            return mv_single(self.X_data, self.X_inducing, vector)
        batched_mv = vmap_over_batch_dims(mv_single, len(self._batch_shape))
        return batched_mv(self.X_data, self.X_inducing, vector)

    def transpose(self) -> _TransposedCrossKernelOperator:
        """Return the adjoint operator ``K^T``.

        Uses a dedicated adjoint matvec that scans over batches of
        ``X_data`` and accumulates ``K_batch^T @ u_batch`` into an
        ``(M,)`` result, keeping peak memory at ``O(batch_size x M)``.
        """
        if lx.symmetric_tag in self.tags:
            return _TransposedCrossKernelOperator(self, tags=self.tags)
        return _TransposedCrossKernelOperator(self, tags=lx.transpose_tags(self.tags))

    def as_matrix(self) -> Float[Array, "*batch N M"]:
        """Materialize the full ``N x M`` cross-kernel matrix."""
        if self._has_params:
            matrix_fn = lambda X_data, X_inducing: jax.vmap(
                lambda x_i: jax.vmap(lambda z_j: self.kernel_fn(self.params, x_i, z_j))(
                    X_inducing
                )
            )(X_data)
        else:
            matrix_fn = lambda X_data, X_inducing: jax.vmap(
                lambda x_i: jax.vmap(lambda z_j: self.kernel_fn(x_i, z_j))(X_inducing)
            )(X_data)
        if not self._batch_shape:
            return matrix_fn(self.X_data, self.X_inducing)
        batched_matrix_fn = vmap_over_batch_dims(matrix_fn, len(self._batch_shape))
        return batched_matrix_fn(self.X_data, self.X_inducing)

    def in_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((*self._batch_shape, self._m), self.X_data.dtype)

    def out_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((*self._batch_shape, self._n), self.X_data.dtype)

mv(vector: Float[Array, '*batch M']) -> Float[Array, '*batch N']

Compute K(X_data, X_inducing) @ v via batched scan.

Peak memory per step is O(batch_size * M).

Source code in src/gaussx/_operators/_implicit_cross_kernel.py
def mv(self, vector: Float[Array, "*batch M"]) -> Float[Array, "*batch N"]:
    """Compute ``K(X_data, X_inducing) @ v`` via batched scan.

    Peak memory per step is ``O(batch_size * M)``.
    """
    if vector.shape[:-1] != self._batch_shape:
        raise ValueError(
            "vector must have leading batch dimensions matching X_data/X_inducing."
        )

    def mv_single(
        X_data: Float[Array, "N D"],
        X_inducing: Float[Array, "M D"],
        v: Float[Array, " M"],
    ) -> Float[Array, " N"]:
        if self._has_params:
            assert self._kernel_mv is not None
            return self._kernel_mv(self.params, X_data, X_inducing, v)

        n = self._n
        bs = self.batch_size
        n_padded = ((n + bs - 1) // bs) * bs
        pad_amount = n_padded - n
        X_padded = jnp.pad(X_data, ((0, pad_amount), (0, 0)), mode="constant")
        X_batched = rearrange(X_padded, "(B bs) D -> B bs D", bs=bs)

        def batch_matvec(
            carry: None, X_batch: Float[Array, "batch_size D"]
        ) -> tuple[None, Float[Array, " batch_size"]]:
            K_batch = jax.vmap(
                lambda x_i: jax.vmap(lambda z_j: self.kernel_fn(x_i, z_j))(
                    X_inducing
                )
            )(X_batch)
            return carry, K_batch @ v

        _, results = jax.lax.scan(batch_matvec, None, X_batched)
        return rearrange(results, "B bs -> (B bs)")[:n]

    if not self._batch_shape:
        return mv_single(self.X_data, self.X_inducing, vector)
    batched_mv = vmap_over_batch_dims(mv_single, len(self._batch_shape))
    return batched_mv(self.X_data, self.X_inducing, vector)

transpose() -> _TransposedCrossKernelOperator

Return the adjoint operator K^T.

Uses a dedicated adjoint matvec that scans over batches of X_data and accumulates K_batch^T @ u_batch into an (M,) result, keeping peak memory at O(batch_size x M).

Source code in src/gaussx/_operators/_implicit_cross_kernel.py
def transpose(self) -> _TransposedCrossKernelOperator:
    """Return the adjoint operator ``K^T``.

    Uses a dedicated adjoint matvec that scans over batches of
    ``X_data`` and accumulates ``K_batch^T @ u_batch`` into an
    ``(M,)`` result, keeping peak memory at ``O(batch_size x M)``.
    """
    if lx.symmetric_tag in self.tags:
        return _TransposedCrossKernelOperator(self, tags=self.tags)
    return _TransposedCrossKernelOperator(self, tags=lx.transpose_tags(self.tags))

as_matrix() -> Float[Array, '*batch N M']

Materialize the full N x M cross-kernel matrix.

Source code in src/gaussx/_operators/_implicit_cross_kernel.py
def as_matrix(self) -> Float[Array, "*batch N M"]:
    """Materialize the full ``N x M`` cross-kernel matrix."""
    if self._has_params:
        matrix_fn = lambda X_data, X_inducing: jax.vmap(
            lambda x_i: jax.vmap(lambda z_j: self.kernel_fn(self.params, x_i, z_j))(
                X_inducing
            )
        )(X_data)
    else:
        matrix_fn = lambda X_data, X_inducing: jax.vmap(
            lambda x_i: jax.vmap(lambda z_j: self.kernel_fn(x_i, z_j))(X_inducing)
        )(X_data)
    if not self._batch_shape:
        return matrix_fn(self.X_data, self.X_inducing)
    batched_matrix_fn = vmap_over_batch_dims(matrix_fn, len(self._batch_shape))
    return batched_matrix_fn(self.X_data, self.X_inducing)

InterpolatedOperator

Bases: AbstractLinearOperator

Structured Kernel Interpolation: K \approx W K_{uu} W^T.

W is a sparse interpolation matrix with p nonzeros per row (e.g. cubic interpolation weights). The base operator K_{uu} acts on the inducing grid (typically Toeplitz for stationary kernels).

Total matvec cost: O(n p + m log m) when the base is Toeplitz, essentially linear in n.

Parameters:

Name Type Description Default
base_operator AbstractLinearOperator

The inducing-point kernel K_{uu}, shape (m, m).

required
interp_indices Int[Array, 'n p']

Integer indices into the inducing grid, shape (n, p) where p is the interpolation order.

required
interp_values Float[Array, 'n p']

Interpolation weights, shape (n, p).

required
Source code in src/gaussx/_operators/_interpolated.py
class InterpolatedOperator(lx.AbstractLinearOperator):
    r"""Structured Kernel Interpolation: ``K \approx W K_{uu} W^T``.

    ``W`` is a sparse interpolation matrix with ``p`` nonzeros per row
    (e.g. cubic interpolation weights).  The base operator ``K_{uu}``
    acts on the inducing grid (typically Toeplitz for stationary
    kernels).

    Total matvec cost: ``O(n p + m log m)`` when the base is Toeplitz,
    essentially linear in ``n``.

    Args:
        base_operator: The inducing-point kernel ``K_{uu}``, shape ``(m, m)``.
        interp_indices: Integer indices into the inducing grid,
            shape ``(n, p)`` where ``p`` is the interpolation order.
        interp_values: Interpolation weights, shape ``(n, p)``.
    """

    base_operator: lx.AbstractLinearOperator
    interp_indices: Int[Array, "n p"]
    interp_values: Float[Array, "n p"]
    _in_size: int = eqx.field(static=True)
    _out_size: int = eqx.field(static=True)
    _m: int = eqx.field(static=True)
    _dtype: str = eqx.field(static=True)
    tags: frozenset[object] = eqx.field(static=True)

    def __init__(
        self,
        base_operator: lx.AbstractLinearOperator,
        interp_indices: Int[Array, "n p"],
        interp_values: Float[Array, "n p"],
        *,
        tags: object | frozenset[object] = frozenset(),
    ) -> None:
        if interp_indices.shape != interp_values.shape:
            raise ValueError(
                f"interp_indices and interp_values must have the same shape, "
                f"got {interp_indices.shape} and {interp_values.shape}."
            )
        m = base_operator.in_size()
        out_m = base_operator.out_size()
        if m != out_m:
            raise ValueError(f"base_operator must be square, got shape ({out_m}, {m}).")
        n = interp_indices.shape[0]
        self.base_operator = base_operator
        self.interp_indices = jnp.asarray(interp_indices)
        self.interp_values = jnp.asarray(interp_values)
        self._in_size = n
        self._out_size = n
        self._m = m
        self._dtype = str(
            jnp.result_type(
                jnp.dtype(_resolve_dtype(base_operator)),
                self.interp_values.dtype,
            )
        )
        self.tags = _to_frozenset(tags)

    def mv(self, vector: Float[Array, " n"]) -> Float[Array, " n"]:
        # W^T v: scatter into inducing space
        wt_v = jnp.zeros(self._m, dtype=vector.dtype)
        wt_v = wt_v.at[self.interp_indices].add(self.interp_values * vector[:, None])
        # K_uu (W^T v)
        k_wt_v = self.base_operator.mv(wt_v)
        # W (K_uu W^T v): gather back
        return jnp.sum(self.interp_values * k_wt_v[self.interp_indices], axis=-1)

    def as_matrix(self) -> Float[Array, "n n"]:
        W = self._build_W()
        K_uu = self.base_operator.as_matrix()
        return W @ K_uu @ W.T

    def transpose(self) -> InterpolatedOperator:
        # W K_uu W^T is symmetric when K_uu is symmetric, so self-transpose.
        # In general, (W K_uu W^T)^T = W K_uu^T W^T.
        return InterpolatedOperator(
            self.base_operator.T,
            self.interp_indices,
            self.interp_values,
            tags=lx.transpose_tags(self.tags),
        )

    def in_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((self._in_size,), jnp.dtype(self._dtype))

    def out_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((self._out_size,), jnp.dtype(self._dtype))

    def _build_W(self) -> Float[Array, "n m"]:
        """Build the dense interpolation matrix W (n x m)."""
        n = self._in_size
        W = jnp.zeros((n, self._m), dtype=jnp.dtype(self._dtype))
        rows = jnp.arange(n)[:, None]  # (n, 1) broadcast with (n, p)
        W = W.at[rows, self.interp_indices].add(self.interp_values)
        return W

MaskedOperator

Bases: AbstractLinearOperator

Row/column-masked view of a base operator.

Given a base operator A of shape (N, N) and boolean masks, produces the sub-matrix A[row_mask][:, col_mask].

Matvec is computed without materializing the sub-matrix: zero-pad input to full size, apply base matvec, then extract masked rows.

Parameters:

Name Type Description Default
base AbstractLinearOperator

The underlying (N, N) linear operator.

required
row_mask Bool[Array, ' N']

Boolean mask of length N selecting output rows.

required
col_mask Bool[Array, ' N']

Boolean mask of length N selecting input columns.

required
Source code in src/gaussx/_operators/_masked.py
class MaskedOperator(lx.AbstractLinearOperator):
    """Row/column-masked view of a base operator.

    Given a base operator A of shape ``(N, N)`` and boolean masks,
    produces the sub-matrix ``A[row_mask][:, col_mask]``.

    Matvec is computed without materializing the sub-matrix:
    zero-pad input to full size, apply base matvec, then extract
    masked rows.

    Args:
        base: The underlying ``(N, N)`` linear operator.
        row_mask: Boolean mask of length N selecting output rows.
        col_mask: Boolean mask of length N selecting input columns.
    """

    base: lx.AbstractLinearOperator
    row_mask: Bool[Array, " N"]
    col_mask: Bool[Array, " N"]
    _in_size: int = eqx.field(static=True)
    _out_size: int = eqx.field(static=True)
    _dtype: str = eqx.field(static=True)
    tags: frozenset[object] = eqx.field(static=True)

    def __init__(
        self,
        base: lx.AbstractLinearOperator,
        row_mask: Bool[Array, " N"],
        col_mask: Bool[Array, " N"],
        *,
        tags: object | frozenset[object] = frozenset(),
    ) -> None:
        if base.in_size() != base.out_size():
            raise ValueError(
                f"Base operator must be square, got in_size={base.in_size()}, "
                f"out_size={base.out_size()}."
            )
        n = base.in_size()
        if row_mask.shape != (n,) or col_mask.shape != (n,):
            raise ValueError(
                f"Masks must have shape ({n},), got row_mask={row_mask.shape}, "
                f"col_mask={col_mask.shape}."
            )
        self.base = base
        self.row_mask = jnp.asarray(row_mask, dtype=bool)
        self.col_mask = jnp.asarray(col_mask, dtype=bool)
        self._in_size = int(jnp.sum(col_mask))
        self._out_size = int(jnp.sum(row_mask))
        struct = base.out_structure()
        leaves = jax.tree.leaves(struct)
        self._dtype = str(leaves[0].dtype)
        self.tags = _to_frozenset(tags)

    def mv(self, vector: Float[Array, " m"]) -> Float[Array, " k"]:
        # Scatter input into full-size vector at col_mask positions
        n = self.base.in_size()
        col_indices = jnp.where(self.col_mask, size=self._in_size)[0]
        full_v = jnp.zeros(n, dtype=vector.dtype).at[col_indices].set(vector)
        # Apply base operator
        full_out = self.base.mv(full_v)
        # Gather output at row_mask positions
        row_indices = jnp.where(self.row_mask, size=self._out_size)[0]
        return full_out[row_indices]

    def as_matrix(self) -> Float[Array, "k m"]:
        full = self.base.as_matrix()
        row_indices = jnp.where(self.row_mask, size=self._out_size)[0]
        col_indices = jnp.where(self.col_mask, size=self._in_size)[0]
        return full[jnp.ix_(row_indices, col_indices)]

    def transpose(self) -> MaskedOperator:
        return MaskedOperator(
            self.base.T,
            self.col_mask,
            self.row_mask,
            tags=lx.transpose_tags(self.tags),
        )

    def in_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((self._in_size,), jnp.dtype(self._dtype))

    def out_structure(self) -> jax.ShapeDtypeStruct:
        return jax.ShapeDtypeStruct((self._out_size,), jnp.dtype(self._dtype))

implicit_cross_kernel(kernel_fn: Callable, X_data: Float[Array, 'N D'], X_inducing: Float[Array, 'M D'], batch_size: int = 1024, *, params: Any | None = None, tags: object | frozenset[object] = frozenset()) -> ImplicitCrossKernelOperator

Create a matrix-free rectangular cross-kernel operator.

Convenience wrapper around ImplicitCrossKernelOperator.

Parameters:

Name Type Description Default
kernel_fn Callable

Kernel function k(x, z) -> scalar or k(params, x, z) -> scalar.

required
X_data Float[Array, 'N D']

Data points, shape (N, D).

required
X_inducing Float[Array, 'M D']

Inducing points, shape (M, D).

required
batch_size int

Rows of X_data processed per scan step.

1024
params Any | None

Optional pytree of kernel hyperparameters.

None
tags object | frozenset[object]

Lineax structural tags.

frozenset()

Returns:

Type Description
ImplicitCrossKernelOperator

An ImplicitCrossKernelOperator representing K(X_data, X_inducing).

Source code in src/gaussx/_operators/_implicit_cross_kernel.py
def implicit_cross_kernel(
    kernel_fn: Callable,
    X_data: Float[Array, "N D"],
    X_inducing: Float[Array, "M D"],
    batch_size: int = 1024,
    *,
    params: Any | None = None,
    tags: object | frozenset[object] = frozenset(),
) -> ImplicitCrossKernelOperator:
    """Create a matrix-free rectangular cross-kernel operator.

    Convenience wrapper around `ImplicitCrossKernelOperator`.

    Args:
        kernel_fn: Kernel function ``k(x, z) -> scalar`` or
            ``k(params, x, z) -> scalar``.
        X_data: Data points, shape ``(N, D)``.
        X_inducing: Inducing points, shape ``(M, D)``.
        batch_size: Rows of ``X_data`` processed per scan step.
        params: Optional pytree of kernel hyperparameters.
        tags: Lineax structural tags.

    Returns:
        An ``ImplicitCrossKernelOperator`` representing ``K(X_data, X_inducing)``.
    """
    return ImplicitCrossKernelOperator(
        kernel_fn, X_data, X_inducing, batch_size=batch_size, params=params, tags=tags
    )

Lazy algebra & sampling

Sum / scale / compose operators without materializing, sample \(\varepsilon \sim \mathcal{N}(0, A)\) for the structured families, and solve bordered systems through the capacitance (Schur-complement) form.

Structured linear algebra and Gaussian primitives for JAX.

CapacitanceSolver

Bases: Module

Solve a base system subject to homogeneous point constraints.

Given a fast base solver B^{-1} and a set of N_b constrained indices, this enforces u = 0 at those indices via the capacitance-matrix correction:

  1. Base solve: u = B^{-1} f
  2. Sample boundary: u_b = u[boundary]
  3. Correction: alpha = C^{-1} u_b
  4. Subtract: x = u - G^T alpha

where G[k] = B^{-1} e_{b_k} are the Green's functions of the base solver for unit sources at the constrained indices, and C[k, l] = G[l][b_k] is the capacitance matrix. C^{-1} and G are precomputed at construction.

The solver operates on flat vectors. Any reshaping between fields and flat vectors, and any masking of the exterior, is the caller's responsibility -- keeping grid/mask concepts out of this class.

Parameters:

Name Type Description Default
base_solve Callable[[Float[Array, ' n']], Float[Array, ' n']]

Callable applying the base inverse B^{-1} to a flat right-hand side of length n.

required
boundary_indices Int[Array, ' Nb']

Flat indices of the constrained degrees of freedom, shape (N_b,).

required
n int

Length of the flat solution vector.

required

Attributes:

Name Type Description
base_solve Callable[[Float[Array, ' n']], Float[Array, ' n']]

The base inverse callable.

boundary_indices Int[Array, ' Nb']

The constrained indices.

green Float[Array, 'Nb n']

Green's functions G, shape (N_b, n).

capacitance_inv Float[Array, 'Nb Nb']

Inverse capacitance matrix C^{-1}, shape (N_b, N_b).

Source code in src/gaussx/_operators/_capacitance.py
class CapacitanceSolver(eqx.Module):
    r"""Solve a base system subject to homogeneous point constraints.

    Given a fast base solver ``B^{-1}`` and a set of ``N_b`` constrained indices,
    this enforces ``u = 0`` at those indices via the capacitance-matrix
    correction:

    1. Base solve:        ``u = B^{-1} f``
    2. Sample boundary:   ``u_b = u[boundary]``
    3. Correction:        ``alpha = C^{-1} u_b``
    4. Subtract:          ``x = u - G^T alpha``

    where ``G[k] = B^{-1} e_{b_k}`` are the Green's functions of the base solver
    for unit sources at the constrained indices, and ``C[k, l] = G[l][b_k]`` is
    the capacitance matrix. ``C^{-1}`` and ``G`` are precomputed at construction.

    The solver operates on **flat** vectors. Any reshaping between fields and
    flat vectors, and any masking of the exterior, is the caller's
    responsibility -- keeping grid/mask concepts out of this class.

    Args:
        base_solve: Callable applying the base inverse ``B^{-1}`` to a flat
            right-hand side of length ``n``.
        boundary_indices: Flat indices of the constrained degrees of freedom,
            shape ``(N_b,)``.
        n: Length of the flat solution vector.

    Attributes:
        base_solve: The base inverse callable.
        boundary_indices: The constrained indices.
        green: Green's functions ``G``, shape ``(N_b, n)``.
        capacitance_inv: Inverse capacitance matrix ``C^{-1}``, shape
            ``(N_b, N_b)``.
    """

    base_solve: Callable[[Float[Array, " n"]], Float[Array, " n"]]
    boundary_indices: Int[Array, " Nb"]
    green: Float[Array, "Nb n"]
    capacitance_inv: Float[Array, "Nb Nb"]

    def __init__(
        self,
        base_solve: Callable[[Float[Array, " n"]], Float[Array, " n"]],
        boundary_indices: Int[Array, " Nb"],
        n: int,
    ):
        indices = jnp.asarray(boundary_indices)
        n_b = indices.shape[0]
        unit_sources = jnp.zeros((n_b, n)).at[jnp.arange(n_b), indices].set(1.0)
        green = jax.vmap(base_solve)(unit_sources)  # [Nb, n]
        capacitance = green[:, indices].T  # C[k, l] = green[l][b_k]

        self.base_solve = base_solve
        self.boundary_indices = indices
        self.green = green
        self.capacitance_inv = jnp.linalg.inv(capacitance)

    def __call__(self, rhs: Float[Array, " n"]) -> Float[Array, " n"]:
        """Solve the constrained system for a flat right-hand side ``rhs``."""
        u = self.base_solve(rhs)
        u_b = u[self.boundary_indices]
        alpha = self.capacitance_inv @ u_b
        correction = self.green.T @ alpha
        return u - correction

SumOperator(*operators: lx.AbstractLinearOperator, tags: object | frozenset[object] = frozenset()) -> lx.AbstractLinearOperator

Lazy sum (A + B + …) v = A v + B v + ….

Defers materialization so that structured sub-operators keep their efficient matvec. All operators must have the same input and output sizes. Returns a (possibly tagged) chain of lineax AddLinearOperator compositions.

Parameters:

Name Type Description Default
*operators AbstractLinearOperator

Two or more lineax.AbstractLinearOperator instances with matching shapes.

()
tags object | frozenset[object]

Optional explicit lineax tags for the combined operator.

frozenset()

Returns:

Type Description
AbstractLinearOperator

The lazy sum as a lineax operator.

Source code in src/gaussx/_operators/_lazy_algebra.py
def SumOperator(
    *operators: lx.AbstractLinearOperator,
    tags: object | frozenset[object] = frozenset(),
) -> lx.AbstractLinearOperator:
    """Lazy sum ``(A + B + …) v = A v + B v + …``.

    Defers materialization so that structured sub-operators keep their
    efficient matvec. All operators must have the same input and output
    sizes. Returns a (possibly tagged) chain of lineax
    ``AddLinearOperator`` compositions.

    Args:
        *operators: Two or more ``lineax.AbstractLinearOperator`` instances
            with matching shapes.
        tags: Optional explicit lineax tags for the combined operator.

    Returns:
        The lazy sum as a lineax operator.
    """
    if len(operators) < 2:
        raise ValueError("SumOperator requires at least two operators.")
    in0 = operators[0].in_size()
    out0 = operators[0].out_size()
    for i, op in enumerate(operators[1:], 1):
        if op.in_size() != in0 or op.out_size() != out0:
            raise ValueError(
                f"Shape mismatch: operator 0 has shape ({out0}, {in0}) "
                f"but operator {i} has shape ({op.out_size()}, {op.in_size()})."
            )
    return _maybe_tag(ft.reduce(_op.add, operators), tags)

ScaledOperator(operator: lx.AbstractLinearOperator, scalar: float | Float[Array, ''], *, tags: object | frozenset[object] = frozenset()) -> lx.AbstractLinearOperator

Lazy scalar multiply (c A) v = c (A v).

Returns a (possibly tagged) lineax MulLinearOperator.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A lineax.AbstractLinearOperator.

required
scalar float | Float[Array, '']

A scalar multiplier.

required
tags object | frozenset[object]

Optional explicit lineax tags for the scaled operator.

frozenset()

Returns:

Type Description
AbstractLinearOperator

The lazy scaled operator.

Source code in src/gaussx/_operators/_lazy_algebra.py
def ScaledOperator(
    operator: lx.AbstractLinearOperator,
    scalar: float | Float[Array, ""],
    *,
    tags: object | frozenset[object] = frozenset(),
) -> lx.AbstractLinearOperator:
    """Lazy scalar multiply ``(c A) v = c (A v)``.

    Returns a (possibly tagged) lineax ``MulLinearOperator``.

    Args:
        operator: A ``lineax.AbstractLinearOperator``.
        scalar: A scalar multiplier.
        tags: Optional explicit lineax tags for the scaled operator.

    Returns:
        The lazy scaled operator.
    """
    scalar_array = jnp.asarray(scalar)
    if scalar_array.ndim != 0:
        msg = (
            "ScaledOperator scalar must be a rank-0 scalar, got "
            f"shape {scalar_array.shape}."
        )
        raise ValueError(msg)
    return _maybe_tag(scalar_array * operator, tags)

ProductOperator(left: lx.AbstractLinearOperator, right: lx.AbstractLinearOperator, *, tags: object | frozenset[object] = frozenset()) -> lx.AbstractLinearOperator

Lazy matmul (A B) v = A (B v).

The inner dimension must match: left.in_size() == right.out_size(). Returns a (possibly tagged) lineax ComposedLinearOperator.

Parameters:

Name Type Description Default
left AbstractLinearOperator

The left operator A.

required
right AbstractLinearOperator

The right operator B.

required
tags object | frozenset[object]

Optional explicit lineax tags for the composed operator.

frozenset()

Returns:

Type Description
AbstractLinearOperator

The lazy product as a lineax operator.

Source code in src/gaussx/_operators/_lazy_algebra.py
def ProductOperator(
    left: lx.AbstractLinearOperator,
    right: lx.AbstractLinearOperator,
    *,
    tags: object | frozenset[object] = frozenset(),
) -> lx.AbstractLinearOperator:
    """Lazy matmul ``(A B) v = A (B v)``.

    The inner dimension must match: ``left.in_size() == right.out_size()``.
    Returns a (possibly tagged) lineax ``ComposedLinearOperator``.

    Args:
        left: The left operator A.
        right: The right operator B.
        tags: Optional explicit lineax tags for the composed operator.

    Returns:
        The lazy product as a lineax operator.
    """
    if left.in_size() != right.out_size():
        raise ValueError(
            f"Inner dimension mismatch: left.in_size()={left.in_size()} "
            f"!= right.out_size()={right.out_size()}."
        )
    return _maybe_tag(left @ right, tags)

kronecker_sum_sample(A_op: lx.AbstractLinearOperator, B_op: lx.AbstractLinearOperator, *, key: jax.Array, num_samples: int = 1) -> Float[Array, 'num_samples n_a n_b']

Sample from 𝒩(0, A ⊕ B) using per-factor eigendecompositions.

Draws zero-mean samples with covariance A ⊕ B by applying the matrix-free KroneckerSumSqrt to standard normal noise, avoiding materialization of the full (n_a · n_b, n_a · n_b) covariance.

Parameters:

Name Type Description Default
A_op AbstractLinearOperator

Symmetric PSD factor, shape (n_a, n_a).

required
B_op AbstractLinearOperator

Symmetric PSD factor, shape (n_b, n_b).

required
key Array

PRNG key for the standard normal draws.

required
num_samples int

Number of samples to draw.

1

Returns:

Type Description
Float[Array, 'num_samples n_a n_b']

Samples of shape (num_samples, n_a, n_b).

Raises:

Type Description
ValueError

If num_samples is less than 1.

Source code in src/gaussx/_operators/_kronecker_sum.py
def kronecker_sum_sample(
    A_op: lx.AbstractLinearOperator,
    B_op: lx.AbstractLinearOperator,
    *,
    key: jax.Array,
    num_samples: int = 1,
) -> Float[Array, "num_samples n_a n_b"]:
    """Sample from ``𝒩(0, A ⊕ B)`` using per-factor eigendecompositions.

    Draws zero-mean samples with covariance ``A ⊕ B`` by applying the
    matrix-free `KroneckerSumSqrt` to standard normal noise, avoiding
    materialization of the full ``(n_a · n_b, n_a · n_b)`` covariance.

    Args:
        A_op: Symmetric PSD factor, shape ``(n_a, n_a)``.
        B_op: Symmetric PSD factor, shape ``(n_b, n_b)``.
        key: PRNG key for the standard normal draws.
        num_samples: Number of samples to draw.

    Returns:
        Samples of shape ``(num_samples, n_a, n_b)``.

    Raises:
        ValueError: If ``num_samples`` is less than 1.
    """
    if num_samples <= 0:
        raise ValueError(f"num_samples must be at least 1, got {num_samples}.")

    sqrt_op = KroneckerSumSqrt(A_op, B_op)
    eps = jax.random.normal(
        key,
        (num_samples, sqrt_op.in_size()),
        dtype=jnp.dtype(sqrt_op._dtype),
    )
    samples = jax.vmap(sqrt_op.mv)(eps)
    return rearrange(samples, "s (a b) -> s a b", a=sqrt_op._n_a, b=sqrt_op._n_b)

sumkronecker_sample(op: SumKronecker, *, key: jax.Array, num_samples: int = 1, lanczos_order: int = 50) -> Float[Array, 'num_samples n']

Sample from 𝒩(0, op) with matrix-free Lanczos square roots.

The square-root action is evaluated by matfree Lanczos against op.mv. This avoids materialising the dense (n_A n_B) × (n_A n_B) covariance and costs lanczos_order SumKronecker matvecs per sample.

Parameters:

Name Type Description Default
op SumKronecker

Positive-semidefinite SumKronecker covariance operator.

required
key Array

JAX PRNG key.

required
num_samples int

Number of independent samples to draw.

1
lanczos_order int

Lanczos truncation order.

50

Returns:

Type Description
Float[Array, 'num_samples n']

Samples with shape (num_samples, op.in_size()).

Source code in src/gaussx/_operators/_sum_kronecker.py
def sumkronecker_sample(
    op: SumKronecker,
    *,
    key: jax.Array,
    num_samples: int = 1,
    lanczos_order: int = 50,
) -> Float[Array, "num_samples n"]:
    r"""Sample from ``𝒩(0, op)`` with matrix-free Lanczos square roots.

    The square-root action is evaluated by ``matfree`` Lanczos against
    ``op.mv``. This avoids materialising the dense ``(n_A n_B) ×
    (n_A n_B)`` covariance and costs ``lanczos_order`` SumKronecker
    matvecs per sample.

    Args:
        op: Positive-semidefinite SumKronecker covariance operator.
        key: JAX PRNG key.
        num_samples: Number of independent samples to draw.
        lanczos_order: Lanczos truncation order.

    Returns:
        Samples with shape ``(num_samples, op.in_size())``.
    """
    from gaussx._primitives._sqrt import sqrt

    if op.in_size() != op.out_size():
        raise ValueError(
            "sumkronecker_sample requires a square SumKronecker, got "
            f"in_size={op.in_size()} and out_size={op.out_size()}."
        )
    if num_samples < 1:
        raise ValueError(f"num_samples must be at least 1, got {num_samples}.")

    sqrt_op = sqrt(op, lanczos_order=lanczos_order)
    eps = jax.random.normal(
        key, (num_samples, op.in_size()), dtype=op.in_structure().dtype
    )
    return jax.vmap(sqrt_op.mv)(eps)

toeplitz_sample(column: Float[Array, ' n'], *, key: jax.Array, num_samples: int = 1, embedding_factor: int = 2) -> Float[Array, 'num_samples n']

Sample from 𝒩(0, Toeplitz(column)) via FFT circulant embedding.

The Wood--Chan non-negativity check is JIT-friendly (via eqx.error_if). If the embedding fails for the given embedding_factor, the error fires at evaluation time — bump embedding_factor (typically 4 or 8) for difficult covariances rather than relying on a runtime fallback.

Parameters:

Name Type Description Default
column Float[Array, ' n']

First column of the covariance matrix.

required
key Array

JAX PRNG key used to draw white noise.

required
num_samples int

Number of independent samples to draw.

1
embedding_factor int

Circulant embedding size as a multiple of n.

2

Returns:

Type Description
Float[Array, 'num_samples n']

Samples with shape (num_samples, n).

Source code in src/gaussx/_operators/_toeplitz.py
def toeplitz_sample(
    column: Float[Array, " n"],
    *,
    key: jax.Array,
    num_samples: int = 1,
    embedding_factor: int = 2,
) -> Float[Array, "num_samples n"]:
    """Sample from ``𝒩(0, Toeplitz(column))`` via FFT circulant embedding.

    The Wood--Chan non-negativity check is JIT-friendly (via
    ``eqx.error_if``). If the embedding fails for the given
    ``embedding_factor``, the error fires at evaluation time — bump
    ``embedding_factor`` (typically ``4`` or ``8``) for difficult
    covariances rather than relying on a runtime fallback.

    Args:
        column: First column of the covariance matrix.
        key: JAX PRNG key used to draw white noise.
        num_samples: Number of independent samples to draw.
        embedding_factor: Circulant embedding size as a multiple of ``n``.

    Returns:
        Samples with shape ``(num_samples, n)``.
    """
    if num_samples < 1:
        raise ValueError("num_samples must be at least 1.")

    column = _as_floating_column(column)
    factor = ToeplitzCholesky(column, embedding_factor=embedding_factor)
    noise = jr.normal(key, (num_samples, factor.in_size()), dtype=column.dtype)
    return jax.vmap(factor.mv)(noise)

Structural tags & predicates

Tags mark structure and properties on operators; the is_* predicates are what the primitives consult when choosing an algorithm. The property tags (positive_semidefinite_tag, symmetric_tag, the triangular tags, …) are re-exported from lineax so user code only needs one import.

Structured linear algebra and Gaussian primitives for JAX.

is_diagonal = lx.is_diagonal module-attribute

is_symmetric = lx.is_symmetric module-attribute

is_positive_semidefinite = lx.is_positive_semidefinite module-attribute

is_negative_semidefinite = lx.is_negative_semidefinite module-attribute

is_lower_triangular = lx.is_lower_triangular module-attribute

is_upper_triangular = lx.is_upper_triangular module-attribute

kronecker_tag = _Tag('kronecker_tag') module-attribute

Operator is a Kronecker product.

kronecker_sum_tag = _Tag('kronecker_sum_tag') module-attribute

Operator is a Kronecker sum A (+) B = A (x) I_b + I_a (x) B.

block_diagonal_tag = _Tag('block_diagonal_tag') module-attribute

Operator is block diagonal.

block_tridiagonal_tag = _Tag('block_tridiagonal_tag') module-attribute

Operator is block tridiagonal.

low_rank_tag = _Tag('low_rank_tag') module-attribute

Operator has low-rank structure (e.g. L + U D V^T).

diagonal_tag = lx.diagonal_tag module-attribute

symmetric_tag = lx.symmetric_tag module-attribute

positive_semidefinite_tag = lx.positive_semidefinite_tag module-attribute

negative_semidefinite_tag = lx.negative_semidefinite_tag module-attribute

lower_triangular_tag = lx.lower_triangular_tag module-attribute

upper_triangular_tag = lx.upper_triangular_tag module-attribute

tridiagonal_tag = lx.tridiagonal_tag module-attribute

unit_diagonal_tag = lx.unit_diagonal_tag module-attribute

is_kronecker(operator: lx.AbstractLinearOperator) -> bool

Check whether operator carries the Kronecker tag.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

Linear operator to inspect.

required

Returns:

Type Description
bool

True if the operator carries the Kronecker tag, else False.

Source code in src/gaussx/_tags.py
@ft.singledispatch
def is_kronecker(operator: lx.AbstractLinearOperator) -> bool:
    """Check whether *operator* carries the Kronecker tag.

    Args:
        operator: Linear operator to inspect.

    Returns:
        ``True`` if the operator carries the Kronecker tag, else ``False``.
    """
    return False

is_kronecker_sum(operator: lx.AbstractLinearOperator) -> bool

Check whether operator carries the Kronecker sum tag.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

Linear operator to inspect.

required

Returns:

Type Description
bool

True if the operator carries the Kronecker sum tag, else False.

Source code in src/gaussx/_tags.py
@ft.singledispatch
def is_kronecker_sum(operator: lx.AbstractLinearOperator) -> bool:
    """Check whether *operator* carries the Kronecker sum tag.

    Args:
        operator: Linear operator to inspect.

    Returns:
        ``True`` if the operator carries the Kronecker sum tag, else ``False``.
    """
    return False

is_block_diagonal(operator: lx.AbstractLinearOperator) -> bool

Check whether operator carries the block-diagonal tag.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

Linear operator to inspect.

required

Returns:

Type Description
bool

True if the operator carries the block-diagonal tag, else False.

Source code in src/gaussx/_tags.py
@ft.singledispatch
def is_block_diagonal(operator: lx.AbstractLinearOperator) -> bool:
    """Check whether *operator* carries the block-diagonal tag.

    Args:
        operator: Linear operator to inspect.

    Returns:
        ``True`` if the operator carries the block-diagonal tag, else ``False``.
    """
    return False

is_block_tridiagonal(operator: lx.AbstractLinearOperator) -> bool

Check whether operator carries the block-tridiagonal tag.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

Linear operator to inspect.

required

Returns:

Type Description
bool

True if the operator carries the block-tridiagonal tag, else

bool

False.

Source code in src/gaussx/_tags.py
@ft.singledispatch
def is_block_tridiagonal(operator: lx.AbstractLinearOperator) -> bool:
    """Check whether *operator* carries the block-tridiagonal tag.

    Args:
        operator: Linear operator to inspect.

    Returns:
        ``True`` if the operator carries the block-tridiagonal tag, else
        ``False``.
    """
    return False

is_low_rank(operator: lx.AbstractLinearOperator) -> bool

Check whether operator carries the low-rank tag.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

Linear operator to inspect.

required

Returns:

Type Description
bool

True if the operator carries the low-rank tag, else False.

Source code in src/gaussx/_tags.py
@ft.singledispatch
def is_low_rank(operator: lx.AbstractLinearOperator) -> bool:
    """Check whether *operator* carries the low-rank tag.

    Args:
        operator: Linear operator to inspect.

    Returns:
        ``True`` if the operator carries the low-rank tag, else ``False``.
    """
    return False