Skip to content

Kernels & Approximations

Low-rank kernel approximations, spectral preconditioning for kernel SGD, kernel two-sample / independence statistics, and the grid helpers behind interpolation-based (KISS-GP style) operators.

Low-rank kernel operators

Nyström (\(K \approx K_{nm} K_{mm}^{-1} K_{mn}\)) and random-Fourier-feature approximations, returned as LowRankUpdate operators so solves and logdets go through Woodbury automatically.

Structured linear algebra and Gaussian primitives for JAX.

nystrom_operator(K_XZ: Float[Array, 'N M'], K_ZZ_op: lx.AbstractLinearOperator) -> LowRankUpdate

Nystrom low-rank kernel approximation.

Approximates K_{XX} \approx K_{XZ} K_{ZZ}^{-1} K_{ZX} as a LowRankUpdate with zero base:

K_{XX} \approx U D U^T

where U = K_{XZ} L_{ZZ}^{-T} and D = I (i.e. UU^T), with L_{ZZ} = cholesky(K_{ZZ}).

Parameters:

Name Type Description Default
K_XZ Float[Array, 'N M']

Cross-covariance between data and inducing points, shape (N, M).

required
K_ZZ_op AbstractLinearOperator

Inducing-point covariance operator, shape (M, M).

required

Returns:

Type Description
LowRankUpdate

LowRankUpdate operator of shape (N, N).

Source code in src/gaussx/_kernels/_kernel_approx.py
def nystrom_operator(
    K_XZ: Float[Array, "N M"],
    K_ZZ_op: lx.AbstractLinearOperator,
) -> LowRankUpdate:
    r"""Nystrom low-rank kernel approximation.

    Approximates ``K_{XX} \approx K_{XZ} K_{ZZ}^{-1} K_{ZX}`` as a
    `LowRankUpdate` with zero base:

        K_{XX} \approx U D U^T

    where ``U = K_{XZ} L_{ZZ}^{-T}`` and ``D = I`` (i.e. ``UU^T``),
    with ``L_{ZZ} = cholesky(K_{ZZ})``.

    Args:
        K_XZ: Cross-covariance between data and inducing points,
            shape ``(N, M)``.
        K_ZZ_op: Inducing-point covariance operator, shape ``(M, M)``.

    Returns:
        `LowRankUpdate` operator of shape ``(N, N)``.
    """

    L = cholesky(K_ZZ_op)
    # U = K_XZ @ L^{-T} = solve(L^T, K_XZ^T)^T
    # Solve L @ A_col = K_XZ^T_col for each column
    from gaussx._linalg._linalg import solve_columns

    K_ZX = K_XZ.T  # (M, N)
    A = solve_columns(L, K_ZX)
    U = A.T  # (N, M)

    N = K_XZ.shape[0]
    M = K_XZ.shape[1]
    base = lx.DiagonalLinearOperator(jnp.zeros(N))
    D = jnp.ones(M)
    return LowRankUpdate(
        base=base,
        U=U,
        d=D,
        V=U,
        tags=frozenset({lx.symmetric_tag, lx.positive_semidefinite_tag}),
    )

rff_operator(X: Float[Array, 'N D'], omega: Float[Array, 'D_rff D'], b: Float[Array, ' D_rff']) -> LowRankUpdate

Random Fourier Features kernel approximation.

Approximates K_{XX} \approx \Phi \Phi^T where:

\Phi_{i,j} = \sqrt{2/D_{rff}} \cos(X_i \cdot \omega_j + b_j)

Returns a LowRankUpdate that never materializes the N x N matrix.

Parameters:

Name Type Description Default
X Float[Array, 'N D']

Data points, shape (N, D).

required
omega Float[Array, 'D_rff D']

Random frequencies, shape (D_rff, D). Sample from the spectral density of the kernel.

required
b Float[Array, ' D_rff']

Random phase offsets, shape (D_rff,). Sample uniformly from [0, 2*pi].

required

Returns:

Type Description
LowRankUpdate

LowRankUpdate operator of shape (N, N).

