Multigrid Helmholtz Solver
Geometric multigrid solver for variable-coefficient Helmholtz equations on masked Arakawa C-grids.
Solver
finitevolx.MultigridSolver
Bases: Module
Geometric multigrid V-cycle solver for the variable-coefficient Helmholtz equation::
div(c(x,y) grad u) - lambda u = rhs
Supports spatially varying coefficients c(x, y) on staggered faces
and masked (irregular) domains.
V-Cycle Algorithm
The V-cycle is a recursive algorithm that visits coarser grids to correct low-frequency error that the smoother cannot resolve::
Level 0 (fine) *---smooth---*-----------*---smooth---*
| restrict prolong |
Level 1 .....*---smooth---*---smooth---*.....
| restrict prolong |
Level 2 (coarse) ..........*--bottom solve--*..........
At each level:
- Pre-smooth (nu_1 weighted Jacobi iterations): damp high-frequency error.
- Compute residual:
r = rhs - A u. - Restrict residual to the coarse grid (2x2 averaging).
- Recurse: solve for the error on the coarse grid.
- Prolongate the coarse correction back to the fine grid (bilinear interpolation).
- Post-smooth (nu_2 weighted Jacobi iterations): damp any high-frequency error introduced by the prolongation.
The recursion is statically unrolled at JAX trace time because
each level has a different array shape. All integer parameters
(n_levels, n_pre, etc.) are eqx.field(static=True),
so the unrolled structure is visible to the XLA compiler.
Differentiation Modes
Three solve modes trade off backward-pass cost vs gradient accuracy:
-
__call__— Implicit differentiation viajax.lax.custom_linear_solve(symmetric=True). The backward pass solves the adjoint systemA^T v = dL/duwith one multigrid call. SinceAis symmetric, this costs the same as the forward pass. O(1) memory, exact gradients for the linear system being solved (gradient accuracy is limited by how well the V-cycles approximateA^{-1}, i.e. depends onn_cyclesand smoother settings). -
solve_onestep— One-step differentiation (Bolte, Pauwels & Vaiter, NeurIPS 2023). Runs K V-cycles, appliesstop_gradientafter K-1, then autodiffs through the last cycle only. O(1 V-cycle) memory, approximate gradients with error O(rho). -
solve_unrolled— Unrolled differentiation viajax.lax.fori_loop. Backward replays all K iterations. O(K) memory, exact through-iteration gradients (reproduces the forward computation exactly, so gradient accuracy matches the forward solve accuracy).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
levels
|
tuple of MultigridLevel
|
Precomputed level data, finest (index 0) to coarsest (index L-1). |
required |
n_levels
|
int
|
Number of multigrid levels. |
required |
n_pre
|
int
|
Pre- and post-smoothing iterations (weighted Jacobi). |
required |
n_post
|
int
|
Pre- and post-smoothing iterations (weighted Jacobi). |
required |
n_coarse
|
int
|
Jacobi iterations on the coarsest grid (bottom solver). |
required |
omega
|
float
|
Jacobi relaxation weight (typically 0.8-0.95). |
required |
n_cycles
|
int
|
Number of V-cycles per solve. |
required |
Source code in finitevolx/_src/solvers/multigrid.py
606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 | |
__call__(rhs)
Solve A u = rhs with implicit differentiation.
The forward pass runs K V-cycles (identical to :meth:solve_unrolled).
The backward pass uses jax.lax.custom_linear_solve with
symmetric=True to compute gradients via the implicit function
theorem (IFT) rather than unrolling through V-cycle iterations.
For a scalar loss L(u), the gradient w.r.t. the RHS is::
dL/d(rhs) = A^{-T} dL/du = A^{-1} dL/du (since A = A^T)
This adjoint solve is just another multigrid call — so the backward pass costs the same as the forward pass, with O(1) extra memory (no iteration history stored).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rhs
|
Float[Array, 'Ny Nx']
|
Right-hand side of the linear system. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'Ny Nx']
|
Approximate solution |
Source code in finitevolx/_src/solvers/multigrid.py
solve_onestep(rhs)
Solve with one-step differentiation (Bolte et al., NeurIPS 2023).
Runs K V-cycles to convergence, but autodiff only sees the last
cycle. The first K-1 cycles are wrapped in jax.lax.stop_gradient
so they contribute no backward-pass cost.
The forward result is identical to :meth:solve_unrolled. The
gradient approximation error is O(rho) where rho is the per-cycle
convergence rate (typically 0.1-0.3 for multigrid).
Gradient structure::
u_0 = 0
u_1 = V(u_0, rhs)
...
u_{K-1} = V(u_{K-2}, rhs) <-- stop_gradient here
u_K = V(u_{K-1}, rhs) <-- autodiff traces only this
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rhs
|
Float[Array, 'Ny Nx']
|
Right-hand side. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'Ny Nx']
|
Approximate solution. |
References
Bolte, Pauwels & Vaiter (NeurIPS 2023). "One-step differentiation of iterative algorithms." https://arxiv.org/abs/2305.13768
Source code in finitevolx/_src/solvers/multigrid.py
solve_unrolled(rhs)
Solve by unrolling all V-cycles through lax.fori_loop.
The backward pass differentiates through every iteration, storing intermediate states for replay. This costs O(n_cycles) memory.
Use this mode when you specifically need gradients through the
iteration dynamics itself. For most applications, prefer
__call__ (implicit differentiation) which gives exact gradients
at O(1) memory cost.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rhs
|
Float[Array, 'Ny Nx']
|
Right-hand side. |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, 'Ny Nx']
|
Approximate solution. |
Source code in finitevolx/_src/solvers/multigrid.py
v_cycle(u, rhs, level_idx=0)
Execute a single multigrid V-cycle starting at level_idx.
Algorithm::
if coarsest level:
return jacobi(u, rhs, n_coarse) # bottom solve
u = jacobi(u, rhs, n_pre) # 1. pre-smooth
r = rhs - A(u) # 2. compute residual
r_c = restrict(r) # 3. restrict to coarse grid
e_c = v_cycle(0, r_c, level+1) # 4. recurse (solve A_c e_c = r_c)
u = u + prolongate(e_c) # 5. correct with coarse error
u = jacobi(u, rhs, n_post) # 6. post-smooth
The recursion unrolls statically at JAX trace time because
level_idx and n_levels are Python ints (static fields).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
u
|
Float[Array, 'Ny Nx']
|
Initial guess (typically zeros for the error equation on coarse grids, or the current iterate on the fine grid). |
required |
rhs
|
Float[Array, 'Ny Nx']
|
Right-hand side (original RHS on the fine grid, or the restricted residual on coarser grids). |
required |
level_idx
|
int
|
Current level (0 = finest, n_levels-1 = coarsest). |
0
|
Returns:
| Type | Description |
|---|---|
Float[Array, 'Ny Nx']
|
Improved solution after one V-cycle. |
Source code in finitevolx/_src/solvers/multigrid.py
Factory
finitevolx.build_multigrid_solver(mask, dx, dy, lambda_=0.0, coeff=None, n_levels=None, n_pre=6, n_post=6, n_coarse=50, omega=0.95, n_cycles=5)
Build a multigrid solver with precomputed level hierarchies.
This is an offline function (runs once on CPU with NumPy) that constructs the entire multigrid hierarchy:
- Mask coarsening: at each level, the cell mask is coarsened by 2x via 4-point averaging (threshold >= 0.5).
- Coefficient interpolation: the cell-centre coefficient
c(x, y)is averaged to staggered face coefficientscx,cyat each level, then coarsened for the next level. - Diagonal precomputation: the inverse diagonal
D^{-1}of the Helmholtz operator is computed at each level for the Jacobi smoother. - Grid spacing doubling:
dxanddydouble at each coarser level.
Grid hierarchy example (64x64, auto levels)::
Level 0: 64 x 64 (dx, dy) <- finest (solve here)
Level 1: 32 x 32 (2*dx, 2*dy)
Level 2: 16 x 16 (4*dx, 4*dy)
Level 3: 8 x 8 (8*dx, 8*dy) <- coarsest (bottom solve)
The returned MultigridSolver is an immutable equinox.Module
with frozen JAX arrays. All subsequent calls (forward solves,
gradients, JIT compilation) use the precomputed hierarchy.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mask
|
array, shape (Ny, Nx), or ArakawaCGridMask
|
Domain mask (1 = fluid, 0 = land). |
required |
dx
|
float
|
Fine-grid spacings (metres or non-dimensional). |
required |
dy
|
float
|
Fine-grid spacings (metres or non-dimensional). |
required |
lambda_
|
float
|
Helmholtz parameter (>= 0). Use 0.0 for pure Poisson (Laplacian
only). For QG PV inversion, |
0.0
|
coeff
|
array, shape (Ny, Nx), or None
|
Spatially varying coefficient |
None
|
n_levels
|
int or None
|
Number of multigrid levels. |
None
|
n_pre
|
int
|
Number of pre- and post-smoothing Jacobi iterations per V-cycle. More smoothing improves the convergence rate but increases cost per cycle. Default: 6 each. |
6
|
n_post
|
int
|
Number of pre- and post-smoothing Jacobi iterations per V-cycle. More smoothing improves the convergence rate but increases cost per cycle. Default: 6 each. |
6
|
n_coarse
|
int
|
Number of Jacobi iterations on the coarsest grid (bottom solver). The coarsest grid is small (typically 8x8), so this is cheap. Default: 50. |
50
|
omega
|
float
|
Jacobi relaxation weight (0 < omega < 1). Under-relaxation improves smoothing stability. Default: 0.95. |
0.95
|
n_cycles
|
int
|
Number of V-cycles applied per solve. 5 cycles typically reduce the residual by 3-5 orders of magnitude. Default: 5. |
5
|
Returns:
| Type | Description |
|---|---|
MultigridSolver
|
Ready-to-use solver (JIT-compilable |
Raises:
| Type | Description |
|---|---|
ValueError
|
If grid dimensions are not divisible by |
Source code in finitevolx/_src/solvers/multigrid.py
905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 | |
Preconditioner
finitevolx.make_multigrid_preconditioner(mg_solver)
Return a preconditioner closure that applies a single multigrid V-cycle.
The returned callable approximates A^{-1} r by running one V-cycle
from a zero initial guess, which is sufficient as a preconditioner
(it doesn't need to converge — it just needs to be a good approximation).
This is compatible with :func:~finitevolx._src.solvers.iterative.solve_cg:
pass the returned closure as the preconditioner argument. CG then
converges in very few iterations (typically 5-10 instead of hundreds)
because multigrid captures both high- and low-frequency components of
the inverse.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mg_solver
|
MultigridSolver
|
A pre-built multigrid solver
(from :func: |
required |
Returns:
| Type | Description |
|---|---|
callable
|
|
Examples:
>>> mg = build_multigrid_solver(mask, dx, dy, lambda_=10.0)
>>> precond = make_multigrid_preconditioner(mg)
>>> u, info = solve_cg(A, rhs, preconditioner=precond)