Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Finite-volume advection

Part 5.3: Finite Volume Advection on a Staggered Grid

Staggered grids (Arakawa C-grid) place velocities at cell faces and scalars at cell centers. Coordax’s distinct named axes prevent accidental mixing of the two grids.

import coordax as cx
import jax.numpy as jnp

R_EARTH = 6.371e6  # meters

# NOTE: cx.field() is preferred over cx.wrap() in coordax >= 0.2

Staggered Grid Configuration

T-grid (centers):  |--T--|--T--|--T--|
U-grid (int faces):    |--U--|--U--|

U[i] lies between T[i] and T[i+1].

n_lat = 16
n_lon_t = 32   # T-grid (cell centers)
n_lon_u = n_lon_t - 1  # U-grid (interior faces only)

print(f"T-grid: {n_lat} × {n_lon_t}  (lat × lon_centers)")
print(f"U-grid: {n_lat} × {n_lon_u}  (lat × lon_faces)")

lat_values   = jnp.linspace(-75, 75, n_lat)    # avoid poles to prevent dx→0
lon_t_values = jnp.linspace(0, 360, n_lon_t, endpoint=False)
dlon         = float(360.0 / n_lon_t)
lon_u_values = lon_t_values[:-1] + dlon / 2.0   # faces between centers

lat_axis   = cx.LabeledAxis('latitude', lat_values)
lon_t_axis = cx.LabeledAxis('longitude',   lon_t_values)
lon_u_axis = cx.LabeledAxis('longitude_u', lon_u_values)

print(f"\nT-axis: '{lon_t_axis.dims[0]}', len={len(lon_t_values)}")
print(f"U-axis: '{lon_u_axis.dims[0]}', len={len(lon_u_values)}")
print("✓ Distinct coordinate systems prevent accidental mixing!")
T-grid: 16 × 32  (lat × lon_centers)
U-grid: 16 × 31  (lat × lon_faces)

T-axis: 'longitude', len=32
U-axis: 'longitude_u', len=31
✓ Distinct coordinate systems prevent accidental mixing!

Temperature Field (T-grid, cell centers)

lat_mesh_t, lon_mesh_t = jnp.meshgrid(lat_values, lon_t_values, indexing='ij')

T0, DeltaT, wave_amp, wave_k = 288.0, 30.0, 0.2, 3
T_data = (T0
          + DeltaT * jnp.cos(jnp.deg2rad(lat_mesh_t))
          * (1 + wave_amp * jnp.sin(wave_k * jnp.deg2rad(lon_mesh_t))))

T_field = cx.field(T_data, lat_axis, lon_t_axis)
# shape: (16, 32) | dims: ('latitude','longitude') | units: K
print(f"T field: {T_field.dims}, shape: {T_field.shape}")
print(f"T range: {float(T_field.data.min()):.1f} — {float(T_field.data.max()):.1f} K")
T field: ('latitude', 'longitude'), shape: (16, 32)
T range: 294.2 — 323.9 K

Velocity Field (U-grid, cell faces)

lat_mesh_u, lon_mesh_u = jnp.meshgrid(lat_values, lon_u_values, indexing='ij')

U0, lat_jet, sigma_lat, beta = 20.0, 45.0, 15.0, 0.3
U_data = (U0
          * jnp.exp(-((lat_mesh_u - lat_jet) / sigma_lat)**2)
          * (1 + beta * jnp.cos(wave_k * jnp.deg2rad(lon_mesh_u))))

U_field = cx.field(U_data, lat_axis, lon_u_axis)
# shape: (16, 31) | dims: ('latitude','longitude_u') | units: m/s
print(f"U field: {U_field.dims}, shape: {U_field.shape}")
print(f"U range: {float(U_field.data.min()):.1f} — {float(U_field.data.max()):.1f} m/s")
print(f"\n✓ T: {T_field.dims}")
print(f"✓ U: {U_field.dims}  ← different coordinate system!")
U field: ('latitude', 'longitude_u'), shape: (16, 31)
U range: 0.0 — 26.0 m/s

✓ T: ('latitude', 'longitude')
✓ U: ('latitude', 'longitude_u')  ← different coordinate system!

Interpolate Temperature from T-grid to U-grid

Centered average of adjacent cell centers.

def interpolate_centers_to_interior_faces(field_centers, axis_name='longitude'):
    """
    Interpolate from cell centers to interior cell faces.

    T_face[i] = 0.5 * (T_center[i] + T_center[i+1])
    """
    axis_pos = field_centers.dims.index(axis_name)
    data = field_centers.data

    sl_l = [slice(None)] * data.ndim
    sl_l[axis_pos] = slice(0, -1)
    sl_r = [slice(None)] * data.ndim
    sl_r[axis_pos] = slice(1, None)

    # T_face[i] = ½·(T_center[i] + T_center[i+1])   shape: (..., n−1, ...)
    return 0.5 * (data[tuple(sl_l)] + data[tuple(sl_r)])