Source code in src/gaussx/_kernels/_kernel_approx.py
def rff_operator(
    X: Float[Array, "N D"],
    omega: Float[Array, "D_rff D"],
    b: Float[Array, " D_rff"],
) -> LowRankUpdate:
    r"""Random Fourier Features kernel approximation.

    Approximates ``K_{XX} \approx \Phi \Phi^T`` where:

        \Phi_{i,j} = \sqrt{2/D_{rff}} \cos(X_i \cdot \omega_j + b_j)

    Returns a `LowRankUpdate` that never materializes
    the ``N x N`` matrix.

    Args:
        X: Data points, shape ``(N, D)``.
        omega: Random frequencies, shape ``(D_rff, D)``.
            Sample from the spectral density of the kernel.
        b: Random phase offsets, shape ``(D_rff,)``.
            Sample uniformly from ``[0, 2*pi]``.

    Returns:
        `LowRankUpdate` operator of shape ``(N, N)``.
    """
    D_rff = omega.shape[0]
    N = X.shape[0]
    Phi = jnp.sqrt(2.0 / D_rff) * jnp.cos(X @ omega.T + b[None, :])  # (N, D_rff)

    base = lx.DiagonalLinearOperator(jnp.zeros(N))
    D = jnp.ones(D_rff)
    return LowRankUpdate(
        base=base,
        U=Phi,
        d=D,
        V=Phi,
        tags=frozenset({lx.symmetric_tag, lx.positive_semidefinite_tag}),
    )

EigenPro preconditioning

Spectral preconditioning for kernel stochastic gradient descent: damp the top eigendirections of the kernel operator so the step size is governed by the residual spectrum (Ma & Belkin, 2017).

Structured linear algebra and Gaussian primitives for JAX.

EigenProPreconditioner

Bases: Module

Spectral correction state for EigenPro kernel SGD.

Stores the top eigenspace of K_mm / m on a subsample and the corresponding EigenPro correction weights.

Attributes:

Name Type Description
V Float[Array, 'm k']

Top eigenvectors of K_mm / m, shape (m, k).

D Float[Array, ' k']

Correction weights, shape (k,).

subsample_indices Int[Array, ' m']

Indices used for the subsample, shape (m,).

max_eigenvalue Float[Array, '']

Largest eigenvalue of K_mm / m.

beta Float[Array, '']

Maximum residual kernel diagonal used for the step size.

Source code in src/gaussx/_kernels/_eigenpro.py
class EigenProPreconditioner(eqx.Module):
    r"""Spectral correction state for EigenPro kernel SGD.

    Stores the top eigenspace of ``K_mm / m`` on a subsample and the
    corresponding EigenPro correction weights.

    Attributes:
        V: Top eigenvectors of ``K_mm / m``, shape ``(m, k)``.
        D: Correction weights, shape ``(k,)``.
        subsample_indices: Indices used for the subsample, shape ``(m,)``.
        max_eigenvalue: Largest eigenvalue of ``K_mm / m``.
        beta: Maximum residual kernel diagonal used for the step size.
    """

    V: Float[Array, "m k"]
    D: Float[Array, " k"]
    subsample_indices: Int[Array, " m"]
    max_eigenvalue: Float[Array, ""]
    beta: Float[Array, ""]

eigenpro_preconditioner(kernel_op: lx.AbstractLinearOperator, *, subsample_size: int = 4000, n_components: int = 100, alpha: float = 0.95, key: jax.Array | None = None) -> EigenProPreconditioner

Build an EigenPro spectral preconditioner from a kernel operator.

The helper samples m rows/columns, eigendecomposes K_mm / m, and stores the top-k eigenspace correction D_i = (1 - (λ_{k+1} / λ_i)^α) / λ_i.

Parameters:

Name Type Description Default
kernel_op AbstractLinearOperator

Square kernel linear operator.

required
subsample_size int

Number of points in the eigendecomposition subsample.

4000
n_components int

Number of leading eigenvectors to keep.

100
alpha float

Spectral decay exponent in (0, 1].

0.95
key Array | None

Optional PRNG key. If omitted, the first subsample_size points are used deterministically.

None

Returns:

Type Description
EigenProPreconditioner

EigenPro preconditioner state for kernel SGD.

