Skip to content

Solvers & Preconditioners

Layer 1.5: strategy objects that encapsulate how a solve or logdet is computed, decoupled from what is being solved. Everything that accepts a solver= keyword anywhere in gaussx takes one of these; None falls back to structural dispatch on the operator.

A strategy bundles a solve and a logdet algorithm. Mix and match with ComposedSolver — e.g. CG for the solve, stochastic Lanczos quadrature for the logdet — which is the standard recipe for large kernel matrices.

The front door

linear_solve is the high-level entry point: it accepts a lineax operator or a bare (matvec, shape) pair, normalises negative-definite systems, picks a sensible iterative solver from the operator's tags, and threads a preconditioner through. as_linear_operator wraps a raw matvec callable into a tagged FunctionLinearOperator for matrix-free workflows.

Structured linear algebra and Gaussian primitives for JAX.

linear_solve(operator: OperatorLike, vector: Float[Array, ' n'], *, solver: AbstractSolveStrategy | None = None, preconditioner: PreconditionerLike | None = None) -> Float[Array, ' n']

Solve A x = b through the unified front door.

Accepts either a built operator or a (matvec, shape) pair, handles the negative-definite sign convention, picks a default solver when none is given, and optionally applies a preconditioner.

Sign handling: an operator tagged negative semidefinite (and not PSD) is solved via the equivalent positive-definite system (-A) x = -b, so that a CG-style solver can be used directly. This is the common case for elliptic PDE operators (e.g. a discrete Laplacian), which finite-volume / spectral callers hand over as negative-definite matvecs.

Default solver selection (when solver is None):

  • positive semidefinite operator -> CGSolver
  • symmetric (possibly indefinite) operator -> MINRESSolver
  • otherwise -> a ValueError asking for an explicit solver

Parameters:

Name Type Description Default
operator OperatorLike

The linear operator A, or a (matvec, shape) pair.

required
vector Float[Array, ' n']

Right-hand side b, shape (n,).

required
solver AbstractSolveStrategy | None

Solve strategy to use. When None a default is selected from the operator's structural tags.

None
preconditioner PreconditionerLike | None

Optional preconditioner. May be an AbstractPreconditioner, a lineax operator applying M^{-1}, or a callable v -> M^{-1} v. Preconditioning is currently applied through CGSolver.

None

Returns:

Type Description
Float[Array, ' n']

The solution x, shape (n,).

Source code in src/gaussx/_solve_frontend.py
def linear_solve(
    operator: OperatorLike,
    vector: Float[Array, " n"],
    *,
    solver: AbstractSolveStrategy | None = None,
    preconditioner: PreconditionerLike | None = None,
) -> Float[Array, " n"]:
    """Solve ``A x = b`` through the unified front door.

    Accepts either a built operator or a ``(matvec, shape)`` pair, handles the
    negative-definite sign convention, picks a default solver when none is
    given, and optionally applies a preconditioner.

    Sign handling: an operator tagged *negative* semidefinite (and not PSD) is
    solved via the equivalent positive-definite system ``(-A) x = -b``, so that
    a CG-style solver can be used directly. This is the common case for elliptic
    PDE operators (e.g. a discrete Laplacian), which finite-volume / spectral
    callers hand over as negative-definite matvecs.

    Default solver selection (when ``solver is None``):

    * positive semidefinite operator -> `CGSolver`
    * symmetric (possibly indefinite) operator -> `MINRESSolver`
    * otherwise -> a `ValueError` asking for an explicit solver

    Args:
        operator: The linear operator ``A``, or a ``(matvec, shape)`` pair.
        vector: Right-hand side ``b``, shape ``(n,)``.
        solver: Solve strategy to use. When ``None`` a default is selected from
            the operator's structural tags.
        preconditioner: Optional preconditioner. May be an
            `AbstractPreconditioner`, a lineax operator applying
            ``M^{-1}``, or a callable ``v -> M^{-1} v``. Preconditioning is
            currently applied through `CGSolver`.

    Returns:
        The solution ``x``, shape ``(n,)``.
    """
    op = _coerce_operator(operator, vector)

    # Negative-definite -> solve the equivalent positive-definite system.
    if is_negative_semidefinite(op) and not is_positive_semidefinite(op):
        op = _negate(op)
        vector = -vector

    if solver is None:
        solver = _default_solver(op)

    if preconditioner is not None:
        solver = _attach_preconditioner(solver, _as_preconditioner(preconditioner))

    return solver.solve(op, vector)

as_linear_operator(matvec: MatvecLike, *, shape: tuple[int, int] | None = None, in_structure: object | None = None, dtype: object = float, symmetric: bool = False, positive_semidefinite: bool = False, negative_definite: bool = False) -> lx.AbstractLinearOperator

Wrap a raw matvec callable as a tagged lineax operator.

This is the matrix-free front door: callers that only have a v -> A @ v function (rather than a structured operator object) use this to obtain an operator that gaussx's solvers and primitives understand.

Parameters:

Name Type Description Default
matvec MatvecLike

The matrix-vector product v -> A @ v.

required
shape tuple[int, int] | None

Matrix shape (out, in). The operator's input structure is inferred as a length-in vector of dtype. Ignored when in_structure is given.

None
in_structure object | None

An explicit input structure -- either a jax.ShapeDtypeStruct, an int (interpreted as a vector length), or a PyTree thereof. Takes precedence over shape.

None
dtype object

Dtype used when building the input structure from shape.

float
symmetric bool

Tag the operator symmetric (A == A^T).

False
positive_semidefinite bool

Tag the operator PSD (implies symmetric).

False
negative_definite bool

Tag the operator negative semidefinite (implies symmetric). Use this for elliptic operators such as a discrete Laplacian; linear_solve will route the solve through the equivalent positive-definite system.

False

Returns:

Type Description
AbstractLinearOperator

A lineax.FunctionLinearOperator carrying the requested tags.

Raises:

Type Description
ValueError

If neither shape nor in_structure is provided.

Source code in src/gaussx/_solve_frontend.py
def as_linear_operator(
    matvec: MatvecLike,
    *,
    shape: tuple[int, int] | None = None,
    in_structure: object | None = None,
    dtype: object = float,
    symmetric: bool = False,
    positive_semidefinite: bool = False,
    negative_definite: bool = False,
) -> lx.AbstractLinearOperator:
    """Wrap a raw ``matvec`` callable as a tagged lineax operator.

    This is the matrix-free front door: callers that only have a
    ``v -> A @ v`` function (rather than a structured operator object) use this
    to obtain an operator that gaussx's solvers and primitives understand.

    Args:
        matvec: The matrix-vector product ``v -> A @ v``.
        shape: Matrix shape ``(out, in)``. The operator's input structure is
            inferred as a length-``in`` vector of ``dtype``. Ignored when
            ``in_structure`` is given.
        in_structure: An explicit input structure -- either a
            ``jax.ShapeDtypeStruct``, an ``int`` (interpreted as a vector
            length), or a PyTree thereof. Takes precedence over ``shape``.
        dtype: Dtype used when building the input structure from ``shape``.
        symmetric: Tag the operator symmetric (``A == A^T``).
        positive_semidefinite: Tag the operator PSD (implies symmetric).
        negative_definite: Tag the operator negative semidefinite (implies
            symmetric). Use this for elliptic operators such as a discrete
            Laplacian; `linear_solve` will route the solve through the
            equivalent positive-definite system.

    Returns:
        A `lineax.FunctionLinearOperator` carrying the requested tags.

    Raises:
        ValueError: If neither ``shape`` nor ``in_structure`` is provided.
    """
    import jax

    if in_structure is not None:
        structure = _normalise_in_structure(in_structure, dtype)
    elif shape is not None:
        structure = jax.ShapeDtypeStruct((shape[1],), dtype)
    else:
        raise ValueError("Provide either `shape=(out, in)` or `in_structure`.")

    tags: list[object] = []
    if symmetric or positive_semidefinite or negative_definite:
        tags.append(lx.symmetric_tag)
    if positive_semidefinite:
        tags.append(lx.positive_semidefinite_tag)
    if negative_definite:
        tags.append(lx.negative_semidefinite_tag)

    return lx.FunctionLinearOperator(matvec, structure, tags=tuple(tags))

Abstract interfaces

Structured linear algebra and Gaussian primitives for JAX.

AbstractSolveStrategy

Bases: Module

Protocol for linear solve strategies.

A solve strategy encapsulates the choice of algorithm for solving A x = b. Separating solve from logdet lets users mix-and-match via ComposedSolver.

Source code in src/gaussx/_strategies/_base.py
class AbstractSolveStrategy(eqx.Module):
    """Protocol for linear solve strategies.

    A solve strategy encapsulates the choice of algorithm for
    solving ``A x = b``.  Separating solve from logdet lets
    users mix-and-match via `ComposedSolver`.
    """

    @abc.abstractmethod
    def solve(
        self,
        operator: lx.AbstractLinearOperator,
        vector: Float[Array, " n"],
    ) -> Float[Array, " n"]:
        """Solve ``A x = b``.

        Args:
            operator: Linear operator ``A``.
            vector: Right-hand side ``b``, shape ``(n,)``.

        Returns:
            Solution ``x``, shape ``(n,)``.
        """
        ...

solve(operator: lx.AbstractLinearOperator, vector: Float[Array, ' n']) -> Float[Array, ' n'] abstractmethod

Solve A x = b.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

Linear operator A.

required
vector Float[Array, ' n']

Right-hand side b, shape (n,).

required

Returns:

Type Description
Float[Array, ' n']

Solution x, shape (n,).

Source code in src/gaussx/_strategies/_base.py
@abc.abstractmethod
def solve(
    self,
    operator: lx.AbstractLinearOperator,
    vector: Float[Array, " n"],
) -> Float[Array, " n"]:
    """Solve ``A x = b``.

    Args:
        operator: Linear operator ``A``.
        vector: Right-hand side ``b``, shape ``(n,)``.

    Returns:
        Solution ``x``, shape ``(n,)``.
    """
    ...

AbstractLogdetStrategy

Bases: Module

Protocol for log-determinant strategies.

A logdet strategy encapsulates the choice of algorithm for computing log |det(A)|. Separating logdet from solve lets users mix-and-match via ComposedSolver.

All implementations accept an optional key parameter for stochastic methods. Deterministic strategies ignore it.

Source code in src/gaussx/_strategies/_base.py
class AbstractLogdetStrategy(eqx.Module):
    """Protocol for log-determinant strategies.

    A logdet strategy encapsulates the choice of algorithm for
    computing ``log |det(A)|``.  Separating logdet from solve
    lets users mix-and-match via `ComposedSolver`.

    All implementations accept an optional ``key`` parameter for
    stochastic methods.  Deterministic strategies ignore it.
    """

    @abc.abstractmethod
    def logdet(
        self,
        operator: lx.AbstractLinearOperator,
        *,
        key: jax.Array | None = None,
    ) -> Float[Array, ""]:
        """Compute log |det(A)|.

        Args:
            operator: Linear operator.
            key: Optional PRNG key for stochastic estimators.
        """
        ...

