Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Solver Strategy Comparison

gaussx provides two solver strategies that pair solve + logdet:

  • DenseSolver — structural dispatch (Cholesky for PSD, etc.)
  • CGSolver — iterative CG solve + stochastic Lanczos logdet

This notebook compares them on the same problem.

For small-to-medium problems (N5000N \lesssim 5000), direct factorization (Cholesky for PSD systems) is optimal: O(N3/3)O(N^3/3) flops with machine-precision accuracy. For larger problems, iterative methods like conjugate gradients (CG) achieve useful accuracy in far fewer flops, especially when the matrix is well-conditioned or a good preconditioner is available.

from __future__ import annotations

import warnings


warnings.filterwarnings("ignore", message=r".*IProgress.*")

import jax
import jax.numpy as jnp
import lineax as lx
import matplotlib.pyplot as plt

import gaussx


jax.config.update("jax_enable_x64", True)

Setup: PSD kernel matrix

key = jax.random.PRNGKey(0)
n = 50

# RBF kernel + noise
x = jnp.linspace(0, 5, n)
sq_dist = (x[:, None] - x[None, :]) ** 2
K = jnp.exp(-0.5 * sq_dist / 1.0**2) + 0.1 * jnp.eye(n)

op = lx.MatrixLinearOperator(K, lx.positive_semidefinite_tag)
b = jax.random.normal(key, (n,))

print(f"Problem size: {n}x{n}")
Problem size: 50x50

DenseSolver

dense = gaussx.DenseSolver()

x_dense = dense.solve(op, b)
ld_dense = dense.logdet(op)

print("DenseSolver:")
print(f"  solve residual: {jnp.max(jnp.abs(op.mv(x_dense) - b)):.2e}")
print(f"  logdet: {ld_dense:.6f}")
DenseSolver:
  solve residual: 1.09e-14
  logdet: -91.659733

CGSolver

The CGSolver pairs conjugate gradients for the linear solve with stochastic Lanczos quadrature (SLQ) for the log-determinant. SLQ exploits the identity logA=tr(logA)\log|A| = \operatorname{tr}(\log A) and then uses Hutchinson’s trace estimator: draw random probe vectors zz and approximate tr(logA)1Pp=1Pzplog(A)zp\operatorname{tr}(\log A) \approx \frac{1}{P}\sum_{p=1}^P z_p^\top \log(A)\, z_p. Each quadratic form zlog(A)zz^\top \log(A)\, z is evaluated via a short Lanczos decomposition, which produces a tridiagonal matrix whose eigenvalues give accurate Gauss quadrature nodes for the spectral integral. See Ubaru et al. (2017) for convergence analysis.

cg = gaussx.CGSolver(rtol=1e-8, atol=1e-8, max_steps=200, num_probes=50)

x_cg = cg.solve(op, b)
ld_cg = cg.logdet(op, key=jax.random.PRNGKey(42))

print("CGSolver:")
print(f"  solve residual: {jnp.max(jnp.abs(op.mv(x_cg) - b)):.2e}")
print(f"  logdet: {ld_cg:.6f}")
CGSolver:
  solve residual: 3.51e-09
  logdet: -90.873256

Comparison

print(f"Solve difference: {jnp.max(jnp.abs(x_dense - x_cg)):.2e}")
print(f"Logdet difference: {jnp.abs(ld_dense - ld_cg):.4f}")

# True logdet for reference
ld_true = jnp.linalg.slogdet(K)[1]
print(f"\nTrue logdet:    {ld_true:.6f}")
print(f"Dense logdet:   {ld_dense:.6f}  (error: {jnp.abs(ld_dense - ld_true):.2e})")
print(f"CG logdet:      {ld_cg:.6f}  (error: {jnp.abs(ld_cg - ld_true):.4f})")
Solve difference: 3.41e-09
Logdet difference: 0.7865

True logdet:    -91.659733
Dense logdet:   -91.659733  (error: 0.00e+00)
CG logdet:      -90.873256  (error: 0.7865)
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Solve comparison
axes[0].plot(x_dense, "C0-", lw=2, label="DenseSolver", alpha=0.8)
axes[0].plot(x_cg, "C1--", lw=2, label="CGSolver", alpha=0.8)
axes[0].set_xlabel("Index")
axes[0].set_ylabel("Solution")
axes[0].set_title("Solve comparison")
axes[0].legend(fontsize=9)
axes[0].grid(True, which="major", alpha=0.3)
axes[0].grid(True, which="minor", alpha=0.1)
axes[0].minorticks_on()

# Solve difference
axes[1].semilogy(jnp.abs(x_dense - x_cg), "C2-")
axes[1].set_xlabel("Index")
axes[1].set_ylabel("|Dense - CG|")
axes[1].set_title("Pointwise solve difference")
axes[1].grid(True, which="major", alpha=0.3)
axes[1].grid(True, which="minor", alpha=0.1)
axes[1].minorticks_on()

plt.tight_layout()
plt.show()
<Figure size 1200x400 with 2 Axes>

When to use which

StrategyBest forSolveLogdet
DenseSolverSmall-medium, structuredExact (structural dispatch)Exact
CGSolverLarge PSD, matrix-freeIterativeStochastic

The DenseSolver is exact and exploits gaussx structural dispatch (Kronecker, BlockDiag, LowRank, Diagonal fast paths). The CGSolver works for any PSD operator, even matrix-free ones where as_matrix() is unavailable, but the logdet is approximate.

The crossover point depends on hardware (GPU memory, FLOP throughput) and matrix conditioning. On modern GPUs, Cholesky can handle N10,000N \sim 10{,}000--50,00050{,}000; beyond that, CG-based methods become necessary.

References

  • Golub, G. H. & Van Loan, C. F. (2013). Matrix Computations. 4th edition, Johns Hopkins University Press.
  • Ubaru, S., Chen, J., & Saad, Y. (2017). Fast estimation of tr(f(A))\operatorname{tr}(f(A)) via stochastic Lanczos quadrature. SIAM J. Matrix Analysis, 38(4), 1075--1099.
  • Hestenes, M. R. & Stiefel, E. (1952). Methods of conjugate gradients for solving linear systems. Journal of Research of the National Bureau of Standards, 49(6), 409--436.