Source code in src/gaussx/_kernels/_eigenpro.py
def eigenpro_preconditioner(
    kernel_op: lx.AbstractLinearOperator,
    *,
    subsample_size: int = 4000,
    n_components: int = 100,
    alpha: float = 0.95,
    key: jax.Array | None = None,
) -> EigenProPreconditioner:
    r"""Build an EigenPro spectral preconditioner from a kernel operator.

    The helper samples ``m`` rows/columns, eigendecomposes ``K_mm / m``, and
    stores the top-``k`` eigenspace correction
    ``D_i = (1 - (λ_{k+1} / λ_i)^α) / λ_i``.

    Args:
        kernel_op: Square kernel linear operator.
        subsample_size: Number of points in the eigendecomposition subsample.
        n_components: Number of leading eigenvectors to keep.
        alpha: Spectral decay exponent in ``(0, 1]``.
        key: Optional PRNG key. If omitted, the first ``subsample_size`` points
            are used deterministically.

    Returns:
        EigenPro preconditioner state for kernel SGD.
    """

    if subsample_size <= 1:
        raise ValueError("subsample_size must be greater than 1.")
    if n_components <= 0:
        raise ValueError("n_components must be positive.")
    if n_components >= subsample_size:
        raise ValueError("n_components must be smaller than subsample_size.")
    if not 0.0 < alpha <= 1.0:
        raise ValueError("alpha must be in (0, 1].")
    if not lx.is_symmetric(kernel_op):
        raise ValueError(
            "kernel_op must be symmetric (tag with ``lx.symmetric_tag`` or "
            "``lx.positive_semidefinite_tag``). EigenPro relies on K_mm and "
            "the data-side diagonal coming from the same symmetric kernel."
        )

    in_shape = kernel_op.in_structure().shape
    out_shape = kernel_op.out_structure().shape
    if len(in_shape) != 1 or len(out_shape) != 1 or in_shape != out_shape:
        raise ValueError("kernel_op must be an unbatched square operator.")

    n = in_shape[0]
    if subsample_size > n:
        raise ValueError("subsample_size cannot exceed the operator size.")

    subsample_indices = _subsample_indices(n, subsample_size, key)
    K_mm = _subsample_matrix(kernel_op, subsample_indices)
    m = K_mm.shape[0]
    K_mm_scaled = symmetrize(K_mm) / m

    # Partial eigendecomposition: top (n_components + 1) eigenpairs via
    # matfree Lanczos. The "+1" gives us the tail eigenvalue needed for
    # the EigenPro correction weights. Falls back to dense eigh for very
    # small ``m`` where Lanczos isn't worth the overhead.
    rank = n_components + 1
    if rank >= m:
        eigvals_all, eigvecs = jnp.linalg.eigh(K_mm_scaled)
        eigvals_all = eigvals_all[::-1]
        eigvecs = eigvecs[:, ::-1]
    else:
        K_mm_op = lx.MatrixLinearOperator(K_mm_scaled, lx.positive_semidefinite_tag)
        eig_key = key if key is not None else jr.PRNGKey(0)
        raw_vals, raw_vecs = eig(K_mm_op, rank=rank, key=eig_key)
        # eig may return complex/unsorted for partial Lanczos; project
        # back to real and sort descending so we keep the top-k.
        eigvals_all = jnp.real(raw_vals)
        eigvecs = jnp.real(raw_vecs)
        order = jnp.argsort(eigvals_all)[::-1]
        eigvals_all = eigvals_all[order]
        eigvecs = eigvecs[:, order]

    top_eigvals = eigvals_all[:n_components]
    tail_eigenvalue = eigvals_all[n_components]
    eps = jnp.finfo(K_mm.dtype).eps
    safe_top_eigvals = jnp.maximum(top_eigvals, eps)
    safe_tail_eigenvalue = jnp.maximum(tail_eigenvalue, eps)
    ratio = jnp.minimum(safe_tail_eigenvalue / safe_top_eigvals, 1.0)
    weights = (1.0 - ratio**alpha) / safe_top_eigvals

    V = eigvecs[:, :n_components]
    beta = _residual_kernel_diagonal(kernel_op, subsample_indices, V, safe_top_eigvals)

    return EigenProPreconditioner(
        V=V,
        D=weights,
        subsample_indices=subsample_indices,
        max_eigenvalue=safe_top_eigvals[0],
        beta=beta,
    )

eigenpro_step_size(precond: EigenProPreconditioner, batch_size: int | Float[Array, '']) -> Float[Array, '']

Return the EigenPro preconditioned step size for a batch size.