logdet(operator: lx.AbstractLinearOperator, *, key: jax.Array | None = None) -> Float[Array, ''] abstractmethod

Compute log |det(A)|.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

Linear operator.

required
key Array | None

Optional PRNG key for stochastic estimators.

None
Source code in src/gaussx/_strategies/_base.py
@abc.abstractmethod
def logdet(
    self,
    operator: lx.AbstractLinearOperator,
    *,
    key: jax.Array | None = None,
) -> Float[Array, ""]:
    """Compute log |det(A)|.

    Args:
        operator: Linear operator.
        key: Optional PRNG key for stochastic estimators.
    """
    ...

AbstractSolverStrategy

Bases: AbstractSolveStrategy, AbstractLogdetStrategy

Protocol for solver strategies that pair solve + logdet.

A solver strategy encapsulates the choice of algorithm for solving linear systems and computing log-determinants. This decouples distribution objects from solver implementation details.

Subclasses must implement both solve and logdet. For independent control, see AbstractSolveStrategy and AbstractLogdetStrategy, composable via ComposedSolver.

Source code in src/gaussx/_strategies/_base.py
class AbstractSolverStrategy(AbstractSolveStrategy, AbstractLogdetStrategy):
    """Protocol for solver strategies that pair solve + logdet.

    A solver strategy encapsulates the choice of algorithm for
    solving linear systems and computing log-determinants. This
    decouples distribution objects from solver implementation
    details.

    Subclasses must implement both `solve` and `logdet`.
    For independent control, see `AbstractSolveStrategy` and
    `AbstractLogdetStrategy`, composable via
    `ComposedSolver`.
    """

Direct & iterative solvers

Structured linear algebra and Gaussian primitives for JAX.

DenseSolver

Bases: AbstractSolverStrategy

Dense solver strategy using gaussx structural dispatch.

Delegates to gaussx.solve and gaussx.logdet which automatically select the best algorithm based on operator structure (Diagonal, BlockDiag, Kronecker, LowRankUpdate, or dense fallback via lineax).

Source code in src/gaussx/_strategies/_dense.py
class DenseSolver(AbstractSolverStrategy):
    """Dense solver strategy using gaussx structural dispatch.

    Delegates to ``gaussx.solve`` and ``gaussx.logdet`` which
    automatically select the best algorithm based on operator
    structure (Diagonal, BlockDiag, Kronecker, LowRankUpdate,
    or dense fallback via lineax).
    """

    def solve(
        self,
        operator: lx.AbstractLinearOperator,
        vector: Float[Array, " n"],
    ) -> Float[Array, " n"]:
        """Solve ``A x = b`` via structural dispatch (``gaussx.solve``).

        Args:
            operator: Linear operator ``A``.
            vector: Right-hand side ``b``, shape ``(n,)``.

        Returns:
            Solution ``x``, shape ``(n,)``.
        """
        return _solve(operator, vector)

    def logdet(
        self,
        operator: lx.AbstractLinearOperator,
        *,
        key: jax.Array | None = None,
    ) -> Float[Array, ""]:
        """Compute ``log|det(A)|`` via structural dispatch (``gaussx.logdet``).

        Args:
            operator: Linear operator ``A``.
            key: Unused; present for protocol compatibility with stochastic
                estimators.

        Returns:
            Scalar log-determinant.
        """
        return _logdet(operator)

solve(operator: lx.AbstractLinearOperator, vector: Float[Array, ' n']) -> Float[Array, ' n']

Solve A x = b via structural dispatch (gaussx.solve).

Parameters:

Name Type Description Default
operator AbstractLinearOperator

Linear operator A.

required
vector Float[Array, ' n']

Right-hand side b, shape (n,).

required

Returns:

Type Description
Float[Array, ' n']

Solution x, shape (n,).

Source code in src/gaussx/_strategies/_dense.py
def solve(
    self,
    operator: lx.AbstractLinearOperator,
    vector: Float[Array, " n"],
) -> Float[Array, " n"]:
    """Solve ``A x = b`` via structural dispatch (``gaussx.solve``).

    Args:
        operator: Linear operator ``A``.
        vector: Right-hand side ``b``, shape ``(n,)``.

    Returns:
        Solution ``x``, shape ``(n,)``.
    """
    return _solve(operator, vector)

logdet(operator: lx.AbstractLinearOperator, *, key: jax.Array | None = None) -> Float[Array, '']

Compute log|det(A)| via structural dispatch (gaussx.logdet).

Parameters:

Name Type Description Default
operator AbstractLinearOperator

Linear operator A.

required
key Array | None

Unused; present for protocol compatibility with stochastic estimators.

None

Returns:

Type Description
Float[Array, '']

Scalar log-determinant.

Source code in src/gaussx/_strategies/_dense.py
def logdet(
    self,
    operator: lx.AbstractLinearOperator,
    *,
    key: jax.Array | None = None,
) -> Float[Array, ""]:
    """Compute ``log|det(A)|`` via structural dispatch (``gaussx.logdet``).

    Args:
        operator: Linear operator ``A``.
        key: Unused; present for protocol compatibility with stochastic
            estimators.

    Returns:
        Scalar log-determinant.
    """
    return _logdet(operator)

AutoSolver

Bases: AbstractSolverStrategy

Automatic solver selection based on operator type and size.

Selection logic:

  • Structured (Diagonal, BlockDiag, Kronecker, LowRankUpdate): DenseSolver (structural dispatch handles efficiency)
  • Small dense (N <= size_threshold): DenseSolver
  • Large PSD: CGSolver
  • Large general: DenseSolver (fallback)

Attributes:

Name Type Description
size_threshold int

Matrix dimension above which iterative solvers are preferred. Default: 1000.

Source code in src/gaussx/_strategies/_auto.py
class AutoSolver(AbstractSolverStrategy):
    """Automatic solver selection based on operator type and size.

    Selection logic:

    - Structured (Diagonal, BlockDiag, Kronecker, LowRankUpdate):
      DenseSolver (structural dispatch handles efficiency)
    - Small dense (N <= size_threshold): DenseSolver
    - Large PSD: CGSolver
    - Large general: DenseSolver (fallback)

    Attributes:
        size_threshold: Matrix dimension above which iterative
            solvers are preferred. Default: 1000.
    """

    size_threshold: int = 1000

    def solve(
        self,
        operator: lx.AbstractLinearOperator,
        vector: Float[Array, " n"],
    ) -> Float[Array, " n"]:
        """Solve A x = b with automatically selected algorithm.

        Args:
            operator: The linear operator A.
            vector: The right-hand side b.

        Returns:
            The solution x.
        """
        return self._get_strategy(operator).solve(operator, vector)

    def logdet(
        self,
        operator: lx.AbstractLinearOperator,
        *,
        key: jax.Array | None = None,
    ) -> Float[Array, ""]:
        """Compute log |det(A)| with automatically selected algorithm.

        Args:
            operator: The linear operator A.
            key: Optional PRNG key (forwarded to stochastic strategies).

        Returns:
            Scalar log |det(A)|.
        """
        return self._get_strategy(operator).logdet(operator, key=key)

    def _get_strategy(
        self, operator: lx.AbstractLinearOperator
    ) -> AbstractSolverStrategy:
        """Select the best solver strategy for the given operator."""
        from gaussx._operators._block_diag import BlockDiag
        from gaussx._operators._kronecker import Kronecker
        from gaussx._operators._low_rank_update import LowRankUpdate
        from gaussx._strategies._cg import CGSolver
        from gaussx._strategies._dense import DenseSolver

        if isinstance(
            operator, (lx.DiagonalLinearOperator, BlockDiag, Kronecker, LowRankUpdate)
        ):
            return DenseSolver()

        n = operator.in_size()
        if n <= self.size_threshold:
            return DenseSolver()

        # Large operators: use CG for PSD, DenseSolver otherwise
        if lx.is_positive_semidefinite(operator):
            return CGSolver()

        return DenseSolver()

solve(operator: lx.AbstractLinearOperator, vector: Float[Array, ' n']) -> Float[Array, ' n']

Solve A x = b with automatically selected algorithm.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

The linear operator A.

required
vector Float[Array, ' n']

The right-hand side b.

required

Returns:

Type Description
Float[Array, ' n']

The solution x.

Source code in src/gaussx/_strategies/_auto.py
def solve(
    self,
    operator: lx.AbstractLinearOperator,
    vector: Float[Array, " n"],
) -> Float[Array, " n"]:
    """Solve A x = b with automatically selected algorithm.

    Args:
        operator: The linear operator A.
        vector: The right-hand side b.

    Returns:
        The solution x.
    """
    return self._get_strategy(operator).solve(operator, vector)

logdet(operator: lx.AbstractLinearOperator, *, key: jax.Array | None = None) -> Float[Array, '']

Compute log |det(A)| with automatically selected algorithm.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

The linear operator A.

required
key Array | None

Optional PRNG key (forwarded to stochastic strategies).

None

Returns:

Type Description
Float[Array, '']

Scalar log |det(A)|.

Source code in src/gaussx/_strategies/_auto.py
def logdet(
    self,
    operator: lx.AbstractLinearOperator,
    *,
    key: jax.Array | None = None,
) -> Float[Array, ""]:
    """Compute log |det(A)| with automatically selected algorithm.

    Args:
        operator: The linear operator A.
        key: Optional PRNG key (forwarded to stochastic strategies).

    Returns:
        Scalar log |det(A)|.
    """
    return self._get_strategy(operator).logdet(operator, key=key)

CGSolver

Bases: AbstractSolverStrategy

Iterative CG solver with stochastic log-determinant.

Uses lineax CG for the linear solve and matfree's stochastic Lanczos quadrature (SLQ) for the log-determinant. Suitable for large PSD operators where dense factorization is too expensive.

Attributes:

Name Type Description
rtol float

Relative tolerance for CG.

atol float

Absolute tolerance for CG.

max_steps int

Maximum CG iterations.

num_probes int

Number of probe vectors for stochastic logdet.

lanczos_order int

Order of the Lanczos decomposition for SLQ.

preconditioner AbstractPreconditioner | None

Optional preconditioner. When set, its approximate inverse is passed to lineax CG to accelerate convergence.

