Multilayer vs. 3D Discretization
This page explains the crucial distinction between multilayer operators
(via jax.vmap over independent 2D layers) and true 3D operators in
finitevolX, and shows how to use each correctly.
Two Different Meanings of a Leading Axis
When you have an array of shape [N, Ny, Nx] in ocean or atmosphere
modelling, the leading axis N can mean one of two fundamentally different
things:
| Use case | Leading axis meaning | Example models |
|---|---|---|
| Multilayer / baroclinic / modal | Layer or mode index | QG n-layer, baroclinic SW |
| True 3D discretization | Vertical spatial dimension z | Primitive equations |
finitevolX provides explicit support for both, but they must not be
confused because their ghost-cell conventions are opposite.
Multilayer: multilayer(fn)
Concept
In a multilayer model (e.g., n-layer quasi-geostrophic or baroclinic
shallow water), the fluid column is divided into nl discrete layers.
Each layer is a complete, physically real 2D horizontal field with its own
dynamics. There are no "ghost" layers — every slice k = 0 … nl-1
holds a real physical field.
The JAX-idiomatic way to apply any 2D horizontal operator to all layers at
once is jax.vmap, which vectorises the call over the leading axis without
any overhead. finitevolX exposes this pattern through the
multilayer helper:
import jax.numpy as jnp
import finitevolx as fvx
grid = fvx.ArakawaCGrid2D.from_interior(64, 64, 1e4, 1e4)
diff2d = fvx.Difference2D(grid=grid)
nl = 3 # number of layers
h = jnp.ones((nl, grid.Ny, grid.Nx)) # [nl, Ny, Nx]
# Apply diff_x_T_to_U to every layer in one vectorised call
dh_dx = fvx.multilayer(diff2d.diff_x_T_to_U)(h) # shape [nl, Ny, Nx]
For operators that take multiple arguments (e.g., divergence), pass the
bound method directly — :func:jax.vmap batches each positional argument
over the leading axis independently:
u = jnp.ones((nl, grid.Ny, grid.Nx))
v = jnp.ones((nl, grid.Ny, grid.Nx))
div = fvx.multilayer(diff2d.divergence)(u, v)
What multilayer does
- Wraps
fnwithjax.vmap, batching over the first axis. - The 2D stencil (ghost-cell ring, interior writes) is completely unchanged.
- All
nlslices are processed; there is no "outer ghost layer" concept.
True 3D: Difference3D
Concept
In a true 3D primitive-equation model, the leading axis is a spatial
vertical coordinate. The discretization uses the same ghost-cell convention
as the horizontal dimensions: the top (k = 0) and bottom
(k = Nz-1) layers are ghost shells that must be filled by vertical boundary
conditions before chaining operators. Only the interior k = 1 … Nz-2
layers are written by the stencil.
grid3d = fvx.ArakawaCGrid3D.from_interior(64, 64, 10, 1e4, 1e4, 50.0)
diff3d = fvx.Difference3D(grid=grid3d)
# h has shape [Nz, Ny, Nx]; Nz = 12 total (10 interior + 2 ghost)
h = jnp.ones((grid3d.Nz, grid3d.Ny, grid3d.Nx))
dh_dx = diff3d.diff_x_T_to_U(h)
# dh_dx[0] = 0 ← ghost layer (k=0), not written
# dh_dx[1:-1] ≠ 0 ← interior layers, written
# dh_dx[-1] = 0 ← ghost layer (k=Nz-1), not written
Key Differences at a Glance
| Property | multilayer(fn) |
Difference3D |
|---|---|---|
| Leading axis meaning | Layer / mode index | Vertical z dimension |
| Ghost layers at k=0, k=N-1 | ❌ None — all slices are real | ✅ Required — top/bottom are ghost shells |
| Which slices are written | All k = 0 … nl-1 |
Only interior k = 1 … Nz-2 |
| Underlying mechanism | jax.vmap(fn) over axis 0 |
Explicit [1:-1, …] slice writes |
| Correct for | Baroclinic / QG layered models | Primitive-equation 3D models |
Equivalence in the Interior
Although the two approaches differ at the boundary layers, they produce identical results in the horizontal interior for the shared interior z-levels:
import jax.numpy as jnp
import numpy as np
import finitevolx as fvx
grid2d = fvx.ArakawaCGrid2D.from_interior(8, 8, 1.0, 1.0)
grid3d = fvx.ArakawaCGrid3D.from_interior(8, 8, 4, 1.0, 1.0, 1.0)
diff2d = fvx.Difference2D(grid=grid2d)
diff3d = fvx.Difference3D(grid=grid3d)
Nz = grid3d.Nz # 6 total (4 interior + 2 ghost)
h = jnp.arange(Nz * grid3d.Ny * grid3d.Nx, dtype=float).reshape(
Nz, grid3d.Ny, grid3d.Nx
)
ml_result = fvx.multilayer(diff2d.diff_x_T_to_U)(h)
d3_result = diff3d.diff_x_T_to_U(h)
# Interior z-levels AND horizontal interior agree exactly
np.testing.assert_allclose(
ml_result[1:-1, 1:-1, 1:-1],
d3_result[1:-1, 1:-1, 1:-1],
)
# But the boundary layers differ:
# Difference3D: k=0, k=5 are zero (ghost shells)
print(d3_result[0].max()) # 0.0
print(d3_result[-1].max()) # 0.0
# multilayer: k=0, k=5 are real (non-zero) layers
print(ml_result[0].max()) # non-zero
print(ml_result[-1].max()) # non-zero
Summary
- Use
multilayer(fn)when your leading axis indexes independent physical layers or modes. Every layer is real;jax.vmapis the correct, efficient way to batch over them. - Use
Difference3Dwhen your leading axis is a spatial vertical coordinate that requires top/bottom ghost shells for boundary conditions. - The two are not interchangeable. Applying
Difference3Dto a true multilayer field silently zeros out the top and bottom layers, which are physically real in that context.