T_at_U_data = interpolate_centers_to_interior_faces(T_field, axis_name='longitude')
T_at_U = cx.field(T_at_U_data, lat_axis, lon_u_axis)
# shape: (16, 31) | dims: ('latitude','longitude_u') | units: K  — T interpolated to faces
print(f"T at U-points: {T_at_U.dims}, shape: {T_at_U.shape}")
print(f"Interpolation error (mean): {float(jnp.abs(T_field.data.mean() - T_at_U.data.mean())):.2e} K")
T at U-points: ('latitude', 'longitude_u'), shape: (16, 31)
Interpolation error (mean): 3.81e-02 K

Advective Flux: Centered Scheme

flux_centered_data = U_field.data * T_at_U.data
flux_centered = cx.field(flux_centered_data, lat_axis, lon_u_axis)
# shape: (16, 31) | dims: ('latitude','longitude_u') | units: K·m/s  — F = U·T̄
print(f"Centered flux: {flux_centered.dims}, shape: {flux_centered.shape}")
print(f"Flux range: {float(flux_centered.data.min()):.1f} — {float(flux_centered.data.max()):.1f} K·m/s")
Centered flux: ('latitude', 'longitude_u'), shape: (16, 31)
Flux range: 0.0 — 8040.9 K·m/s

Advective Flux: Upwind Scheme

def upwind_flux_interior(U_field, T_field_centers, axis_name='longitude'):
    """
    Upwind advective flux on interior faces.

    If U[i] > 0: use T_center[i]   (upstream)
    If U[i] < 0: use T_center[i+1] (upstream)
    """
    axis_pos = T_field_centers.dims.index(axis_name)
    T_data = T_field_centers.data
    U_data = U_field.data

    sl_l = [slice(None)] * T_data.ndim
    sl_l[axis_pos] = slice(0, -1)
    sl_r = [slice(None)] * T_data.ndim
    sl_r[axis_pos] = slice(1, None)
    T_left  = T_data[tuple(sl_l)]
    T_right = T_data[tuple(sl_r)]

    T_upwind = jnp.where(U_data > 0, T_left, T_right)
    # F_up[i] = U[i]·T_up[i],  T_up = T_left if U>0 else T_right
    return U_data * T_upwind


flux_upwind_data = upwind_flux_interior(U_field, T_field, axis_name='longitude')
flux_upwind = cx.field(flux_upwind_data, lat_axis, lon_u_axis)
# shape: (16, 31) | dims: ('latitude','longitude_u') | units: K·m/s  — upwind flux
print(f"Upwind flux: {flux_upwind.dims}, shape: {flux_upwind.shape}")

diff = flux_upwind.data - flux_centered.data
print(f"Max |upwind - centered|: {jnp.max(jnp.abs(diff)):.1f} K·m/s")
Upwind flux: ('latitude', 'longitude_u'), shape: (16, 31)
Max |upwind - centered|: 31.8 K·m/s

Flux Divergence: Faces → Interior Centers

def flux_divergence_to_centers(flux_at_faces, axis_name='longitude_u'):
    """
    Compute delta-flux at interior centers from interior faces.

    div(F)[i] ∝ F[i] − F[i-1]   (face right − face left)
    """
    axis_pos = flux_at_faces.dims.index(axis_name)
    flux_data = flux_at_faces.data

    sl_l = [slice(None)] * flux_data.ndim
    sl_l[axis_pos] = slice(0, -1)
    sl_r = [slice(None)] * flux_data.ndim
    sl_r[axis_pos] = slice(1, None)
    # div(F)[i] = F[i] − F[i−1]   shape: (..., n−2, ...)  — delta-flux at interior centers
    return flux_data[tuple(sl_r)] - flux_data[tuple(sl_l)]


delta_flux_data = flux_divergence_to_centers(flux_upwind, axis_name='longitude_u')

lon_t_interior_axis = cx.LabeledAxis('longitude', lon_t_values[1:-1])
delta_flux = cx.field(delta_flux_data, lat_axis, lon_t_interior_axis)
# shape: (16, 30) | dims: ('latitude','longitude') | units: K·m/s  — interior centers only
print(f"Delta-flux (interior centers): {delta_flux.dims}, shape: {delta_flux.shape}")
Delta-flux (interior centers): ('latitude', 'longitude'), shape: (16, 30)