Source code in src/gaussx/_strategies/_cg.py
class CGSolver(AbstractSolverStrategy):
    """Iterative CG solver with stochastic log-determinant.

    Uses lineax CG for the linear solve and matfree's stochastic
    Lanczos quadrature (SLQ) for the log-determinant. Suitable
    for large PSD operators where dense factorization is too
    expensive.

    Attributes:
        rtol: Relative tolerance for CG.
        atol: Absolute tolerance for CG.
        max_steps: Maximum CG iterations.
        num_probes: Number of probe vectors for stochastic logdet.
        lanczos_order: Order of the Lanczos decomposition for SLQ.
        preconditioner: Optional preconditioner. When set, its approximate
            inverse is passed to lineax CG to accelerate convergence.
    """

    rtol: float = 1e-5
    atol: float = 1e-5
    max_steps: int = 1000
    num_probes: int = 20
    lanczos_order: int = 30
    preconditioner: AbstractPreconditioner | None = None

    def solve(
        self,
        operator: lx.AbstractLinearOperator,
        vector: Float[Array, " n"],
    ) -> Float[Array, " n"]:
        """Solve ``A x = b`` with conjugate gradients.

        Args:
            operator: A PSD linear operator ``A``.
            vector: Right-hand side ``b``, shape ``(n,)``.

        Returns:
            Solution ``x``, shape ``(n,)``.
        """
        solver = lx.CG(rtol=self.rtol, atol=self.atol, max_steps=self.max_steps)
        options: dict[str, lx.AbstractLinearOperator] = {}
        if self.preconditioner is not None:
            precond_op = self.preconditioner.as_operator(operator)
            if precond_op is not None:
                options["preconditioner"] = precond_op
        return lx.linear_solve(operator, vector, solver, options=options).value

    def logdet(
        self,
        operator: lx.AbstractLinearOperator,
        *,
        key: jax.Array | None = None,
    ) -> Float[Array, ""]:
        """Stochastic log-determinant via Lanczos quadrature.

        Args:
            operator: A PSD linear operator.
            key: PRNG key for probe vector sampling. If None,
                uses ``jax.random.PRNGKey(0)``.

        Returns:
            Scalar estimate of log |det(A)|.
        """
        return SLQLogdet(
            num_probes=self.num_probes,
            lanczos_order=self.lanczos_order,
        ).logdet(operator, key=key)

solve(operator: lx.AbstractLinearOperator, vector: Float[Array, ' n']) -> Float[Array, ' n']

Solve A x = b with conjugate gradients.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A PSD linear operator A.

required
vector Float[Array, ' n']

Right-hand side b, shape (n,).

required

Returns:

Type Description
Float[Array, ' n']

Solution x, shape (n,).

Source code in src/gaussx/_strategies/_cg.py
def solve(
    self,
    operator: lx.AbstractLinearOperator,
    vector: Float[Array, " n"],
) -> Float[Array, " n"]:
    """Solve ``A x = b`` with conjugate gradients.

    Args:
        operator: A PSD linear operator ``A``.
        vector: Right-hand side ``b``, shape ``(n,)``.

    Returns:
        Solution ``x``, shape ``(n,)``.
    """
    solver = lx.CG(rtol=self.rtol, atol=self.atol, max_steps=self.max_steps)
    options: dict[str, lx.AbstractLinearOperator] = {}
    if self.preconditioner is not None:
        precond_op = self.preconditioner.as_operator(operator)
        if precond_op is not None:
            options["preconditioner"] = precond_op
    return lx.linear_solve(operator, vector, solver, options=options).value

logdet(operator: lx.AbstractLinearOperator, *, key: jax.Array | None = None) -> Float[Array, '']

Stochastic log-determinant via Lanczos quadrature.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A PSD linear operator.

required
key Array | None

PRNG key for probe vector sampling. If None, uses jax.random.PRNGKey(0).

None

Returns:

Type Description
Float[Array, '']

Scalar estimate of log |det(A)|.

Source code in src/gaussx/_strategies/_cg.py
def logdet(
    self,
    operator: lx.AbstractLinearOperator,
    *,
    key: jax.Array | None = None,
) -> Float[Array, ""]:
    """Stochastic log-determinant via Lanczos quadrature.

    Args:
        operator: A PSD linear operator.
        key: PRNG key for probe vector sampling. If None,
            uses ``jax.random.PRNGKey(0)``.

    Returns:
        Scalar estimate of log |det(A)|.
    """
    return SLQLogdet(
        num_probes=self.num_probes,
        lanczos_order=self.lanczos_order,
    ).logdet(operator, key=key)

PreconditionedCGSolver

Bases: AbstractSolverStrategy

CG solver with pivoted partial Cholesky preconditioner.

Uses matfree's low_rank.cholesky_partial_pivot to build a rank-k preconditioner, then solves (sI + LL^T)^{-1} v via the Woodbury identity inside lineax CG.

For operators of the form K + sigma^2 I, preconditioning dramatically reduces the number of CG iterations.

Attributes:

Name Type Description
preconditioner_rank int

Rank of the partial Cholesky. Set to 0 to disable preconditioning (falls back to plain CG).

shift float

Diagonal shift s for the preconditioner. Typically the noise variance sigma^2.

rtol float

Relative tolerance for CG.

atol float

Absolute tolerance for CG.

max_steps int

Maximum CG iterations.

num_probes int

Number of probe vectors for stochastic logdet.

lanczos_order int

Lanczos iterations for SLQ logdet.

seed int

Seed for probe vector generation.

Source code in src/gaussx/_strategies/_precond_cg.py
class PreconditionedCGSolver(AbstractSolverStrategy):
    """CG solver with pivoted partial Cholesky preconditioner.

    Uses matfree's ``low_rank.cholesky_partial_pivot`` to build a
    rank-k preconditioner, then solves ``(sI + LL^T)^{-1} v`` via
    the Woodbury identity inside lineax CG.

    For operators of the form ``K + sigma^2 I``, preconditioning
    dramatically reduces the number of CG iterations.

    Attributes:
        preconditioner_rank: Rank of the partial Cholesky. Set to 0
            to disable preconditioning (falls back to plain CG).
        shift: Diagonal shift ``s`` for the preconditioner.
            Typically the noise variance ``sigma^2``.
        rtol: Relative tolerance for CG.
        atol: Absolute tolerance for CG.
        max_steps: Maximum CG iterations.
        num_probes: Number of probe vectors for stochastic logdet.
        lanczos_order: Lanczos iterations for SLQ logdet.
        seed: Seed for probe vector generation.
    """

    preconditioner_rank: int = 50
    shift: float = 1.0
    rtol: float = 1e-5
    atol: float = 1e-5
    max_steps: int = 1000
    num_probes: int = 20
    lanczos_order: int = 30
    seed: int = 0

    def solve(
        self,
        operator: lx.AbstractLinearOperator,
        vector: Float[Array, " n"],
    ) -> Float[Array, " n"]:
        """Solve A x = b via preconditioned CG.

        Args:
            operator: A PSD linear operator.
            vector: The right-hand side b.

        Returns:
            The solution x.
        """
        preconditioner = PartialCholeskyPreconditioner(
            rank=self.preconditioner_rank, shift=self.shift
        )
        return CGSolver(
            rtol=self.rtol,
            atol=self.atol,
            max_steps=self.max_steps,
            preconditioner=preconditioner,
        ).solve(operator, vector)

    def logdet(
        self,
        operator: lx.AbstractLinearOperator,
        *,
        key: jax.Array | None = None,
    ) -> Float[Array, ""]:
        """Stochastic log-determinant via Lanczos quadrature.

        Args:
            operator: A PSD linear operator.

        Returns:
            Scalar estimate of log |det(A)|.
        """
        return SLQLogdet(
            num_probes=self.num_probes,
            lanczos_order=self.lanczos_order,
            seed=self.seed,
        ).logdet(operator, key=key)

solve(operator: lx.AbstractLinearOperator, vector: Float[Array, ' n']) -> Float[Array, ' n']

Solve A x = b via preconditioned CG.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A PSD linear operator.

required
vector Float[Array, ' n']

The right-hand side b.

required

Returns:

Type Description
Float[Array, ' n']

The solution x.

Source code in src/gaussx/_strategies/_precond_cg.py
def solve(
    self,
    operator: lx.AbstractLinearOperator,
    vector: Float[Array, " n"],
) -> Float[Array, " n"]:
    """Solve A x = b via preconditioned CG.

    Args:
        operator: A PSD linear operator.
        vector: The right-hand side b.

    Returns:
        The solution x.
    """
    preconditioner = PartialCholeskyPreconditioner(
        rank=self.preconditioner_rank, shift=self.shift
    )
    return CGSolver(
        rtol=self.rtol,
        atol=self.atol,
        max_steps=self.max_steps,
        preconditioner=preconditioner,
    ).solve(operator, vector)

logdet(operator: lx.AbstractLinearOperator, *, key: jax.Array | None = None) -> Float[Array, '']

Stochastic log-determinant via Lanczos quadrature.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A PSD linear operator.

required

Returns:

Type Description
Float[Array, '']

Scalar estimate of log |det(A)|.

Source code in src/gaussx/_strategies/_precond_cg.py
def logdet(
    self,
    operator: lx.AbstractLinearOperator,
    *,
    key: jax.Array | None = None,
) -> Float[Array, ""]:
    """Stochastic log-determinant via Lanczos quadrature.

    Args:
        operator: A PSD linear operator.

    Returns:
        Scalar estimate of log |det(A)|.
    """
    return SLQLogdet(
        num_probes=self.num_probes,
        lanczos_order=self.lanczos_order,
        seed=self.seed,
    ).logdet(operator, key=key)

MINRESSolver

Bases: AbstractSolverStrategy

MINRES solver for symmetric (possibly indefinite) systems.

Uses the Lanczos-based MINRES algorithm for the linear solve and matfree's stochastic Lanczos quadrature (SLQ) for the log-determinant. Unlike CG, MINRES only requires symmetry — it works on indefinite and singular systems.

Use cases: EP natural parameters, saddle-point systems, Laplace approximation Hessians.

Attributes:

Name Type Description
rtol float

Relative tolerance for MINRES.

atol float

Absolute tolerance for MINRES.

max_steps int

Maximum MINRES iterations.

shift float

Diagonal shift — solves (A + shift * I) x = b.

num_probes int

Number of probe vectors for stochastic logdet.

lanczos_order int

Order of the Lanczos decomposition for SLQ.

