Kernels & Approximations¶
Low-rank kernel approximations, spectral preconditioning for kernel SGD, kernel two-sample / independence statistics, and the grid helpers behind interpolation-based (KISS-GP style) operators.
Low-rank kernel operators¶
Nyström (\(K \approx K_{nm} K_{mm}^{-1} K_{mn}\)) and random-Fourier-feature
approximations, returned as LowRankUpdate operators so solves
and logdets go through Woodbury automatically.
Structured linear algebra and Gaussian primitives for JAX.
nystrom_operator(K_XZ: Float[Array, 'N M'], K_ZZ_op: lx.AbstractLinearOperator) -> LowRankUpdate
¶
Nystrom low-rank kernel approximation.
Approximates K_{XX} \approx K_{XZ} K_{ZZ}^{-1} K_{ZX} as a
LowRankUpdate with zero base:
K_{XX} \approx U D U^T
where U = K_{XZ} L_{ZZ}^{-T} and D = I (i.e. UU^T),
with L_{ZZ} = cholesky(K_{ZZ}).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
K_XZ
|
Float[Array, 'N M']
|
Cross-covariance between data and inducing points,
shape |
required |
K_ZZ_op
|
AbstractLinearOperator
|
Inducing-point covariance operator, shape |
required |
Returns:
| Type | Description |
|---|---|
LowRankUpdate
|
|
Source code in src/gaussx/_kernels/_kernel_approx.py
rff_operator(X: Float[Array, 'N D'], omega: Float[Array, 'D_rff D'], b: Float[Array, ' D_rff']) -> LowRankUpdate
¶
Random Fourier Features kernel approximation.
Approximates K_{XX} \approx \Phi \Phi^T where:
\Phi_{i,j} = \sqrt{2/D_{rff}} \cos(X_i \cdot \omega_j + b_j)
Returns a LowRankUpdate that never materializes
the N x N matrix.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
X
|
Float[Array, 'N D']
|
Data points, shape |
required |
omega
|
Float[Array, 'D_rff D']
|
Random frequencies, shape |
required |
b
|
Float[Array, ' D_rff']
|
Random phase offsets, shape |
required |
Returns:
| Type | Description |
|---|---|
LowRankUpdate
|
|
Source code in src/gaussx/_kernels/_kernel_approx.py
EigenPro preconditioning¶
Spectral preconditioning for kernel stochastic gradient descent: damp the top eigendirections of the kernel operator so the step size is governed by the residual spectrum (Ma & Belkin, 2017).
Structured linear algebra and Gaussian primitives for JAX.
EigenProPreconditioner
¶
Bases: Module
Spectral correction state for EigenPro kernel SGD.
Stores the top eigenspace of K_mm / m on a subsample and the
corresponding EigenPro correction weights.
Attributes:
| Name | Type | Description |
|---|---|---|
V |
Float[Array, 'm k']
|
Top eigenvectors of |
D |
Float[Array, ' k']
|
Correction weights, shape |
subsample_indices |
Int[Array, ' m']
|
Indices used for the subsample, shape |
max_eigenvalue |
Float[Array, '']
|
Largest eigenvalue of |
beta |
Float[Array, '']
|
Maximum residual kernel diagonal used for the step size. |
Source code in src/gaussx/_kernels/_eigenpro.py
eigenpro_preconditioner(kernel_op: lx.AbstractLinearOperator, *, subsample_size: int = 4000, n_components: int = 100, alpha: float = 0.95, key: jax.Array | None = None) -> EigenProPreconditioner
¶
Build an EigenPro spectral preconditioner from a kernel operator.
The helper samples m rows/columns, eigendecomposes K_mm / m, and
stores the top-k eigenspace correction
D_i = (1 - (λ_{k+1} / λ_i)^α) / λ_i.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
kernel_op
|
AbstractLinearOperator
|
Square kernel linear operator. |
required |
subsample_size
|
int
|
Number of points in the eigendecomposition subsample. |
4000
|
n_components
|
int
|
Number of leading eigenvectors to keep. |
100
|
alpha
|
float
|
Spectral decay exponent in |
0.95
|
key
|
Array | None
|
Optional PRNG key. If omitted, the first |
None
|
Returns:
| Type | Description |
|---|---|
EigenProPreconditioner
|
EigenPro preconditioner state for kernel SGD. |
Source code in src/gaussx/_kernels/_eigenpro.py
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | |
eigenpro_step_size(precond: EigenProPreconditioner, batch_size: int | Float[Array, '']) -> Float[Array, '']
¶
Return the EigenPro preconditioned step size for a batch size.
batch_size may be a Python int or a scalar JAX array; the
function is JIT-friendly with either. Non-positive batch sizes raise
(eagerly when batch_size is a Python int; traced via
eqx.error_if when it is a JAX array).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
precond
|
EigenProPreconditioner
|
The EigenPro preconditioner carrying |
required |
batch_size
|
int | Float[Array, '']
|
Mini-batch size, as a Python int or scalar JAX array. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar preconditioned step size. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Source code in src/gaussx/_kernels/_eigenpro.py
eigenpro_correction(precond: EigenProPreconditioner, K_batch_sub: Float[Array, 'B m'], gradient: Float[Array, 'B C'], step_size: float | Float[Array, '']) -> Float[Array, 'm C']
¶
Compute the EigenPro eigenspace correction for one SGD mini-batch.
step_size may be a Python float or a scalar JAX array. Passing a
JAX scalar avoids recompilation under jax.jit when the value
changes between training steps.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
precond
|
EigenProPreconditioner
|
The EigenPro preconditioner carrying the top eigenvectors
|
required |
K_batch_sub
|
Float[Array, 'B m']
|
Kernel block between the mini-batch and the subsampled
points, shape |
required |
gradient
|
Float[Array, 'B C']
|
Per-example gradients for the mini-batch, shape |
required |
step_size
|
float | Float[Array, '']
|
Scalar step size, as a Python float or scalar JAX array. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'm C']
|
Eigenspace correction over the subsampled points, shape |
Source code in src/gaussx/_kernels/_eigenpro.py
Kernel statistics¶
Centering, the Hilbert-Schmidt independence criterion, and maximum mean discrepancy.
Structured linear algebra and Gaussian primitives for JAX.
center_kernel(K: lx.AbstractLinearOperator) -> lx.MatrixLinearOperator
¶
Center a kernel matrix: H K H.
Computes the centered Gram matrix where H = I - (1/n) 11^T.
Used in kernel PCA, HSIC, and other centered kernel methods.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
K
|
AbstractLinearOperator
|
Kernel (Gram) matrix operator, shape |
required |
Returns:
| Type | Description |
|---|---|
MatrixLinearOperator
|
Centered kernel operator, shape |
Source code in src/gaussx/_kernels/_kernel_approx.py
centering_operator(n: int) -> LowRankUpdate
¶
Centering matrix H = I - (1/n) \mathbf{1}\mathbf{1}^T.
Returns as a LowRankUpdate so the structure is
preserved for downstream operations like H K H.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n
|
int
|
Dimension of the centering matrix. |
required |
Returns:
| Type | Description |
|---|---|
LowRankUpdate
|
|
Source code in src/gaussx/_kernels/_kernel_approx.py
hsic(K_f: lx.AbstractLinearOperator, K_q: lx.AbstractLinearOperator) -> Float[Array, '']
¶
Biased HSIC estimator.
Computes the Hilbert-Schmidt Independence Criterion:
HSIC = (1/n^2) \mathrm{tr}(K_f H K_q H)
where H = I - (1/n) 11^T is the centering matrix.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
K_f
|
AbstractLinearOperator
|
First kernel matrix, shape |
required |
K_q
|
AbstractLinearOperator
|
Second kernel matrix, shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar HSIC estimate. |
Source code in src/gaussx/_kernels/_kernel_approx.py
mmd_squared(K_xx: Float[Array, 'Nx Nx'], K_yy: Float[Array, 'Ny Ny'], K_xy: Float[Array, 'Nx Ny']) -> Float[Array, '']
¶
Biased squared Maximum Mean Discrepancy.
Computes:
MMD^2 = mean(K_{xx}) + mean(K_{yy}) - 2 \cdot mean(K_{xy})
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
K_xx
|
Float[Array, 'Nx Nx']
|
Kernel matrix within first sample, shape |
required |
K_yy
|
Float[Array, 'Ny Ny']
|
Kernel matrix within second sample, shape |
required |
K_xy
|
Float[Array, 'Nx Ny']
|
Cross-kernel matrix between samples, shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '']
|
Scalar biased MMD^2 estimate. |
Source code in src/gaussx/_kernels/_kernel_approx.py
Grids & interpolation¶
Structured linear algebra and Gaussian primitives for JAX.
create_grid(grid_sizes: list[int] | tuple[int, ...], grid_bounds: list[tuple[float, float]] | tuple[tuple[float, float], ...]) -> list[Float[Array, ' n']]
¶
Create a regular grid from per-dimension sizes and bounds.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
grid_sizes
|
list[int] | tuple[int, ...]
|
Number of points per dimension, length D. |
required |
grid_bounds
|
list[tuple[float, float]] | tuple[tuple[float, float], ...]
|
(lo, hi) bounds per dimension, length D. |
required |
Returns:
| Type | Description |
|---|---|
list[Float[Array, ' n']]
|
List of 1-D arrays, one per dimension. |
Source code in src/gaussx/_kernels/_grid.py
grid_data(grid: list[Float[Array, ' n']]) -> Float[Array, 'G D']
¶
Expand a grid to the full Cartesian product of grid points.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
grid
|
list[Float[Array, ' n']]
|
List of 1-D arrays, one per dimension. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'G D']
|
Shape |
Source code in src/gaussx/_kernels/_grid.py
cubic_interpolation_weights(x_target: Float[Array, 'B D'], grid: list[Float[Array, ' n']]) -> tuple[Int[Array, 'B K'], Float[Array, 'B K']]
¶
Compute cubic interpolation indices and weights for SKI.
For each target point, finds the 4ᴰ nearest grid neighbors and computes cubic convolution interpolation weights.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_target
|
Float[Array, 'B D']
|
Target points, shape |
required |
grid
|
list[Float[Array, ' n']]
|
List of 1-D arrays, one per dimension (length D). Each dimension must have ≥ 4 points. |
required |
Returns:
| Type | Description |
|---|---|
Int[Array, 'B K']
|
Tuple |
Float[Array, 'B K']
|
|
tuple[Int[Array, 'B K'], Float[Array, 'B K']]
|
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If any grid dimension has fewer than 4 points. |