batch_size may be a Python int or a scalar JAX array; the function is JIT-friendly with either. Non-positive batch sizes raise (eagerly when batch_size is a Python int; traced via eqx.error_if when it is a JAX array).

Parameters:

Name Type Description Default
precond EigenProPreconditioner

The EigenPro preconditioner carrying beta and the top eigenvalue.

required
batch_size int | Float[Array, '']

Mini-batch size, as a Python int or scalar JAX array.

required

Returns:

Type Description
Float[Array, '']

Scalar preconditioned step size.

Raises:

Type Description
ValueError

If batch_size is a Python int below 1.

Source code in src/gaussx/_kernels/_eigenpro.py
def eigenpro_step_size(
    precond: EigenProPreconditioner,
    batch_size: int | Float[Array, ""],
) -> Float[Array, ""]:
    r"""Return the EigenPro preconditioned step size for a batch size.

    ``batch_size`` may be a Python int or a scalar JAX array; the
    function is JIT-friendly with either. Non-positive batch sizes raise
    (eagerly when ``batch_size`` is a Python int; traced via
    ``eqx.error_if`` when it is a JAX array).

    Args:
        precond: The EigenPro preconditioner carrying ``beta`` and the top
            eigenvalue.
        batch_size: Mini-batch size, as a Python int or scalar JAX array.

    Returns:
        Scalar preconditioned step size.

    Raises:
        ValueError: If ``batch_size`` is a Python int below 1.
    """
    if isinstance(batch_size, int) and batch_size < 1:
        raise ValueError("batch_size must be a positive integer.")
    batch = jnp.asarray(batch_size, dtype=precond.beta.dtype)
    batch = eqx.error_if(batch, batch < 1.0, "batch_size must be >= 1.")
    beta = precond.beta
    lambda_1 = precond.max_eigenvalue
    return jnp.where(
        batch < beta / lambda_1,
        batch / beta,
        (2.0 * batch) / (beta + (batch - 1.0) * lambda_1),
    )

eigenpro_correction(precond: EigenProPreconditioner, K_batch_sub: Float[Array, 'B m'], gradient: Float[Array, 'B C'], step_size: float | Float[Array, '']) -> Float[Array, 'm C']

Compute the EigenPro eigenspace correction for one SGD mini-batch.

step_size may be a Python float or a scalar JAX array. Passing a JAX scalar avoids recompilation under jax.jit when the value changes between training steps.

Parameters:

Name Type Description Default
precond EigenProPreconditioner

The EigenPro preconditioner carrying the top eigenvectors V and spectral weights D.

required
K_batch_sub Float[Array, 'B m']

Kernel block between the mini-batch and the subsampled points, shape (B, m).

required
gradient Float[Array, 'B C']

Per-example gradients for the mini-batch, shape (B, C).

required
step_size float | Float[Array, '']

Scalar step size, as a Python float or scalar JAX array.

required

Returns:

Type Description
Float[Array, 'm C']

Eigenspace correction over the subsampled points, shape (m, C).

Source code in src/gaussx/_kernels/_eigenpro.py
def eigenpro_correction(
    precond: EigenProPreconditioner,
    K_batch_sub: Float[Array, "B m"],
    gradient: Float[Array, "B C"],
    step_size: float | Float[Array, ""],
) -> Float[Array, "m C"]:
    r"""Compute the EigenPro eigenspace correction for one SGD mini-batch.

    ``step_size`` may be a Python float or a scalar JAX array. Passing a
    JAX scalar avoids recompilation under ``jax.jit`` when the value
    changes between training steps.

    Args:
        precond: The EigenPro preconditioner carrying the top eigenvectors
            ``V`` and spectral weights ``D``.
        K_batch_sub: Kernel block between the mini-batch and the subsampled
            points, shape ``(B, m)``.
        gradient: Per-example gradients for the mini-batch, shape ``(B, C)``.
        step_size: Scalar step size, as a Python float or scalar JAX array.

    Returns:
        Eigenspace correction over the subsampled points, shape ``(m, C)``.
    """
    step_size_arr = jnp.asarray(step_size, dtype=gradient.dtype)
    projected_gradient = precond.V.T @ (K_batch_sub.T @ gradient)
    weight_shape = (-1,) + (1,) * (projected_gradient.ndim - 1)
    weighted_gradient = precond.D.reshape(weight_shape) * projected_gradient
    return step_size_arr * (precond.V @ weighted_gradient)

