Multigrid Helmholtz Solver
This page covers the theory behind the geometric multigrid solver in finitevolX, its variable-coefficient Helmholtz operator, and the three differentiation strategies for computing gradients through the solve.
The Problem
The spectral and capacitance-matrix solvers in finitevolX handle constant-coefficient Helmholtz equations efficiently. However, many ocean modelling tasks require solving the variable-coefficient generalisation:
where \(c(x,y)\) is a spatially varying coefficient (e.g. layer thickness, diffusivity, or Rossby-radius field) and the domain may be irregular (masked).
| Application | Equation | \(c(x,y)\) | \(\lambda\) |
|---|---|---|---|
| Variable-thickness QG | \(\nabla \cdot (H\,\nabla \psi) - \psi/R_d^2 = q\) | Layer thickness \(H(x,y)\) | \(1/R_d^2\) |
| Spatially varying diffusion | \(\nabla \cdot (\kappa\,\nabla T) = f\) | Diffusivity \(\kappa(x,y)\) | 0 |
| Topographic PV inversion | \(\nabla \cdot (f_0^2/N^2 \cdot \nabla \psi) - \beta y\,\psi = q\) | Stratification-dependent | Varies |
Spectral methods cannot handle spatially varying \(c(x,y)\) because the eigenvalue decomposition assumes constant coefficients. CG can handle it, but converges slowly without a good preconditioner. Geometric multigrid is the standard approach: it provides both a fast standalone solver and an excellent preconditioner for CG.
Geometric Multigrid: Overview
Multigrid accelerates iterative solvers by exploiting a hierarchy of progressively coarser grids. The key insight: simple smoothers (e.g. Jacobi) efficiently damp high-frequency error on any grid, but leave low-frequency error untouched. By restricting the problem to a coarser grid, those low-frequency components become high-frequency — and can be damped cheaply.
The V-Cycle
The fundamental building block is the V-cycle, a recursive algorithm that visits each level of the grid hierarchy:
Level 0 (finest) ●───smooth───●───────────────────────●───smooth───●
↓ restrict prolong ↑
Level 1 ·····●───smooth───●───────●───smooth───●·····
↓ restrict prolong ↑
Level 2 (coarsest) ··········●───bottom solve───●··········
Algorithm for v_cycle(u, rhs, level):
- If at the coarsest level: run many Jacobi iterations (bottom solve).
- Pre-smooth: apply \(\nu_1\) weighted Jacobi iterations.
- Residual: compute \(r = f - A u\).
- Restrict: transfer \(r\) to the coarse grid.
- Recurse: solve \(A_c\,e_c = r_c\) on the coarse grid (V-cycle).
- Prolongate: interpolate \(e_c\) back to the fine grid.
- Correct: \(u \leftarrow u + e_{\text{fine}}\).
- Post-smooth: apply \(\nu_2\) weighted Jacobi iterations.
Multiple V-cycles are applied to drive the residual to convergence.
Discrete Operator
Variable-Coefficient Helmholtz Stencil
The operator is discretised using a 5-point finite-volume stencil on a cell-centred grid. Face coefficients \(c_x\) and \(c_y\) live on the staggered faces between cell centres:
Face coefficients are computed by averaging the cell-centre coefficient \(c(x,y)\) to the faces:
where \(m_{x,j,i} = 1\) only when both cells \((j,i)\) and \((j,i+1)\) are wet (inside the mask). This ensures the operator respects irregular domain boundaries.
Boundary Conditions
Zero normal flux at domain edges is enforced by the face coefficients:
boundary faces have zero coefficient, so no flux crosses the domain
boundary. Out-of-bounds neighbours are zero-padded as an implementation
convenience (avoids periodic wrapping from jnp.roll), but the actual
BC is determined by the zeroed face coefficients, not the ghost
values. This means no periodic wrapping at the domain edges — the
natural choice for bounded ocean basins.
Components
Weighted Jacobi Smoother
The smoother updates the solution pointwise:
where \(D = \text{diag}(A)\) is the diagonal of the operator and \(\omega \in (0, 1)\) is the relaxation weight. The diagonal is precomputed during the offline build phase.
Why Jacobi, not Gauss-Seidel?
Red-black Gauss-Seidel is the traditional multigrid smoother, but it
requires sequential updates that are difficult to express efficiently
in JAX. Weighted Jacobi is fully parallel, composes naturally with
jax.lax.fori_loop, and converges well with \(\omega \approx 0.8\text{--}0.95\).
Restriction (Fine to Coarse)
Cell-centred full-weighting restriction averages each \(2 \times 2\) block of fine cells into one coarse cell, weighted by the mask:
This mask-weighted divisor prevents land cells from contaminating the coarse-grid values.
Prolongation (Coarse to Fine)
Bilinear prolongation maps each coarse cell to four fine sub-cells using 9/3/3/1 weights:
The three other sub-cells use analogous stencils with shifted neighbours. The mask-weighted divisor again prevents land contamination.
Bottom Solver
At the coarsest level, the grid is small enough (typically \(8 \times 8\) to \(16 \times 16\)) that iterated Jacobi converges quickly. This avoids the complexity of a dense direct solve while keeping the implementation pure-JAX.
Grid Hierarchy Construction
The build_multigrid_solver factory performs offline precomputation
(using NumPy) to build the level hierarchy:
- Mask coarsening: each level's mask is the 4-point average of the fine mask, thresholded at 0.5.
- Coefficient interpolation: cell-centre \(c(x,y)\) is averaged to face coefficients at each level, then coarsened for the next level.
- Diagonal precomputation: \(D^{-1}\) is computed from the face coefficients and mask at each level.
- Grid spacing doubling: \(\Delta x\) and \(\Delta y\) double at each coarser level.
The resulting MultigridSolver is an immutable equinox.Module —
all arrays are frozen JAX arrays, and all integer parameters
(n_levels, n_pre, etc.) are static fields for efficient JIT
compilation.
Grid size constraint
Both grid dimensions must be divisible by \(2^{L-1}\) where \(L\) is the number of multigrid levels. The factory auto-detects \(L\) by halving until either dimension would drop below 8.
Differentiating Through the Solve
A key advantage of implementing multigrid in JAX is that we can compute gradients of a loss function through the linear solve:
where \(\theta\) parameterises the coefficient field \(c(x,y)\), the RHS \(f\), or both. finitevolX provides three differentiation strategies with different cost/accuracy trade-offs.
Strategy 1: Implicit Differentiation (Default)
The default __call__ method uses jax.lax.custom_linear_solve to
compute gradients via the implicit function theorem (IFT).
Given the linear system \(A u = f\), the gradient of a scalar loss \(\ell(u)\) with respect to the RHS is:
Since \(A\) is symmetric (\(A = A^T\)), the adjoint solve is just another multigrid call with the same operator.
Properties:
- Forward: \(K\) V-cycles (identical to unrolled)
- Backward: 1 multigrid solve (the adjoint equation)
- Memory: \(O(1)\) — no iteration history stored
- Gradients: exact (up to solver tolerance)
When to use
This is the default and recommended mode. Use it whenever you need gradients through the solve (e.g. learning \(c(x,y)\) or \(\lambda\) from data).
Strategy 2: One-Step Differentiation
The solve_onestep method implements the approach of
Bolte, Pauwels & Vaiter (NeurIPS 2023). It runs \(K\) V-cycles
for convergence, but only differentiates through the last cycle:
The gradient approximation error is \(O(\rho)\) where \(\rho\) is the V-cycle convergence rate (typically 0.1--0.3 for multigrid).
Properties:
- Forward: \(K\) V-cycles (identical)
- Backward: 1 V-cycle (autodiff through only the last)
- Memory: \(O(1\text{ V-cycle})\)
- Gradients: approximate, error \(O(\rho)\)
When to use
When you need cheap gradients and can tolerate small approximation error — e.g. in training loops where the solver is called many times and the gradient noise from one-step differentiation is small relative to stochastic gradient noise.
Strategy 3: Unrolled Differentiation
The solve_unrolled method differentiates through every V-cycle
iteration via jax.lax.fori_loop:
Properties:
- Forward: \(K\) V-cycles
- Backward: \(K\) V-cycles (replay all iterations)
- Memory: \(O(K)\) — stores intermediate states for backprop
- Gradients: exact through the iteration process
When to use
When you specifically need gradients through the iteration dynamics itself (e.g. analysing convergence behaviour, or when the number of iterations is very small).
Comparison
| Mode | Backward cost | Memory | Gradient quality |
|---|---|---|---|
Implicit (__call__) |
1 multigrid solve | \(O(1)\) | Exact (IFT) |
One-step (solve_onestep) |
1 V-cycle autodiff | \(O(1)\) | Approximate, \(O(\rho)\) error |
Unrolled (solve_unrolled) |
\(K\) V-cycles | \(O(K)\) | Exact (through iterations) |
Multigrid as a Preconditioner
A single V-cycle is an excellent preconditioner for CG. This is useful when:
- You want the convergence guarantees of CG (Krylov method) but need faster convergence than a spectral preconditioner provides.
- The variable-coefficient problem is poorly conditioned.
- You want to combine multigrid's coarse-grid correction with CG's global optimality.
make_multigrid_preconditioner wraps one V-cycle as a closure compatible
with solve_cg. The CG iteration then converges in very few steps
(often 5--10 instead of hundreds).
Convergence Theory
For a constant-coefficient Poisson/Helmholtz problem on a rectangular domain, multigrid with weighted Jacobi smoothing achieves a convergence rate \(\rho \approx 0.1\text{--}0.3\) per V-cycle, independent of grid size. This means:
- 5 V-cycles reduce the residual by a factor of \(\rho^5 \approx 10^{-3}\text{--}10^{-5}\)
- The cost per V-cycle is \(O(N)\) where \(N = N_y \times N_x\) (each level costs half the previous, geometric series)
- Total cost: \(O(N)\) — optimal for elliptic solves
For variable coefficients and masked domains, the convergence rate depends on the coefficient contrast and domain geometry, but multigrid typically remains much faster than unpreconditioned CG.
Decision Guide
Is c(x,y) constant (or nearly so)?
├── Yes → Use spectral solver (fastest, O(N log N))
│ Or capacitance method for masked domains
└── No (variable coefficient) ↓
Is the domain rectangular (no mask)?
├── Yes → Multigrid standalone (method="multigrid")
│ build_multigrid_solver(np.ones(...), dx, dy, coeff=c)
└── No (masked/irregular domain) ↓
Is the domain simple with small N_b?
├── Yes → Multigrid standalone (handles masks natively)
└── No (complex domain, poor convergence) ↓
Default → Multigrid-preconditioned CG
make_multigrid_preconditioner() + solve_cg()
References
- Briggs, Henson & McCormick (2000) — A Multigrid Tutorial, 2nd ed. (standard reference for geometric multigrid)
- Trottenberg, Oosterlee & Schuller (2001) — Multigrid (comprehensive treatment including variable coefficients)
- Bolte, Pauwels & Vaiter (NeurIPS 2023) — One-step differentiation of iterative algorithms
- Blondel et al. (ICML 2022) — Efficient and modular implicit differentiation
- Louity —
qgsw-pytorch(reference PyTorch multigrid Helmholtz implementation)