Source code in src/gaussx/_strategies/_minres.py
class MINRESSolver(AbstractSolverStrategy):
    """MINRES solver for symmetric (possibly indefinite) systems.

    Uses the Lanczos-based MINRES algorithm for the linear solve
    and matfree's stochastic Lanczos quadrature (SLQ) for the
    log-determinant. Unlike CG, MINRES only requires symmetry —
    it works on indefinite and singular systems.

    Use cases: EP natural parameters, saddle-point systems,
    Laplace approximation Hessians.

    Attributes:
        rtol: Relative tolerance for MINRES.
        atol: Absolute tolerance for MINRES.
        max_steps: Maximum MINRES iterations.
        shift: Diagonal shift — solves ``(A + shift * I) x = b``.
        num_probes: Number of probe vectors for stochastic logdet.
        lanczos_order: Order of the Lanczos decomposition for SLQ.
    """

    rtol: float = 1e-5
    atol: float = 1e-5
    max_steps: int = 1000
    shift: float = 0.0
    num_probes: int = 20
    lanczos_order: int = 30

    def solve(
        self,
        operator: lx.AbstractLinearOperator,
        vector: Float[Array, " n"],
    ) -> Float[Array, " n"]:
        """Solve ``(A + shift I) x = b`` with MINRES.

        Args:
            operator: A symmetric (possibly indefinite) linear operator ``A``.
            vector: Right-hand side ``b``, shape ``(n,)``.

        Returns:
            Solution ``x``, shape ``(n,)``.
        """
        return _minres_solve(
            operator.mv,
            vector,
            rtol=self.rtol,
            atol=self.atol,
            max_steps=self.max_steps,
            shift=self.shift,
        )

    def logdet(
        self,
        operator: lx.AbstractLinearOperator,
        *,
        key: jax.Array | None = None,
    ) -> Float[Array, ""]:
        """Stochastic ``log|det(A + shift I)|`` via Lanczos quadrature.

        Args:
            operator: A symmetric linear operator.
            key: PRNG key for probe vector sampling. If None,
                uses ``jax.random.PRNGKey(0)``.

        Returns:
            Scalar estimate of ``log|det(A + shift I)|``.
        """
        return IndefiniteSLQLogdet(
            num_probes=self.num_probes,
            lanczos_order=self.lanczos_order,
            shift=self.shift,
        ).logdet(operator, key=key)

solve(operator: lx.AbstractLinearOperator, vector: Float[Array, ' n']) -> Float[Array, ' n']

Solve (A + shift I) x = b with MINRES.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A symmetric (possibly indefinite) linear operator A.

required
vector Float[Array, ' n']

Right-hand side b, shape (n,).

required

Returns:

Type Description
Float[Array, ' n']

Solution x, shape (n,).

Source code in src/gaussx/_strategies/_minres.py
def solve(
    self,
    operator: lx.AbstractLinearOperator,
    vector: Float[Array, " n"],
) -> Float[Array, " n"]:
    """Solve ``(A + shift I) x = b`` with MINRES.

    Args:
        operator: A symmetric (possibly indefinite) linear operator ``A``.
        vector: Right-hand side ``b``, shape ``(n,)``.

    Returns:
        Solution ``x``, shape ``(n,)``.
    """
    return _minres_solve(
        operator.mv,
        vector,
        rtol=self.rtol,
        atol=self.atol,
        max_steps=self.max_steps,
        shift=self.shift,
    )

logdet(operator: lx.AbstractLinearOperator, *, key: jax.Array | None = None) -> Float[Array, '']

Stochastic log|det(A + shift I)| via Lanczos quadrature.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A symmetric linear operator.

required
key Array | None

PRNG key for probe vector sampling. If None, uses jax.random.PRNGKey(0).

None

Returns:

Type Description
Float[Array, '']

Scalar estimate of log|det(A + shift I)|.

Source code in src/gaussx/_strategies/_minres.py
def logdet(
    self,
    operator: lx.AbstractLinearOperator,
    *,
    key: jax.Array | None = None,
) -> Float[Array, ""]:
    """Stochastic ``log|det(A + shift I)|`` via Lanczos quadrature.

    Args:
        operator: A symmetric linear operator.
        key: PRNG key for probe vector sampling. If None,
            uses ``jax.random.PRNGKey(0)``.

    Returns:
        Scalar estimate of ``log|det(A + shift I)|``.
    """
    return IndefiniteSLQLogdet(
        num_probes=self.num_probes,
        lanczos_order=self.lanczos_order,
        shift=self.shift,
    ).logdet(operator, key=key)

LSMRSolver

Bases: AbstractSolverStrategy

LSMR iterative least-squares solver (Fong & Saunders 2011).

Matrix-free solver that only requires matvec and transpose-matvec. Supports Tikhonov regularization via damp parameter: minimizes ||Ax - b||^2 + damp^2 ||x||^2.

Suitable for rectangular, ill-conditioned, or regularized systems.