Kernel statistics

Centering, the Hilbert-Schmidt independence criterion, and maximum mean discrepancy.

Structured linear algebra and Gaussian primitives for JAX.

center_kernel(K: lx.AbstractLinearOperator) -> lx.MatrixLinearOperator

Center a kernel matrix: H K H.

Computes the centered Gram matrix where H = I - (1/n) 11^T. Used in kernel PCA, HSIC, and other centered kernel methods.

Parameters:

Name Type Description Default
K AbstractLinearOperator

Kernel (Gram) matrix operator, shape (n, n).

required

Returns:

Type Description
MatrixLinearOperator

Centered kernel operator, shape (n, n).

Source code in src/gaussx/_kernels/_kernel_approx.py
def center_kernel(
    K: lx.AbstractLinearOperator,
) -> lx.MatrixLinearOperator:
    r"""Center a kernel matrix: ``H K H``.

    Computes the centered Gram matrix where ``H = I - (1/n) 11^T``.
    Used in kernel PCA, HSIC, and other centered kernel methods.

    Args:
        K: Kernel (Gram) matrix operator, shape ``(n, n)``.

    Returns:
        Centered kernel operator, shape ``(n, n)``.
    """
    K_mat = K.as_matrix()
    row_mean = jnp.mean(K_mat, axis=1, keepdims=True)
    col_mean = jnp.mean(K_mat, axis=0, keepdims=True)
    total_mean = jnp.mean(K_mat)
    K_centered = K_mat - row_mean - col_mean + total_mean
    tags = frozenset()
    if lx.is_symmetric(K):
        tags = frozenset({lx.symmetric_tag})
    return lx.MatrixLinearOperator(K_centered, tags)

centering_operator(n: int) -> LowRankUpdate

Centering matrix H = I - (1/n) \mathbf{1}\mathbf{1}^T.

Returns as a LowRankUpdate so the structure is preserved for downstream operations like H K H.

Parameters:

Name Type Description Default
n int

Dimension of the centering matrix.

required

Returns:

Type Description
LowRankUpdate

LowRankUpdate operator of shape (n, n).

Source code in src/gaussx/_kernels/_kernel_approx.py
def centering_operator(n: int) -> LowRankUpdate:
    r"""Centering matrix ``H = I - (1/n) \mathbf{1}\mathbf{1}^T``.

    Returns as a `LowRankUpdate` so the structure is
    preserved for downstream operations like ``H K H``.

    Args:
        n: Dimension of the centering matrix.

    Returns:
        `LowRankUpdate` operator of shape ``(n, n)``.
    """
    base = lx.DiagonalLinearOperator(jnp.ones(n))
    ones = jnp.ones((n, 1))
    D = jnp.array([-1.0 / n])
    return LowRankUpdate(base=base, U=ones, d=D, V=ones, tags=frozenset())

hsic(K_f: lx.AbstractLinearOperator, K_q: lx.AbstractLinearOperator) -> Float[Array, '']

Biased HSIC estimator.

Computes the Hilbert-Schmidt Independence Criterion:

HSIC = (1/n^2) \mathrm{tr}(K_f H K_q H)

where H = I - (1/n) 11^T is the centering matrix.

Parameters:

Name Type Description Default
K_f AbstractLinearOperator

First kernel matrix, shape (n, n).

required
K_q AbstractLinearOperator

Second kernel matrix, shape (n, n).

required

Returns:

Type Description
Float[Array, '']

Scalar HSIC estimate.

Source code in src/gaussx/_kernels/_kernel_approx.py
def hsic(
    K_f: lx.AbstractLinearOperator,
    K_q: lx.AbstractLinearOperator,
) -> Float[Array, ""]:
    r"""Biased HSIC estimator.

    Computes the Hilbert-Schmidt Independence Criterion:

        HSIC = (1/n^2) \mathrm{tr}(K_f H K_q H)

    where ``H = I - (1/n) 11^T`` is the centering matrix.

    Args:
        K_f: First kernel matrix, shape ``(n, n)``.
        K_q: Second kernel matrix, shape ``(n, n)``.

    Returns:
        Scalar HSIC estimate.
    """
    from gaussx._linalg._linalg import trace_product

    K_f_centered = center_kernel(K_f)
    K_q_centered = center_kernel(K_q)
    n = K_f_centered.as_matrix().shape[0]
    return trace_product(K_f_centered, K_q_centered) / (n * n)