Temperature Tendency ∂T/∂t

dlon_rad = jnp.deg2rad(dlon)
dx = R_EARTH * jnp.cos(jnp.deg2rad(lat_values)) * dlon_rad  # (n_lat,)
dx_broadcast = dx[:, None]

# ∂T/∂t = −div(F) / Δx   [K/s]   Δx(φ) = R·cos(φ)·Δλ
dT_dt_data = -delta_flux.data / dx_broadcast
dT_dt = cx.field(dT_dt_data, lat_axis, lon_t_interior_axis)
# shape: (16, 30) | dims: ('latitude','longitude') | units: K/s
print(f"∂T/∂t: {dT_dt.dims}, shape: {dT_dt.shape}")
print(f"Range: {float(dT_dt.data.min()):.2e} — {float(dT_dt.data.max()):.2e} K/s")
print(f"Max heating: {float(dT_dt.data.max()) * 3600:.2f} K/h")
∂T/∂t: ('latitude', 'longitude'), shape: (16, 30)
Range: -1.19e-03 — 1.22e-03 K/s
Max heating: 4.38 K/h

Forward Euler Time Step

# CFL check
dt = 100.0  # seconds
cfl = float(jnp.abs(U_field.data).max()) * dt / float(dx.min())
print(f"CFL = {cfl:.3f}  (must be < 1 for stability)")

T_new_data = T_field.data.at[:, 1:-1].add(dt * dT_dt.data)
T_new = cx.field(T_new_data, lat_axis, lon_t_axis)
# shape: (16, 32) | dims: ('latitude','longitude') | units: K  — full T-grid restored
print(f"Updated T: {T_new.dims}, shape: {T_new.shape}")
print(f"Updated T range: {float(T_new.data.min()):.1f} — {float(T_new.data.max()):.1f} K")

# Area-weighted conservation check
cos_lat = jnp.cos(jnp.deg2rad(lat_mesh_t))
mass_i = float(jnp.sum(T_field.data * cos_lat))
mass_f = float(jnp.sum(T_new.data   * cos_lat))
print(f"Mass conservation: {abs(mass_f - mass_i) / abs(mass_i):.2e} relative change")
CFL = 0.008  (must be < 1 for stability)
Updated T: ('latitude', 'longitude'), shape: (16, 32)
Updated T range: 294.2 — 323.9 K
Mass conservation: 1.87e-06 relative change

Complete One-Step Advection Function

def advection_step_staggered(T_field, U_field, dt, lat_values, dlon_deg, scheme='upwind'):
    """
    One advection step on an Arakawa C-grid.

    Returns updated T_field (same shape / axes).
    """
    lat_axis = T_field.axes['latitude']
    lon_t_axis = T_field.axes['longitude']
    lon_u_axis = U_field.axes['longitude_u']

    if scheme == 'upwind':
        flux_data = upwind_flux_interior(U_field, T_field, axis_name='longitude')
    else:
        T_at_U_data = interpolate_centers_to_interior_faces(T_field, axis_name='longitude')
        flux_data = U_field.data * T_at_U_data

    flux = cx.field(flux_data, lat_axis, lon_u_axis)
    delta_flux_data = flux_divergence_to_centers(flux, axis_name='longitude_u')

    dlon_rad = jnp.deg2rad(dlon_deg)
    dx = R_EARTH * jnp.cos(jnp.deg2rad(lat_values)) * dlon_rad
    dT_dt_data = -delta_flux_data / dx[:, None]

    T_new_data = T_field.data.at[:, 1:-1].add(dt * dT_dt_data)
    return cx.field(T_new_data, lat_axis, lon_t_axis)


T_updated = advection_step_staggered(
    T_field, U_field, dt=100.0, lat_values=lat_values, dlon_deg=dlon, scheme='upwind'
)
# shape: (16, 32) | dims: ('latitude','longitude') | units: K
print(f"Complete step — T: {T_updated.dims}, shape: {T_updated.shape}")
print(f"T range: {float(T_updated.data.min()):.1f} — {float(T_updated.data.max()):.1f} K")
Complete step — T: ('latitude', 'longitude'), shape: (16, 32)
T range: 294.2 — 323.9 K

Key Patterns Summary

  1. T-grid (longitude) and U-grid (longitude_u) are distinct coordinate systems
  2. Interpolation: T_face[i] = 0.5*(T[i] + T[i+1]) (centers → interior faces)
  3. Upwind flux: choose upstream value based on sign of velocity
  4. Flux divergence: F[i] - F[i-1] at interior centers
  5. Tendency: ∂T/∂t = -delta_F / dx; update T[:, 1:-1] only