The undamped path delegates to lineax.LSMR (with lineax's implicit-differentiation rules). lineax's LSMR has no Tikhonov damp parameter, so damped solves use matfree's LSMR, which has a custom VJP for memory-efficient backpropagation.

Attributes:

Name Type Description
atol float

Absolute tolerance.

btol float

Relative tolerance on the residual.

ctol float

Condition number tolerance (the lineax path uses conlim = 1 / ctol).

maxiter int

Maximum iterations.

damp float

Tikhonov damping parameter.

num_probes int

Number of probe vectors for stochastic logdet.

lanczos_order int

Lanczos iterations for SLQ logdet.

seed int

Seed for probe vector generation.

Source code in src/gaussx/_strategies/_lsmr.py
class LSMRSolver(AbstractSolverStrategy):
    """LSMR iterative least-squares solver (Fong & Saunders 2011).

    Matrix-free solver that only requires matvec and transpose-matvec.
    Supports Tikhonov regularization via ``damp`` parameter:
    minimizes ``||Ax - b||^2 + damp^2 ||x||^2``.

    Suitable for rectangular, ill-conditioned, or regularized systems.

    The undamped path delegates to `lineax.LSMR` (with lineax's
    implicit-differentiation rules). lineax's LSMR has no Tikhonov
    ``damp`` parameter, so damped solves use matfree's LSMR, which has
    a custom VJP for memory-efficient backpropagation.

    Attributes:
        atol: Absolute tolerance.
        btol: Relative tolerance on the residual.
        ctol: Condition number tolerance (the lineax path uses
            ``conlim = 1 / ctol``).
        maxiter: Maximum iterations.
        damp: Tikhonov damping parameter.
        num_probes: Number of probe vectors for stochastic logdet.
        lanczos_order: Lanczos iterations for SLQ logdet.
        seed: Seed for probe vector generation.
    """

    atol: float = 1e-6
    btol: float = 1e-6
    ctol: float = 1e-6
    maxiter: int = 1000
    damp: float = 0.0
    num_probes: int = 20
    lanczos_order: int = 30
    seed: int = 0

    def solve(
        self,
        operator: lx.AbstractLinearOperator,
        vector: Float[Array, " m"],
    ) -> Float[Array, " n"]:
        """Solve A x = b via LSMR.

        Args:
            operator: A linear operator (may be rectangular).
            vector: The right-hand side b.

        Returns:
            The (least-squares) solution x.
        """
        if self.damp == 0.0:
            solver = lx.LSMR(
                rtol=self.btol,
                atol=self.atol,
                max_steps=self.maxiter,
                conlim=1.0 / self.ctol if self.ctol > 0 else 1e8,
            )
            return lx.linear_solve(operator, vector, solver).value

        # Tikhonov damping: not supported by lineax's LSMR, so the
        # damped path stays on matfree.
        lsmr_fn = matfree.lstsq.lsmr(
            atol=self.atol,
            btol=self.btol,
            ctol=self.ctol,
            maxiter=self.maxiter,
        )

        def vecmat(v):
            return operator.T.mv(v)

        result = lsmr_fn(vecmat, vector, damp=self.damp)
        return result[0]

    def logdet(
        self,
        operator: lx.AbstractLinearOperator,
        *,
        key: jax.Array | None = None,
    ) -> Float[Array, ""]:
        """Stochastic log-determinant via Lanczos quadrature.

        Args:
            operator: A PSD linear operator.

        Returns:
            Scalar estimate of log |det(A)|.
        """
        return SLQLogdet(
            num_probes=self.num_probes,
            lanczos_order=self.lanczos_order,
            seed=self.seed,
        ).logdet(operator, key=key)

solve(operator: lx.AbstractLinearOperator, vector: Float[Array, ' m']) -> Float[Array, ' n']

Solve A x = b via LSMR.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A linear operator (may be rectangular).

required
vector Float[Array, ' m']

The right-hand side b.

required

Returns:

Type Description
Float[Array, ' n']

The (least-squares) solution x.

Source code in src/gaussx/_strategies/_lsmr.py
def solve(
    self,
    operator: lx.AbstractLinearOperator,
    vector: Float[Array, " m"],
) -> Float[Array, " n"]:
    """Solve A x = b via LSMR.

    Args:
        operator: A linear operator (may be rectangular).
        vector: The right-hand side b.

    Returns:
        The (least-squares) solution x.
    """
    if self.damp == 0.0:
        solver = lx.LSMR(
            rtol=self.btol,
            atol=self.atol,
            max_steps=self.maxiter,
            conlim=1.0 / self.ctol if self.ctol > 0 else 1e8,
        )
        return lx.linear_solve(operator, vector, solver).value

    # Tikhonov damping: not supported by lineax's LSMR, so the
    # damped path stays on matfree.
    lsmr_fn = matfree.lstsq.lsmr(
        atol=self.atol,
        btol=self.btol,
        ctol=self.ctol,
        maxiter=self.maxiter,
    )

    def vecmat(v):
        return operator.T.mv(v)

    result = lsmr_fn(vecmat, vector, damp=self.damp)
    return result[0]

logdet(operator: lx.AbstractLinearOperator, *, key: jax.Array | None = None) -> Float[Array, '']

Stochastic log-determinant via Lanczos quadrature.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A PSD linear operator.

required

Returns:

Type Description
Float[Array, '']

Scalar estimate of log |det(A)|.

Source code in src/gaussx/_strategies/_lsmr.py
def logdet(
    self,
    operator: lx.AbstractLinearOperator,
    *,
    key: jax.Array | None = None,
) -> Float[Array, ""]:
    """Stochastic log-determinant via Lanczos quadrature.

    Args:
        operator: A PSD linear operator.

    Returns:
        Scalar estimate of log |det(A)|.
    """
    return SLQLogdet(
        num_probes=self.num_probes,
        lanczos_order=self.lanczos_order,
        seed=self.seed,
    ).logdet(operator, key=key)

BBMMSolver

Bases: AbstractSolverStrategy

Black-Box Matrix-Matrix solver (Gardner et al. 2018).

Simultaneously solves multiple RHS and computes logdet via modified batched CG (mBCG). Amortizes matvecs across solve and logdet.

Solve: CG via lineax on each RHS column. Logdet: Stochastic Lanczos Quadrature via matfree.

Probe vectors are generated at construction time from seed and stored as frozen state. This makes logdet and solve_and_logdet deterministic functions of the operator — no PRNG key is needed at call time.

Attributes:

Name Type Description
cg_max_iter int

Maximum CG iterations.

cg_tolerance float

Relative tolerance for CG.

lanczos_iter int

Lanczos iterations for SLQ.

num_probes int

Number of probe vectors for Hutchinson.

seed int

Seed for probe vector generation.

Source code in src/gaussx/_strategies/_bbmm.py
class BBMMSolver(AbstractSolverStrategy):
    """Black-Box Matrix-Matrix solver (Gardner et al. 2018).

    Simultaneously solves multiple RHS and computes logdet via
    modified batched CG (mBCG). Amortizes matvecs across solve
    and logdet.

    Solve: CG via lineax on each RHS column.
    Logdet: Stochastic Lanczos Quadrature via matfree.

    Probe vectors are generated at construction time from ``seed``
    and stored as frozen state. This makes ``logdet`` and
    ``solve_and_logdet`` deterministic functions of the operator —
    no PRNG key is needed at call time.

    Attributes:
        cg_max_iter: Maximum CG iterations.
        cg_tolerance: Relative tolerance for CG.
        lanczos_iter: Lanczos iterations for SLQ.
        num_probes: Number of probe vectors for Hutchinson.
        seed: Seed for probe vector generation.
    """

    cg_max_iter: int = 1000
    cg_tolerance: float = 1e-4
    lanczos_iter: int = 100
    num_probes: int = 10
    seed: int = 0

    def solve(
        self,
        operator: lx.AbstractLinearOperator,
        vector: Float[Array, " n"],
    ) -> Float[Array, " n"]:
        """Solve A x = b via CG.

        Args:
            operator: A PSD linear operator.
            vector: The right-hand side b.

        Returns:
            The solution x.
        """
        solver = lx.CG(
            rtol=self.cg_tolerance,
            atol=self.cg_tolerance,
            max_steps=self.cg_max_iter,
        )
        return lx.linear_solve(operator, vector, solver).value

    def logdet(
        self,
        operator: lx.AbstractLinearOperator,
        *,
        key: jax.Array | None = None,
    ) -> Float[Array, ""]:
        """Stochastic log-determinant via Lanczos quadrature.

        Probe vectors are generated deterministically from ``self.seed``.

        Args:
            operator: A PSD linear operator.

        Returns:
            Scalar estimate of log |det(A)|.
        """
        return SLQLogdet(
            num_probes=self.num_probes,
            lanczos_order=self.lanczos_iter,
            seed=self.seed,
        ).logdet(operator, key=key)

    def solve_and_logdet(
        self,
        operator: lx.AbstractLinearOperator,
        vector: Float[Array, " n"],
    ) -> tuple[Float[Array, " n"], Float[Array, ""]]:
        """Joint solve + logdet.

        Computes both solve(A, b) and logdet(A) sharing the
        operator's matvec calls where possible.

        Args:
            operator: A PSD linear operator.
            vector: The right-hand side b.

        Returns:
            Tuple of (solution, log_determinant).
        """
        return self.solve(operator, vector), self.logdet(operator)

solve(operator: lx.AbstractLinearOperator, vector: Float[Array, ' n']) -> Float[Array, ' n']

Solve A x = b via CG.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A PSD linear operator.

required
vector Float[Array, ' n']

The right-hand side b.

required

Returns:

Type Description
Float[Array, ' n']

The solution x.

Source code in src/gaussx/_strategies/_bbmm.py
def solve(
    self,
    operator: lx.AbstractLinearOperator,
    vector: Float[Array, " n"],
) -> Float[Array, " n"]:
    """Solve A x = b via CG.

    Args:
        operator: A PSD linear operator.
        vector: The right-hand side b.

    Returns:
        The solution x.
    """
    solver = lx.CG(
        rtol=self.cg_tolerance,
        atol=self.cg_tolerance,
        max_steps=self.cg_max_iter,
    )
    return lx.linear_solve(operator, vector, solver).value

logdet(operator: lx.AbstractLinearOperator, *, key: jax.Array | None = None) -> Float[Array, '']

Stochastic log-determinant via Lanczos quadrature.

Probe vectors are generated deterministically from self.seed.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A PSD linear operator.

required

Returns:

Type Description
Float[Array, '']

Scalar estimate of log |det(A)|.

Source code in src/gaussx/_strategies/_bbmm.py
def logdet(
    self,
    operator: lx.AbstractLinearOperator,
    *,
    key: jax.Array | None = None,
) -> Float[Array, ""]:
    """Stochastic log-determinant via Lanczos quadrature.

    Probe vectors are generated deterministically from ``self.seed``.

    Args:
        operator: A PSD linear operator.

    Returns:
        Scalar estimate of log |det(A)|.
    """
    return SLQLogdet(
        num_probes=self.num_probes,
        lanczos_order=self.lanczos_iter,
        seed=self.seed,
    ).logdet(operator, key=key)

solve_and_logdet(operator: lx.AbstractLinearOperator, vector: Float[Array, ' n']) -> tuple[Float[Array, ' n'], Float[Array, '']]

Joint solve + logdet.

Computes both solve(A, b) and logdet(A) sharing the operator's matvec calls where possible.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A PSD linear operator.

required
vector Float[Array, ' n']

The right-hand side b.

required

Returns:

Type Description
tuple[Float[Array, ' n'], Float[Array, '']]

Tuple of (solution, log_determinant).

Source code in src/gaussx/_strategies/_bbmm.py
def solve_and_logdet(
    self,
    operator: lx.AbstractLinearOperator,
    vector: Float[Array, " n"],
) -> tuple[Float[Array, " n"], Float[Array, ""]]:
    """Joint solve + logdet.

    Computes both solve(A, b) and logdet(A) sharing the
    operator's matvec calls where possible.

    Args:
        operator: A PSD linear operator.
        vector: The right-hand side b.

    Returns:
        Tuple of (solution, log_determinant).
    """
    return self.solve(operator, vector), self.logdet(operator)

ComposedSolver

Bases: AbstractSolverStrategy

Mix-and-match solve and logdet from different strategies.

This lets you pair, e.g., an exact dense solve with a stochastic log-determinant estimator, or an iterative CG solve with a closed-form Kronecker log-determinant.

Accepts either fine-grained protocols (AbstractSolveStrategy, AbstractLogdetStrategy) or full solver strategies.

Attributes:

Name Type Description
solve_strategy AbstractSolveStrategy

Strategy whose .solve() method will be used.

logdet_strategy AbstractLogdetStrategy

Strategy whose .logdet() method will be used.

Examples:

solver = ComposedSolver(
    solve_strategy=DenseSolver(),
    logdet_strategy=SLQLogdet(num_probes=50, lanczos_order=30),
)
Source code in src/gaussx/_strategies/_composed.py
class ComposedSolver(AbstractSolverStrategy):
    """Mix-and-match solve and logdet from different strategies.

    This lets you pair, e.g., an exact dense solve with a stochastic
    log-determinant estimator, or an iterative CG solve with a
    closed-form Kronecker log-determinant.

    Accepts either fine-grained protocols (`AbstractSolveStrategy`,
    `AbstractLogdetStrategy`) or full solver strategies.

    Attributes:
        solve_strategy: Strategy whose ``.solve()`` method will be used.
        logdet_strategy: Strategy whose ``.logdet()`` method will be used.

    Examples:

        solver = ComposedSolver(
            solve_strategy=DenseSolver(),
            logdet_strategy=SLQLogdet(num_probes=50, lanczos_order=30),
        )
    """

    solve_strategy: AbstractSolveStrategy
    logdet_strategy: AbstractLogdetStrategy

    def solve(
        self,
        operator: lx.AbstractLinearOperator,
        vector: Float[Array, " n"],
    ) -> Float[Array, " n"]:
        """Solve ``A x = b`` by delegating to ``solve_strategy``.

        Args:
            operator: Linear operator ``A``.
            vector: Right-hand side ``b``, shape ``(n,)``.

        Returns:
            Solution ``x``, shape ``(n,)``.
        """
        return self.solve_strategy.solve(operator, vector)

    def logdet(
        self,
        operator: lx.AbstractLinearOperator,
        *,
        key: jax.Array | None = None,
    ) -> Float[Array, ""]:
        """Compute ``log|det(A)|`` by delegating to ``logdet_strategy``.

        Args:
            operator: Linear operator ``A``.
            key: Optional PRNG key forwarded to stochastic estimators.

        Returns:
            Scalar log-determinant.
        """
        return self.logdet_strategy.logdet(operator, key=key)

solve(operator: lx.AbstractLinearOperator, vector: Float[Array, ' n']) -> Float[Array, ' n']

Solve A x = b by delegating to solve_strategy.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

Linear operator A.

required
vector Float[Array, ' n']

Right-hand side b, shape (n,).

required

Returns:

Type Description
Float[Array, ' n']

Solution x, shape (n,).

Source code in src/gaussx/_strategies/_composed.py
def solve(
    self,
    operator: lx.AbstractLinearOperator,
    vector: Float[Array, " n"],
) -> Float[Array, " n"]:
    """Solve ``A x = b`` by delegating to ``solve_strategy``.

    Args:
        operator: Linear operator ``A``.
        vector: Right-hand side ``b``, shape ``(n,)``.

    Returns:
        Solution ``x``, shape ``(n,)``.
    """
    return self.solve_strategy.solve(operator, vector)

logdet(operator: lx.AbstractLinearOperator, *, key: jax.Array | None = None) -> Float[Array, '']

Compute log|det(A)| by delegating to logdet_strategy.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

Linear operator A.

required
key Array | None

Optional PRNG key forwarded to stochastic estimators.

None

Returns:

Type Description
Float[Array, '']

Scalar log-determinant.

Source code in src/gaussx/_strategies/_composed.py
def logdet(
    self,
    operator: lx.AbstractLinearOperator,
    *,
    key: jax.Array | None = None,
) -> Float[Array, ""]:
    """Compute ``log|det(A)|`` by delegating to ``logdet_strategy``.

    Args:
        operator: Linear operator ``A``.
        key: Optional PRNG key forwarded to stochastic estimators.

    Returns:
        Scalar log-determinant.
    """
    return self.logdet_strategy.logdet(operator, key=key)

Logdet strategies

Dense eigendecomposition for exactness; stochastic Lanczos quadrature (SLQ) for \(O(n^2 \cdot \text{rank})\) estimates on large PSD (or symmetric-indefinite) operators.

Structured linear algebra and Gaussian primitives for JAX.

DenseLogdet

Bases: AbstractLogdetStrategy

Dense log-determinant via gaussx structural dispatch.

Delegates to gaussx.logdet which automatically selects the best algorithm based on operator structure (Diagonal, BlockDiag, Kronecker, LowRankUpdate, or dense fallback).

Source code in src/gaussx/_strategies/_slq_logdet.py
class DenseLogdet(AbstractLogdetStrategy):
    """Dense log-determinant via gaussx structural dispatch.

    Delegates to `gaussx.logdet` which automatically selects
    the best algorithm based on operator structure (Diagonal,
    BlockDiag, Kronecker, LowRankUpdate, or dense fallback).
    """

    def logdet(
        self,
        operator: lx.AbstractLinearOperator,
        *,
        key: jax.Array | None = None,
    ) -> Float[Array, ""]:
        """Compute ``log |det(A)|`` via structural dispatch.

        Args:
            operator: A linear operator.
            key: Ignored (deterministic).

        Returns:
            Scalar ``log |det(A)|``.
        """
        from gaussx._primitives._logdet import logdet as _logdet

        return _logdet(operator)

logdet(operator: lx.AbstractLinearOperator, *, key: jax.Array | None = None) -> Float[Array, '']

Compute log |det(A)| via structural dispatch.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A linear operator.

required
key Array | None

Ignored (deterministic).

None

Returns:

Type Description
Float[Array, '']

Scalar log |det(A)|.

Source code in src/gaussx/_strategies/_slq_logdet.py
def logdet(
    self,
    operator: lx.AbstractLinearOperator,
    *,
    key: jax.Array | None = None,
) -> Float[Array, ""]:
    """Compute ``log |det(A)|`` via structural dispatch.

    Args:
        operator: A linear operator.
        key: Ignored (deterministic).

    Returns:
        Scalar ``log |det(A)|``.
    """
    from gaussx._primitives._logdet import logdet as _logdet

    return _logdet(operator)

SLQLogdet

Bases: AbstractLogdetStrategy

Stochastic log-determinant via Lanczos quadrature (SLQ).

Estimates log det(A) for PSD operators using stochastic trace estimation: logdet(A) = tr(log(A)). Uses matfree's Lanczos decomposition with sign-flip ("Rademacher") probe vectors by default.

Attributes:

Name Type Description
num_probes int

Number of probe vectors for Hutchinson estimator.

lanczos_order int

Order of the Lanczos decomposition.

seed int

Seed for probe vector generation (used when no key is passed to logdet).

sampler SamplerName

Probe distribution ("signs", "normal", "sphere").

Source code in src/gaussx/_strategies/_slq_logdet.py
class SLQLogdet(AbstractLogdetStrategy):
    """Stochastic log-determinant via Lanczos quadrature (SLQ).

    Estimates ``log det(A)`` for PSD operators using stochastic
    trace estimation: ``logdet(A) = tr(log(A))``.  Uses matfree's
    Lanczos decomposition with sign-flip ("Rademacher") probe vectors
    by default.

    Attributes:
        num_probes: Number of probe vectors for Hutchinson estimator.
        lanczos_order: Order of the Lanczos decomposition.
        seed: Seed for probe vector generation (used when no
            ``key`` is passed to `logdet`).
        sampler: Probe distribution (``"signs"``, ``"normal"``,
            ``"sphere"``).
    """

    num_probes: int = 20
    lanczos_order: int = 30
    seed: int = 0
    sampler: SamplerName = "signs"

    def _integrand(self, n: int):
        order = min(self.lanczos_order, n)
        tridiag = matfree.decomp.tridiag_sym(order, reortho="full")
        return matfree.funm.monte_carlo_funm_sym_logdet(tridiag)

    def logdet(
        self,
        operator: lx.AbstractLinearOperator,
        *,
        key: jax.Array | None = None,
    ) -> Float[Array, ""]:
        """Stochastic ``log det(A)`` via Lanczos quadrature.

        Args:
            operator: A PSD linear operator.
            key: PRNG key for probe vector sampling.  If ``None``,
                uses ``jax.random.PRNGKey(self.seed)``.

        Returns:
            Scalar estimate of ``log det(A)``.
        """
        if key is None:
            key = jax.random.PRNGKey(self.seed)

        n = operator.in_size()
        point, _ = _slq_estimators(self._integrand(n), n, self.num_probes, self.sampler)
        return point(operator.mv, key)

    def logdet_and_error(
        self,
        operator: lx.AbstractLinearOperator,
        *,
        key: jax.Array | None = None,
    ) -> tuple[Float[Array, ""], Float[Array, ""]]:
        """Stochastic ``log det(A)`` with its standard error.

        Args:
            operator: A PSD linear operator.
            key: PRNG key for probe vector sampling.  If ``None``,
                uses ``jax.random.PRNGKey(self.seed)``.

        Returns:
            Tuple ``(estimate, standard_error)`` where the standard
            error is the standard error of the mean across probes.
        """
        if key is None:
            key = jax.random.PRNGKey(self.seed)

        n = operator.in_size()
        _, with_sem = _slq_estimators(
            self._integrand(n), n, self.num_probes, self.sampler
        )
        return with_sem(operator.mv, key)

logdet(operator: lx.AbstractLinearOperator, *, key: jax.Array | None = None) -> Float[Array, '']

Stochastic log det(A) via Lanczos quadrature.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A PSD linear operator.

required
key Array | None

PRNG key for probe vector sampling. If None, uses jax.random.PRNGKey(self.seed).

None

Returns:

Type Description
Float[Array, '']

Scalar estimate of log det(A).

Source code in src/gaussx/_strategies/_slq_logdet.py
def logdet(
    self,
    operator: lx.AbstractLinearOperator,
    *,
    key: jax.Array | None = None,
) -> Float[Array, ""]:
    """Stochastic ``log det(A)`` via Lanczos quadrature.

    Args:
        operator: A PSD linear operator.
        key: PRNG key for probe vector sampling.  If ``None``,
            uses ``jax.random.PRNGKey(self.seed)``.

    Returns:
        Scalar estimate of ``log det(A)``.
    """
    if key is None:
        key = jax.random.PRNGKey(self.seed)

    n = operator.in_size()
    point, _ = _slq_estimators(self._integrand(n), n, self.num_probes, self.sampler)
    return point(operator.mv, key)

logdet_and_error(operator: lx.AbstractLinearOperator, *, key: jax.Array | None = None) -> tuple[Float[Array, ''], Float[Array, '']]

Stochastic log det(A) with its standard error.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A PSD linear operator.

required
key Array | None

PRNG key for probe vector sampling. If None, uses jax.random.PRNGKey(self.seed).

None

Returns:

Type Description
Float[Array, '']

Tuple (estimate, standard_error) where the standard

Float[Array, '']

error is the standard error of the mean across probes.

Source code in src/gaussx/_strategies/_slq_logdet.py
def logdet_and_error(
    self,
    operator: lx.AbstractLinearOperator,
    *,
    key: jax.Array | None = None,
) -> tuple[Float[Array, ""], Float[Array, ""]]:
    """Stochastic ``log det(A)`` with its standard error.

    Args:
        operator: A PSD linear operator.
        key: PRNG key for probe vector sampling.  If ``None``,
            uses ``jax.random.PRNGKey(self.seed)``.

    Returns:
        Tuple ``(estimate, standard_error)`` where the standard
        error is the standard error of the mean across probes.
    """
    if key is None:
        key = jax.random.PRNGKey(self.seed)

    n = operator.in_size()
    _, with_sem = _slq_estimators(
        self._integrand(n), n, self.num_probes, self.sampler
    )
    return with_sem(operator.mv, key)

IndefiniteSLQLogdet

Bases: AbstractLogdetStrategy

Stochastic log|det(A)| for symmetric (possibly indefinite) operators.

Like SLQLogdet but uses log(|lambda|) as the matrix function, so it works on indefinite and negative-definite matrices. Supports a diagonal shift (A + shift * I).

Attributes:

Name Type Description
num_probes int

Number of probe vectors for Hutchinson estimator.

lanczos_order int

Order of the Lanczos decomposition.

shift float

Diagonal shift applied before computing the logdet.

seed int

Seed for probe vector generation.

sampler SamplerName

Probe distribution ("signs", "normal", "sphere").

Source code in src/gaussx/_strategies/_slq_logdet.py
class IndefiniteSLQLogdet(AbstractLogdetStrategy):
    """Stochastic ``log|det(A)|`` for symmetric (possibly indefinite) operators.

    Like `SLQLogdet` but uses ``log(|lambda|)`` as the matrix
    function, so it works on indefinite and negative-definite matrices.
    Supports a diagonal shift ``(A + shift * I)``.

    Attributes:
        num_probes: Number of probe vectors for Hutchinson estimator.
        lanczos_order: Order of the Lanczos decomposition.
        shift: Diagonal shift applied before computing the logdet.
        seed: Seed for probe vector generation.
        sampler: Probe distribution (``"signs"``, ``"normal"``,
            ``"sphere"``).
    """

    num_probes: int = 20
    lanczos_order: int = 30
    shift: float = 0.0
    seed: int = 0
    sampler: SamplerName = "signs"

    def _shifted_matvec(self, operator: lx.AbstractLinearOperator):
        shift = self.shift

        def matvec(v):
            return operator.mv(v) + shift * v

        return matvec

    def logdet(
        self,
        operator: lx.AbstractLinearOperator,
        *,
        key: jax.Array | None = None,
    ) -> Float[Array, ""]:
        r"""Stochastic ``log|det(A + shift I)|`` via Lanczos quadrature.

        Args:
            operator: A symmetric linear operator.
            key: PRNG key for probe vector sampling.  If ``None``,
                uses ``jax.random.PRNGKey(self.seed)``.

        Returns:
            Scalar estimate of ``log|det(A + shift I)|``.
        """
        if key is None:
            key = jax.random.PRNGKey(self.seed)

        n = operator.in_size()
        order = min(self.lanczos_order, n)
        point, _ = _slq_estimators(
            _logabsdet_integrand(order), n, self.num_probes, self.sampler
        )
        return point(self._shifted_matvec(operator), key)

    def logdet_and_error(
        self,
        operator: lx.AbstractLinearOperator,
        *,
        key: jax.Array | None = None,
    ) -> tuple[Float[Array, ""], Float[Array, ""]]:
        """Stochastic ``log|det(A + shift I)|`` with its standard error.

        Args:
            operator: A symmetric linear operator.
            key: PRNG key for probe vector sampling.  If ``None``,
                uses ``jax.random.PRNGKey(self.seed)``.

        Returns:
            Tuple ``(estimate, standard_error)``.
        """
        if key is None:
            key = jax.random.PRNGKey(self.seed)

        n = operator.in_size()
        order = min(self.lanczos_order, n)
        _, with_sem = _slq_estimators(
            _logabsdet_integrand(order), n, self.num_probes, self.sampler
        )
        return with_sem(self._shifted_matvec(operator), key)

logdet(operator: lx.AbstractLinearOperator, *, key: jax.Array | None = None) -> Float[Array, '']

Stochastic log|det(A + shift I)| via Lanczos quadrature.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A symmetric linear operator.

required
key Array | None

PRNG key for probe vector sampling. If None, uses jax.random.PRNGKey(self.seed).

None

Returns:

Type Description
Float[Array, '']

Scalar estimate of log|det(A + shift I)|.

Source code in src/gaussx/_strategies/_slq_logdet.py
def logdet(
    self,
    operator: lx.AbstractLinearOperator,
    *,
    key: jax.Array | None = None,
) -> Float[Array, ""]:
    r"""Stochastic ``log|det(A + shift I)|`` via Lanczos quadrature.

    Args:
        operator: A symmetric linear operator.
        key: PRNG key for probe vector sampling.  If ``None``,
            uses ``jax.random.PRNGKey(self.seed)``.

    Returns:
        Scalar estimate of ``log|det(A + shift I)|``.
    """
    if key is None:
        key = jax.random.PRNGKey(self.seed)

    n = operator.in_size()
    order = min(self.lanczos_order, n)
    point, _ = _slq_estimators(
        _logabsdet_integrand(order), n, self.num_probes, self.sampler
    )
    return point(self._shifted_matvec(operator), key)

logdet_and_error(operator: lx.AbstractLinearOperator, *, key: jax.Array | None = None) -> tuple[Float[Array, ''], Float[Array, '']]

Stochastic log|det(A + shift I)| with its standard error.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A symmetric linear operator.

required
key Array | None

PRNG key for probe vector sampling. If None, uses jax.random.PRNGKey(self.seed).

None

Returns:

Type Description
tuple[Float[Array, ''], Float[Array, '']]

Tuple (estimate, standard_error).

Source code in src/gaussx/_strategies/_slq_logdet.py
def logdet_and_error(
    self,
    operator: lx.AbstractLinearOperator,
    *,
    key: jax.Array | None = None,
) -> tuple[Float[Array, ""], Float[Array, ""]]:
    """Stochastic ``log|det(A + shift I)|`` with its standard error.

    Args:
        operator: A symmetric linear operator.
        key: PRNG key for probe vector sampling.  If ``None``,
            uses ``jax.random.PRNGKey(self.seed)``.

    Returns:
        Tuple ``(estimate, standard_error)``.
    """
    if key is None:
        key = jax.random.PRNGKey(self.seed)

    n = operator.in_size()
    order = min(self.lanczos_order, n)
    _, with_sem = _slq_estimators(
        _logabsdet_integrand(order), n, self.num_probes, self.sampler
    )
    return with_sem(self._shifted_matvec(operator), key)

Preconditioners

Approximate inverses \(M^{-1} \approx A^{-1}\) that accelerate the iterative solvers above. Pass them via the preconditioner= argument of linear_solve, CGSolver, or PreconditionedCGSolver.

Structured linear algebra and Gaussian primitives for JAX.

AbstractPreconditioner

Bases: Module

Protocol for preconditioners producing an approximate inverse.

Subclasses implement as_operator, returning a PSD lineax operator that applies M^{-1} (or None to disable preconditioning).

Source code in src/gaussx/_preconditioners/_base.py
class AbstractPreconditioner(eqx.Module):
    """Protocol for preconditioners producing an approximate inverse.

    Subclasses implement `as_operator`, returning a PSD lineax operator
    that applies ``M^{-1}`` (or ``None`` to disable preconditioning).
    """

    @abc.abstractmethod
    def as_operator(
        self,
        operator: lx.AbstractLinearOperator | None = None,
    ) -> lx.AbstractLinearOperator | None:
        """Return ``M^{-1}`` as a PSD lineax operator.

        Args:
            operator: The system operator ``A`` being solved. Used by
                data-dependent preconditioners (e.g. partial-Cholesky); ignored
                by static ones.

        Returns:
            A positive-semidefinite operator applying ``M^{-1}``, or ``None`` to
            indicate that no preconditioning should be applied.
        """
        ...

    def __call__(
        self,
        vector: Float[Array, " n"],
        operator: lx.AbstractLinearOperator | None = None,
    ) -> Float[Array, " n"]:
        """Apply ``M^{-1}`` to *vector* (identity when disabled)."""
        op = self.as_operator(operator)
        if op is None:
            return vector
        return op.mv(vector)

as_operator(operator: lx.AbstractLinearOperator | None = None) -> lx.AbstractLinearOperator | None abstractmethod

Return M^{-1} as a PSD lineax operator.

Parameters:

Name Type Description Default
operator AbstractLinearOperator | None

The system operator A being solved. Used by data-dependent preconditioners (e.g. partial-Cholesky); ignored by static ones.

None

Returns:

Type Description
AbstractLinearOperator | None

A positive-semidefinite operator applying M^{-1}, or None to

AbstractLinearOperator | None

indicate that no preconditioning should be applied.

Source code in src/gaussx/_preconditioners/_base.py
@abc.abstractmethod
def as_operator(
    self,
    operator: lx.AbstractLinearOperator | None = None,
) -> lx.AbstractLinearOperator | None:
    """Return ``M^{-1}`` as a PSD lineax operator.

    Args:
        operator: The system operator ``A`` being solved. Used by
            data-dependent preconditioners (e.g. partial-Cholesky); ignored
            by static ones.

    Returns:
        A positive-semidefinite operator applying ``M^{-1}``, or ``None`` to
        indicate that no preconditioning should be applied.
    """
    ...

JacobiPreconditioner

Bases: AbstractPreconditioner

Diagonal preconditioner M^{-1} = diag(1 / diag(A)).

The cheapest preconditioner: scales each coordinate by the reciprocal of the corresponding diagonal entry of A. Effective when A is diagonally dominant.

Attributes:

Name Type Description
diagonal Float[Array, ' n'] | None

The diagonal of A. When None, it is extracted from the operator passed to as_operator via gaussx.diag.

Source code in src/gaussx/_preconditioners/_jacobi.py
class JacobiPreconditioner(AbstractPreconditioner):
    """Diagonal preconditioner ``M^{-1} = diag(1 / diag(A))``.

    The cheapest preconditioner: scales each coordinate by the reciprocal of
    the corresponding diagonal entry of ``A``. Effective when ``A`` is
    diagonally dominant.

    Attributes:
        diagonal: The diagonal of ``A``. When ``None``, it is extracted from the
            operator passed to `as_operator` via `gaussx.diag`.
    """

    diagonal: Float[Array, " n"] | None = None

    def as_operator(
        self,
        operator: lx.AbstractLinearOperator | None = None,
    ) -> lx.AbstractLinearOperator:
        """Return ``diag(1 / d)`` as a PSD operator."""
        d = self.diagonal
        if d is None:
            if operator is None:
                raise ValueError(
                    "JacobiPreconditioner needs either an explicit `diagonal` "
                    "or an operator to extract one from."
                )
            from gaussx._primitives import diag

            d = diag(operator)
        inv = jnp.where(d != 0.0, 1.0 / d, 0.0)
        return lx.TaggedLinearOperator(
            lx.DiagonalLinearOperator(inv), lx.positive_semidefinite_tag
        )

as_operator(operator: lx.AbstractLinearOperator | None = None) -> lx.AbstractLinearOperator

Return diag(1 / d) as a PSD operator.

Source code in src/gaussx/_preconditioners/_jacobi.py
def as_operator(
    self,
    operator: lx.AbstractLinearOperator | None = None,
) -> lx.AbstractLinearOperator:
    """Return ``diag(1 / d)`` as a PSD operator."""
    d = self.diagonal
    if d is None:
        if operator is None:
            raise ValueError(
                "JacobiPreconditioner needs either an explicit `diagonal` "
                "or an operator to extract one from."
            )
        from gaussx._primitives import diag

        d = diag(operator)
    inv = jnp.where(d != 0.0, 1.0 / d, 0.0)
    return lx.TaggedLinearOperator(
        lx.DiagonalLinearOperator(inv), lx.positive_semidefinite_tag
    )

OperatorPreconditioner

Bases: AbstractPreconditioner

Use an externally supplied approximate inverse as a preconditioner.

Attributes:

Name Type Description
approx_inverse AbstractLinearOperator | Callable

The approximate inverse M^{-1}, either as a lineax operator or as a callable v -> M^{-1} v.

in_structure object

Input structure for the callable form. When None it is taken from the system operator passed to as_operator.

Source code in src/gaussx/_preconditioners/_operator.py
class OperatorPreconditioner(AbstractPreconditioner):
    """Use an externally supplied approximate inverse as a preconditioner.

    Attributes:
        approx_inverse: The approximate inverse ``M^{-1}``, either as a lineax
            operator or as a callable ``v -> M^{-1} v``.
        in_structure: Input structure for the callable form. When ``None`` it is
            taken from the system operator passed to `as_operator`.
    """

    approx_inverse: lx.AbstractLinearOperator | Callable
    in_structure: object = None

    def as_operator(
        self,
        operator: lx.AbstractLinearOperator | None = None,
    ) -> lx.AbstractLinearOperator:
        """Return the approximate inverse as a PSD lineax operator."""
        ai = self.approx_inverse
        if isinstance(ai, lx.AbstractLinearOperator):
            if lx.is_positive_semidefinite(ai):
                return ai
            return lx.TaggedLinearOperator(ai, lx.positive_semidefinite_tag)

        structure = self.in_structure
        if structure is None:
            if operator is None:
                raise ValueError(
                    "OperatorPreconditioner with a callable approximate inverse "
                    "needs `in_structure` or a system operator to infer it."
                )
            structure = operator.out_structure()
        return lx.FunctionLinearOperator(ai, structure, lx.positive_semidefinite_tag)

as_operator(operator: lx.AbstractLinearOperator | None = None) -> lx.AbstractLinearOperator

Return the approximate inverse as a PSD lineax operator.

Source code in src/gaussx/_preconditioners/_operator.py
def as_operator(
    self,
    operator: lx.AbstractLinearOperator | None = None,
) -> lx.AbstractLinearOperator:
    """Return the approximate inverse as a PSD lineax operator."""
    ai = self.approx_inverse
    if isinstance(ai, lx.AbstractLinearOperator):
        if lx.is_positive_semidefinite(ai):
            return ai
        return lx.TaggedLinearOperator(ai, lx.positive_semidefinite_tag)

    structure = self.in_structure
    if structure is None:
        if operator is None:
            raise ValueError(
                "OperatorPreconditioner with a callable approximate inverse "
                "needs `in_structure` or a system operator to infer it."
            )
        structure = operator.out_structure()
    return lx.FunctionLinearOperator(ai, structure, lx.positive_semidefinite_tag)

NystromPreconditioner

Bases: AbstractPreconditioner

Low-rank approximate inverse from randomized operator probing.

Builds a rank-k Nyström approximation of a symmetric positive semidefinite operator A and uses it as an approximate inverse. Good when A is available only through matvecs and a handful of probes captures its dominant spectrum.

Algorithm (for PSD A):

  1. Draw a Gaussian probe matrix Omega in R^{n x k} and orthonormalise it via QR to get Q.
  2. Form Y = A Q (k matvecs) and the small matrix B = Q^T Y.
  3. Eigendecompose B = U S U^T and set W = Q U (orthonormal columns).
  4. The approximate inverse is M^{-1} x = a x + W ((s_inv - a) (W^T x)), where s_inv = 1 / |eig(B)| and the scalar fallback a keeps directions outside the captured subspace near the inverse of the smallest captured eigenvalue (so CG does not falsely converge in the preconditioned norm).

Construct via from_operator.

Attributes:

Name Type Description
basis Float[Array, 'n k']

Orthonormal basis W, shape (n, k).

scale Float[Array, ' k']

Per-direction extra scaling s_inv - a, shape (k,).

shift Float[Array, '']

Scalar fallback a applied to the full space.

Source code in src/gaussx/_preconditioners/_nystrom.py
class NystromPreconditioner(AbstractPreconditioner):
    r"""Low-rank approximate inverse from randomized operator probing.

    Builds a rank-``k`` Nyström approximation of a symmetric positive
    semidefinite operator ``A`` and uses it as an approximate inverse. Good when
    ``A`` is available only through matvecs and a handful of probes captures its
    dominant spectrum.

    Algorithm (for PSD ``A``):

    1. Draw a Gaussian probe matrix ``Omega in R^{n x k}`` and orthonormalise it
       via QR to get ``Q``.
    2. Form ``Y = A Q`` (``k`` matvecs) and the small matrix ``B = Q^T Y``.
    3. Eigendecompose ``B = U S U^T`` and set ``W = Q U`` (orthonormal columns).
    4. The approximate inverse is
       ``M^{-1} x = a x + W ((s_inv - a) (W^T x))``,
       where ``s_inv = 1 / |eig(B)|`` and the scalar fallback ``a`` keeps
       directions outside the captured subspace near the inverse of the smallest
       captured eigenvalue (so CG does not falsely converge in the
       preconditioned norm).

    Construct via `from_operator`.

    Attributes:
        basis: Orthonormal basis ``W``, shape ``(n, k)``.
        scale: Per-direction extra scaling ``s_inv - a``, shape ``(k,)``.
        shift: Scalar fallback ``a`` applied to the full space.
    """

    basis: Float[Array, "n k"]
    scale: Float[Array, " k"]
    shift: Float[Array, ""]

    @classmethod
    def from_operator(
        cls,
        operator: lx.AbstractLinearOperator,
        rank: int = 50,
        key: jax.Array | None = None,
    ) -> NystromPreconditioner:
        """Build a Nyström preconditioner by probing *operator*.

        Args:
            operator: A symmetric PSD operator ``A``.
            rank: Number of probe vectors (approximation rank).
            key: PRNG key for the probe matrix. Defaults to
                ``jax.random.PRNGKey(0)``.

        Returns:
            A ready-to-use `NystromPreconditioner`.
        """
        if key is None:
            key = jax.random.PRNGKey(0)

        n = operator.in_size()
        k = min(rank, n)

        omega = jax.random.normal(key, (n, k))
        q, _ = jnp.linalg.qr(omega)

        y = eqx.filter_vmap(operator.mv, in_axes=1, out_axes=1)(q)
        b = einsum(q, y, "n k, n j -> k j")
        # Symmetrize before the eigendecomposition: b is symmetric in exact
        # arithmetic, but floating-point asymmetry in the off-diagonals can
        # perturb eigh. (Matches the convention in _distributions/_conditional.)
        b = symmetrize(b)
        eigvals, u = jnp.linalg.eigh(b)

        abs_eigvals = jnp.abs(eigvals)
        eps = jnp.finfo(abs_eigvals.dtype).eps * n
        s_inv = jnp.where(abs_eigvals > eps, 1.0 / abs_eigvals, 0.0)

        w = einsum(q, u, "n k, k j -> n j")
        # Fallback for uncaptured directions: random probing captures the
        # largest eigenvalues, so uncaptured directions have smaller eigenvalues
        # and larger inverses. The largest captured inverse is the best
        # available proxy and keeps the preconditioned spectrum near 1.
        shift = jnp.max(s_inv)
        scale = s_inv - shift
        return cls(basis=w, scale=scale, shift=shift)

    def as_operator(
        self,
        operator: lx.AbstractLinearOperator | None = None,
    ) -> lx.AbstractLinearOperator:
        """Return the rank-``k`` approximate inverse as a PSD operator."""
        w = self.basis
        scale = self.scale
        shift = self.shift
        structure = jax.ShapeDtypeStruct((w.shape[0],), w.dtype)

        def matvec(x: Float[Array, " n"]) -> Float[Array, " n"]:
            coeffs = einsum(w, x, "n k, n -> k")
            return shift * x + einsum(w, scale * coeffs, "n k, k -> n")

        return lx.FunctionLinearOperator(
            matvec, structure, lx.positive_semidefinite_tag
        )

from_operator(operator: lx.AbstractLinearOperator, rank: int = 50, key: jax.Array | None = None) -> NystromPreconditioner classmethod

Build a Nyström preconditioner by probing operator.

Parameters:

Name Type Description Default
operator AbstractLinearOperator

A symmetric PSD operator A.

required
rank int

Number of probe vectors (approximation rank).

50
key Array | None

PRNG key for the probe matrix. Defaults to jax.random.PRNGKey(0).

None

Returns:

Type Description
NystromPreconditioner

A ready-to-use NystromPreconditioner.

Source code in src/gaussx/_preconditioners/_nystrom.py
@classmethod
def from_operator(
    cls,
    operator: lx.AbstractLinearOperator,
    rank: int = 50,
    key: jax.Array | None = None,
) -> NystromPreconditioner:
    """Build a Nyström preconditioner by probing *operator*.

    Args:
        operator: A symmetric PSD operator ``A``.
        rank: Number of probe vectors (approximation rank).
        key: PRNG key for the probe matrix. Defaults to
            ``jax.random.PRNGKey(0)``.

    Returns:
        A ready-to-use `NystromPreconditioner`.
    """
    if key is None:
        key = jax.random.PRNGKey(0)

    n = operator.in_size()
    k = min(rank, n)

    omega = jax.random.normal(key, (n, k))
    q, _ = jnp.linalg.qr(omega)

    y = eqx.filter_vmap(operator.mv, in_axes=1, out_axes=1)(q)
    b = einsum(q, y, "n k, n j -> k j")
    # Symmetrize before the eigendecomposition: b is symmetric in exact
    # arithmetic, but floating-point asymmetry in the off-diagonals can
    # perturb eigh. (Matches the convention in _distributions/_conditional.)
    b = symmetrize(b)
    eigvals, u = jnp.linalg.eigh(b)

    abs_eigvals = jnp.abs(eigvals)
    eps = jnp.finfo(abs_eigvals.dtype).eps * n
    s_inv = jnp.where(abs_eigvals > eps, 1.0 / abs_eigvals, 0.0)

    w = einsum(q, u, "n k, k j -> n j")
    # Fallback for uncaptured directions: random probing captures the
    # largest eigenvalues, so uncaptured directions have smaller eigenvalues
    # and larger inverses. The largest captured inverse is the best
    # available proxy and keeps the preconditioned spectrum near 1.
    shift = jnp.max(s_inv)
    scale = s_inv - shift
    return cls(basis=w, scale=scale, shift=shift)

as_operator(operator: lx.AbstractLinearOperator | None = None) -> lx.AbstractLinearOperator

Return the rank-k approximate inverse as a PSD operator.

Source code in src/gaussx/_preconditioners/_nystrom.py
def as_operator(
    self,
    operator: lx.AbstractLinearOperator | None = None,
) -> lx.AbstractLinearOperator:
    """Return the rank-``k`` approximate inverse as a PSD operator."""
    w = self.basis
    scale = self.scale
    shift = self.shift
    structure = jax.ShapeDtypeStruct((w.shape[0],), w.dtype)

    def matvec(x: Float[Array, " n"]) -> Float[Array, " n"]:
        coeffs = einsum(w, x, "n k, n -> k")
        return shift * x + einsum(w, scale * coeffs, "n k, k -> n")

    return lx.FunctionLinearOperator(
        matvec, structure, lx.positive_semidefinite_tag
    )

PartialCholeskyPreconditioner

Bases: AbstractPreconditioner

Preconditioner from a pivoted partial Cholesky factor.

Builds a rank-k partial Cholesky factor L of the system operator via matfree, then applies (s I + L L^T)^{-1} through the Woodbury identity. For operators of the form K + sigma^2 I this dramatically reduces CG iteration counts.

Attributes:

Name Type Description
rank int

Rank of the partial Cholesky. <= 0 disables preconditioning (as_operator returns None).

shift float

Diagonal shift s for the preconditioner, typically the noise variance sigma^2.

Source code in src/gaussx/_preconditioners/_partial_cholesky.py
class PartialCholeskyPreconditioner(AbstractPreconditioner):
    """Preconditioner from a pivoted partial Cholesky factor.

    Builds a rank-``k`` partial Cholesky factor ``L`` of the system operator via
    matfree, then applies ``(s I + L L^T)^{-1}`` through the Woodbury identity.
    For operators of the form ``K + sigma^2 I`` this dramatically reduces CG
    iteration counts.

    Attributes:
        rank: Rank of the partial Cholesky. ``<= 0`` disables preconditioning
            (`as_operator` returns ``None``).
        shift: Diagonal shift ``s`` for the preconditioner, typically the noise
            variance ``sigma^2``.
    """

    rank: int = 50
    shift: float = 1.0

    def as_operator(
        self,
        operator: lx.AbstractLinearOperator | None = None,
    ) -> lx.AbstractLinearOperator | None:
        """Build the Woodbury preconditioner operator from *operator*."""
        if self.rank <= 0:
            return None
        if operator is None:
            raise ValueError(
                "PartialCholeskyPreconditioner.as_operator requires the system "
                "operator to build its factor."
            )

        n = operator.in_size()
        rank = min(self.rank, n)

        def mat_el(i, j):
            ej = jnp.zeros(n).at[j].set(1.0)
            return operator.mv(ej)[i]

        chol_fn = matfree.low_rank.cholesky_partial_pivot(mat_el, nrows=n, rank=rank)
        factor, info = chol_fn()
        precond_fn = matfree.low_rank.preconditioner(lambda: (factor, info))

        def precond_matvec(v: Float[Array, " n"]) -> Float[Array, " n"]:
            applied, _ = precond_fn(v, self.shift)
            return applied

        return lx.FunctionLinearOperator(
            precond_matvec,
            operator.out_structure(),
            lx.positive_semidefinite_tag,
        )

as_operator(operator: lx.AbstractLinearOperator | None = None) -> lx.AbstractLinearOperator | None

Build the Woodbury preconditioner operator from operator.

Source code in src/gaussx/_preconditioners/_partial_cholesky.py
def as_operator(
    self,
    operator: lx.AbstractLinearOperator | None = None,
) -> lx.AbstractLinearOperator | None:
    """Build the Woodbury preconditioner operator from *operator*."""
    if self.rank <= 0:
        return None
    if operator is None:
        raise ValueError(
            "PartialCholeskyPreconditioner.as_operator requires the system "
            "operator to build its factor."
        )

    n = operator.in_size()
    rank = min(self.rank, n)

    def mat_el(i, j):
        ej = jnp.zeros(n).at[j].set(1.0)
        return operator.mv(ej)[i]

    chol_fn = matfree.low_rank.cholesky_partial_pivot(mat_el, nrows=n, rank=rank)
    factor, info = chol_fn()
    precond_fn = matfree.low_rank.preconditioner(lambda: (factor, info))

    def precond_matvec(v: Float[Array, " n"]) -> Float[Array, " n"]:
        applied, _ = precond_fn(v, self.shift)
        return applied

    return lx.FunctionLinearOperator(
        precond_matvec,
        operator.out_structure(),
        lx.positive_semidefinite_tag,
    )