mmd_squared(K_xx: Float[Array, 'Nx Nx'], K_yy: Float[Array, 'Ny Ny'], K_xy: Float[Array, 'Nx Ny']) -> Float[Array, '']

Biased squared Maximum Mean Discrepancy.

Computes:

MMD^2 = mean(K_{xx}) + mean(K_{yy}) - 2 \cdot mean(K_{xy})

Parameters:

Name Type Description Default
K_xx Float[Array, 'Nx Nx']

Kernel matrix within first sample, shape (Nx, Nx).

required
K_yy Float[Array, 'Ny Ny']

Kernel matrix within second sample, shape (Ny, Ny).

required
K_xy Float[Array, 'Nx Ny']

Cross-kernel matrix between samples, shape (Nx, Ny).

required

Returns:

Type Description
Float[Array, '']

Scalar biased MMD^2 estimate.

Source code in src/gaussx/_kernels/_kernel_approx.py
def mmd_squared(
    K_xx: Float[Array, "Nx Nx"],
    K_yy: Float[Array, "Ny Ny"],
    K_xy: Float[Array, "Nx Ny"],
) -> Float[Array, ""]:
    r"""Biased squared Maximum Mean Discrepancy.

    Computes:

        MMD^2 = mean(K_{xx}) + mean(K_{yy}) - 2 \cdot mean(K_{xy})

    Args:
        K_xx: Kernel matrix within first sample, shape ``(Nx, Nx)``.
        K_yy: Kernel matrix within second sample, shape ``(Ny, Ny)``.
        K_xy: Cross-kernel matrix between samples, shape ``(Nx, Ny)``.

    Returns:
        Scalar biased MMD^2 estimate.
    """
    return jnp.mean(K_xx) + jnp.mean(K_yy) - 2.0 * jnp.mean(K_xy)

Grids & interpolation

Structured linear algebra and Gaussian primitives for JAX.

create_grid(grid_sizes: list[int] | tuple[int, ...], grid_bounds: list[tuple[float, float]] | tuple[tuple[float, float], ...]) -> list[Float[Array, ' n']]

Create a regular grid from per-dimension sizes and bounds.

Parameters:

Name Type Description Default
grid_sizes list[int] | tuple[int, ...]

Number of points per dimension, length D.

required
grid_bounds list[tuple[float, float]] | tuple[tuple[float, float], ...]

(lo, hi) bounds per dimension, length D.

required

Returns:

Type Description
list[Float[Array, ' n']]

List of 1-D arrays, one per dimension.

Source code in src/gaussx/_kernels/_grid.py
def create_grid(
    grid_sizes: list[int] | tuple[int, ...],
    grid_bounds: list[tuple[float, float]] | tuple[tuple[float, float], ...],
) -> list[Float[Array, " n"]]:
    """Create a regular grid from per-dimension sizes and bounds.

    Args:
        grid_sizes: Number of points per dimension, length D.
        grid_bounds: (lo, hi) bounds per dimension, length D.

    Returns:
        List of 1-D arrays, one per dimension.
    """
    return [
        jnp.linspace(lo, hi, n)
        for n, (lo, hi) in zip(grid_sizes, grid_bounds, strict=True)
    ]

grid_data(grid: list[Float[Array, ' n']]) -> Float[Array, 'G D']

Expand a grid to the full Cartesian product of grid points.

Parameters:

Name Type Description Default
grid list[Float[Array, ' n']]

List of 1-D arrays, one per dimension.

required

Returns:

Type Description
Float[Array, 'G D']

Shape (prod(sizes), D) array of all grid points.

Source code in src/gaussx/_kernels/_grid.py
def grid_data(grid: list[Float[Array, " n"]]) -> Float[Array, "G D"]:
    """Expand a grid to the full Cartesian product of grid points.

    Args:
        grid: List of 1-D arrays, one per dimension.

    Returns:
        Shape ``(prod(sizes), D)`` array of all grid points.
    """
    meshes = jnp.meshgrid(*grid, indexing="ij")  # D arrays of shape sizes
    stacked = jnp.stack(meshes, axis=0)  # (D, *sizes)
    return rearrange(stacked, "D ... -> (...) D")

