Finite-volume advection
Part 5.3: Finite Volume Advection on a Staggered Grid¶
Staggered grids (Arakawa C-grid) place velocities at cell faces and scalars at cell centers. Coordax’s distinct named axes prevent accidental mixing of the two grids.
import coordax as cx
import jax.numpy as jnp
R_EARTH = 6.371e6 # meters
# NOTE: cx.field() is preferred over cx.wrap() in coordax >= 0.2Staggered Grid Configuration¶
T-grid (centers): |--T--|--T--|--T--|
U-grid (int faces): |--U--|--U--|U[i] lies between T[i] and T[i+1].
n_lat = 16
n_lon_t = 32 # T-grid (cell centers)
n_lon_u = n_lon_t - 1 # U-grid (interior faces only)
print(f"T-grid: {n_lat} × {n_lon_t} (lat × lon_centers)")
print(f"U-grid: {n_lat} × {n_lon_u} (lat × lon_faces)")
lat_values = jnp.linspace(-75, 75, n_lat) # avoid poles to prevent dx→0
lon_t_values = jnp.linspace(0, 360, n_lon_t, endpoint=False)
dlon = float(360.0 / n_lon_t)
lon_u_values = lon_t_values[:-1] + dlon / 2.0 # faces between centers
lat_axis = cx.LabeledAxis('latitude', lat_values)
lon_t_axis = cx.LabeledAxis('longitude', lon_t_values)
lon_u_axis = cx.LabeledAxis('longitude_u', lon_u_values)
print(f"\nT-axis: '{lon_t_axis.dims[0]}', len={len(lon_t_values)}")
print(f"U-axis: '{lon_u_axis.dims[0]}', len={len(lon_u_values)}")
print("✓ Distinct coordinate systems prevent accidental mixing!")T-grid: 16 × 32 (lat × lon_centers)
U-grid: 16 × 31 (lat × lon_faces)
T-axis: 'longitude', len=32
U-axis: 'longitude_u', len=31
✓ Distinct coordinate systems prevent accidental mixing!
Temperature Field (T-grid, cell centers)¶
lat_mesh_t, lon_mesh_t = jnp.meshgrid(lat_values, lon_t_values, indexing='ij')
T0, DeltaT, wave_amp, wave_k = 288.0, 30.0, 0.2, 3
T_data = (T0
+ DeltaT * jnp.cos(jnp.deg2rad(lat_mesh_t))
* (1 + wave_amp * jnp.sin(wave_k * jnp.deg2rad(lon_mesh_t))))
T_field = cx.field(T_data, lat_axis, lon_t_axis)
# shape: (16, 32) | dims: ('latitude','longitude') | units: K
print(f"T field: {T_field.dims}, shape: {T_field.shape}")
print(f"T range: {float(T_field.data.min()):.1f} — {float(T_field.data.max()):.1f} K")T field: ('latitude', 'longitude'), shape: (16, 32)
T range: 294.2 — 323.9 K
Velocity Field (U-grid, cell faces)¶
lat_mesh_u, lon_mesh_u = jnp.meshgrid(lat_values, lon_u_values, indexing='ij')
U0, lat_jet, sigma_lat, beta = 20.0, 45.0, 15.0, 0.3
U_data = (U0
* jnp.exp(-((lat_mesh_u - lat_jet) / sigma_lat)**2)
* (1 + beta * jnp.cos(wave_k * jnp.deg2rad(lon_mesh_u))))
U_field = cx.field(U_data, lat_axis, lon_u_axis)
# shape: (16, 31) | dims: ('latitude','longitude_u') | units: m/s
print(f"U field: {U_field.dims}, shape: {U_field.shape}")
print(f"U range: {float(U_field.data.min()):.1f} — {float(U_field.data.max()):.1f} m/s")
print(f"\n✓ T: {T_field.dims}")
print(f"✓ U: {U_field.dims} ← different coordinate system!")U field: ('latitude', 'longitude_u'), shape: (16, 31)
U range: 0.0 — 26.0 m/s
✓ T: ('latitude', 'longitude')
✓ U: ('latitude', 'longitude_u') ← different coordinate system!
Interpolate Temperature from T-grid to U-grid¶
Centered average of adjacent cell centers.
def interpolate_centers_to_interior_faces(field_centers, axis_name='longitude'):
"""
Interpolate from cell centers to interior cell faces.
T_face[i] = 0.5 * (T_center[i] + T_center[i+1])
"""
axis_pos = field_centers.dims.index(axis_name)
data = field_centers.data
sl_l = [slice(None)] * data.ndim
sl_l[axis_pos] = slice(0, -1)
sl_r = [slice(None)] * data.ndim
sl_r[axis_pos] = slice(1, None)
# T_face[i] = ½·(T_center[i] + T_center[i+1]) shape: (..., n−1, ...)
return 0.5 * (data[tuple(sl_l)] + data[tuple(sl_r)])
T_at_U_data = interpolate_centers_to_interior_faces(T_field, axis_name='longitude')
T_at_U = cx.field(T_at_U_data, lat_axis, lon_u_axis)
# shape: (16, 31) | dims: ('latitude','longitude_u') | units: K — T interpolated to faces
print(f"T at U-points: {T_at_U.dims}, shape: {T_at_U.shape}")
print(f"Interpolation error (mean): {float(jnp.abs(T_field.data.mean() - T_at_U.data.mean())):.2e} K")T at U-points: ('latitude', 'longitude_u'), shape: (16, 31)
Interpolation error (mean): 3.81e-02 K
Advective Flux: Centered Scheme¶
flux_centered_data = U_field.data * T_at_U.data
flux_centered = cx.field(flux_centered_data, lat_axis, lon_u_axis)
# shape: (16, 31) | dims: ('latitude','longitude_u') | units: K·m/s — F = U·T̄
print(f"Centered flux: {flux_centered.dims}, shape: {flux_centered.shape}")
print(f"Flux range: {float(flux_centered.data.min()):.1f} — {float(flux_centered.data.max()):.1f} K·m/s")Centered flux: ('latitude', 'longitude_u'), shape: (16, 31)
Flux range: 0.0 — 8040.9 K·m/s
Advective Flux: Upwind Scheme¶
def upwind_flux_interior(U_field, T_field_centers, axis_name='longitude'):
"""
Upwind advective flux on interior faces.
If U[i] > 0: use T_center[i] (upstream)
If U[i] < 0: use T_center[i+1] (upstream)
"""
axis_pos = T_field_centers.dims.index(axis_name)
T_data = T_field_centers.data
U_data = U_field.data
sl_l = [slice(None)] * T_data.ndim
sl_l[axis_pos] = slice(0, -1)
sl_r = [slice(None)] * T_data.ndim
sl_r[axis_pos] = slice(1, None)
T_left = T_data[tuple(sl_l)]
T_right = T_data[tuple(sl_r)]
T_upwind = jnp.where(U_data > 0, T_left, T_right)
# F_up[i] = U[i]·T_up[i], T_up = T_left if U>0 else T_right
return U_data * T_upwind
flux_upwind_data = upwind_flux_interior(U_field, T_field, axis_name='longitude')
flux_upwind = cx.field(flux_upwind_data, lat_axis, lon_u_axis)
# shape: (16, 31) | dims: ('latitude','longitude_u') | units: K·m/s — upwind flux
print(f"Upwind flux: {flux_upwind.dims}, shape: {flux_upwind.shape}")
diff = flux_upwind.data - flux_centered.data
print(f"Max |upwind - centered|: {jnp.max(jnp.abs(diff)):.1f} K·m/s")Upwind flux: ('latitude', 'longitude_u'), shape: (16, 31)
Max |upwind - centered|: 31.8 K·m/s
Flux Divergence: Faces → Interior Centers¶
def flux_divergence_to_centers(flux_at_faces, axis_name='longitude_u'):
"""
Compute delta-flux at interior centers from interior faces.
div(F)[i] ∝ F[i] − F[i-1] (face right − face left)
"""
axis_pos = flux_at_faces.dims.index(axis_name)
flux_data = flux_at_faces.data
sl_l = [slice(None)] * flux_data.ndim
sl_l[axis_pos] = slice(0, -1)
sl_r = [slice(None)] * flux_data.ndim
sl_r[axis_pos] = slice(1, None)
# div(F)[i] = F[i] − F[i−1] shape: (..., n−2, ...) — delta-flux at interior centers
return flux_data[tuple(sl_r)] - flux_data[tuple(sl_l)]
delta_flux_data = flux_divergence_to_centers(flux_upwind, axis_name='longitude_u')
lon_t_interior_axis = cx.LabeledAxis('longitude', lon_t_values[1:-1])
delta_flux = cx.field(delta_flux_data, lat_axis, lon_t_interior_axis)
# shape: (16, 30) | dims: ('latitude','longitude') | units: K·m/s — interior centers only
print(f"Delta-flux (interior centers): {delta_flux.dims}, shape: {delta_flux.shape}")Delta-flux (interior centers): ('latitude', 'longitude'), shape: (16, 30)
Temperature Tendency ∂T/∂t¶
dlon_rad = jnp.deg2rad(dlon)
dx = R_EARTH * jnp.cos(jnp.deg2rad(lat_values)) * dlon_rad # (n_lat,)
dx_broadcast = dx[:, None]
# ∂T/∂t = −div(F) / Δx [K/s] Δx(φ) = R·cos(φ)·Δλ
dT_dt_data = -delta_flux.data / dx_broadcast
dT_dt = cx.field(dT_dt_data, lat_axis, lon_t_interior_axis)
# shape: (16, 30) | dims: ('latitude','longitude') | units: K/s
print(f"∂T/∂t: {dT_dt.dims}, shape: {dT_dt.shape}")
print(f"Range: {float(dT_dt.data.min()):.2e} — {float(dT_dt.data.max()):.2e} K/s")
print(f"Max heating: {float(dT_dt.data.max()) * 3600:.2f} K/h")∂T/∂t: ('latitude', 'longitude'), shape: (16, 30)
Range: -1.19e-03 — 1.22e-03 K/s
Max heating: 4.38 K/h
Forward Euler Time Step¶
# CFL check
dt = 100.0 # seconds
cfl = float(jnp.abs(U_field.data).max()) * dt / float(dx.min())
print(f"CFL = {cfl:.3f} (must be < 1 for stability)")
T_new_data = T_field.data.at[:, 1:-1].add(dt * dT_dt.data)
T_new = cx.field(T_new_data, lat_axis, lon_t_axis)
# shape: (16, 32) | dims: ('latitude','longitude') | units: K — full T-grid restored
print(f"Updated T: {T_new.dims}, shape: {T_new.shape}")
print(f"Updated T range: {float(T_new.data.min()):.1f} — {float(T_new.data.max()):.1f} K")
# Area-weighted conservation check
cos_lat = jnp.cos(jnp.deg2rad(lat_mesh_t))
mass_i = float(jnp.sum(T_field.data * cos_lat))
mass_f = float(jnp.sum(T_new.data * cos_lat))
print(f"Mass conservation: {abs(mass_f - mass_i) / abs(mass_i):.2e} relative change")CFL = 0.008 (must be < 1 for stability)
Updated T: ('latitude', 'longitude'), shape: (16, 32)
Updated T range: 294.2 — 323.9 K
Mass conservation: 1.87e-06 relative change
Complete One-Step Advection Function¶
def advection_step_staggered(T_field, U_field, dt, lat_values, dlon_deg, scheme='upwind'):
"""
One advection step on an Arakawa C-grid.
Returns updated T_field (same shape / axes).
"""
lat_axis = T_field.axes['latitude']
lon_t_axis = T_field.axes['longitude']
lon_u_axis = U_field.axes['longitude_u']
if scheme == 'upwind':
flux_data = upwind_flux_interior(U_field, T_field, axis_name='longitude')
else:
T_at_U_data = interpolate_centers_to_interior_faces(T_field, axis_name='longitude')
flux_data = U_field.data * T_at_U_data
flux = cx.field(flux_data, lat_axis, lon_u_axis)
delta_flux_data = flux_divergence_to_centers(flux, axis_name='longitude_u')
dlon_rad = jnp.deg2rad(dlon_deg)
dx = R_EARTH * jnp.cos(jnp.deg2rad(lat_values)) * dlon_rad
dT_dt_data = -delta_flux_data / dx[:, None]
T_new_data = T_field.data.at[:, 1:-1].add(dt * dT_dt_data)
return cx.field(T_new_data, lat_axis, lon_t_axis)
T_updated = advection_step_staggered(
T_field, U_field, dt=100.0, lat_values=lat_values, dlon_deg=dlon, scheme='upwind'
)
# shape: (16, 32) | dims: ('latitude','longitude') | units: K
print(f"Complete step — T: {T_updated.dims}, shape: {T_updated.shape}")
print(f"T range: {float(T_updated.data.min()):.1f} — {float(T_updated.data.max()):.1f} K")Complete step — T: ('latitude', 'longitude'), shape: (16, 32)
T range: 294.2 — 323.9 K
Key Patterns Summary¶
- T-grid (
longitude) and U-grid (longitude_u) are distinct coordinate systems - Interpolation:
T_face[i] = 0.5*(T[i] + T[i+1])(centers → interior faces) - Upwind flux: choose upstream value based on sign of velocity
- Flux divergence:
F[i] - F[i-1]at interior centers - Tendency:
∂T/∂t = -delta_F / dx; updateT[:, 1:-1]only