cubic_interpolation_weights(x_target: Float[Array, 'B D'], grid: list[Float[Array, ' n']]) -> tuple[Int[Array, 'B K'], Float[Array, 'B K']]

Compute cubic interpolation indices and weights for SKI.

For each target point, finds the 4ᴰ nearest grid neighbors and computes cubic convolution interpolation weights.

Parameters:

Name Type Description Default
x_target Float[Array, 'B D']

Target points, shape (B, D).

required
grid list[Float[Array, ' n']]

List of 1-D arrays, one per dimension (length D). Each dimension must have ≥ 4 points.

required

Returns:

Type Description
Int[Array, 'B K']

Tuple (indices, weights) both of shape (B, 4ᴰ):

Float[Array, 'B K']
  • indices: Flat indices into the grid (product of sizes).
tuple[Int[Array, 'B K'], Float[Array, 'B K']]
  • weights: Interpolation weights summing to ≈ 1.

Raises:

Type Description
ValueError

If any grid dimension has fewer than 4 points.

Source code in src/gaussx/_kernels/_grid.py
def cubic_interpolation_weights(
    x_target: Float[Array, "B D"],
    grid: list[Float[Array, " n"]],
) -> tuple[Int[Array, "B K"], Float[Array, "B K"]]:
    """Compute cubic interpolation indices and weights for SKI.

    For each target point, finds the 4ᴰ nearest grid neighbors and
    computes cubic convolution interpolation weights.

    Args:
        x_target: Target points, shape ``(B, D)``.
        grid: List of 1-D arrays, one per dimension (length D).
            Each dimension must have ≥ 4 points.

    Returns:
        Tuple ``(indices, weights)`` both of shape ``(B, 4ᴰ)``:

        - ``indices``: Flat indices into the grid (product of sizes).
        - ``weights``: Interpolation weights summing to ≈ 1.

    Raises:
        ValueError: If any grid dimension has fewer than 4 points.
    """
    D = len(grid)
    grid_sizes = [g.shape[0] for g in grid]

    for d, n in enumerate(grid_sizes):
        if n < 4:
            msg = (
                f"Cubic interpolation requires at least 4 grid points per "
                f"dimension, but dimension {d} has {n}."
            )
            raise ValueError(msg)

    # Strides for manual flat indexing: strides[d] = ∏ sizes[d+1:]
    strides = []
    s = 1
    for d in range(D - 1, -1, -1):
        strides.append(s)
        s *= grid_sizes[d]
    strides = strides[::-1]

    # Per-dimension: 4 indices and weights
    dim_indices = []  # each (B, 4)
    dim_weights = []  # each (B, 4)

    for d in range(D):
        g = grid[d]
        n = g.shape[0]
        x_d = x_target[:, d]  # (B,)

        h = (g[-1] - g[0]) / (n - 1)  # cell spacing
        cont_idx = (x_d - g[0]) / h  # continuous index, (B,)

        # Integer cell index, clamped for 4 neighbors
        cell = jnp.floor(cont_idx).astype(jnp.int32)
        cell = jnp.clip(cell, 1, n - 3)  # (B,)

        t = cont_idx - cell.astype(x_d.dtype)  # fractional part, (B,)

        # 4 neighbor indices: cell−1, cell, cell+1, cell+2
        idx = jnp.stack([cell - 1, cell, cell + 1, cell + 2], axis=-1)  # (B, 4)
        idx = jnp.clip(idx, 0, n - 1)

        dim_indices.append(idx)
        dim_weights.append(_cubic_weights_1d(t))

    # Cross-dimensional outer product of indices/weights
    flat_indices = dim_indices[0] * strides[0]  # (B, 4)
    weights = dim_weights[0]  # (B, 4)

    for d in range(1, D):
        # Outer product: (B, K) × (B, 4) → (B, K·4)
        fi_exp = flat_indices[:, :, None]  # (B, K, 1)
        di_exp = dim_indices[d][:, None, :] * strides[d]  # (B, 1, 4)
        flat_indices = rearrange(fi_exp + di_exp, "B K four -> B (K four)")

        w_exp = weights[:, :, None]  # (B, K, 1)
        dw_exp = dim_weights[d][:, None, :]  # (B, 1, 4)
        weights = rearrange(w_exp * dw_exp, "B K four -> B (K four)")

    return flat_indices, weights