Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Finite-difference gradients

Part 5: Coordinate-Aware Gradient Operators

This notebook demonstrates how to compute derivatives with respect to physical coordinates (not just array indices) using finite differences.

import coordax as cx
import jax.numpy as jnp

# NOTE: cx.field() is preferred over cx.wrap() in coordax >= 0.2

5.1 Finite Differences on a Periodic Cartesian Domain

Centered finite differences on a uniform periodic grid.

N = 64
L = 10.0  # meters
x_values = jnp.linspace(0.0, L, N, endpoint=False)

k0 = 2.0 * jnp.pi / L
temperature_data = jnp.sin(3.0 * k0 * x_values) + 0.5 * jnp.cos(7.0 * k0 * x_values)
dT_dx_analytical = (3.0 * k0 * jnp.cos(3.0 * k0 * x_values)
                    - 3.5 * k0 * jnp.sin(7.0 * k0 * x_values))

x_axis = cx.LabeledAxis('x', x_values)
temperature = cx.field(temperature_data, x_axis)
# shape: (64,) | dims: ('x',) | units: K (synthetic test function)

print(f"Temperature field: shape={temperature.shape}, dims={temperature.dims}")
print(f"x range: {float(x_values.min()):.3f} to {float(x_values.max()):.3f} m")
Temperature field: shape=(64,), dims=('x',)
x range: 0.000 to 9.844 m

Grid spacing from coordinates

x_coords = x_axis.ticks
dx = float(L / N)
dx_from_coords = float(x_coords[1] - x_coords[0])
print(f"dx (from domain): {dx:.6f} m")
print(f"dx (from coords): {dx_from_coords:.6f} m")
print(f"Match: {jnp.allclose(dx, dx_from_coords)}")
dx (from domain): 0.156250 m
dx (from coords): 0.156250 m
Match: True

Periodic centered finite difference

def centered_diff_periodic(values, dx):
    """Centered FD on a periodic uniform grid: (f[i+1] - f[i-1]) / (2 dx)."""
    # ∂f/∂x[i] ≈ (f[i+1] − f[i−1]) / (2·Δx)   [2nd-order, O(Δx²)]
    return (jnp.roll(values, shift=-1) - jnp.roll(values, shift=1)) / (2.0 * dx)


dT_dx_data = centered_diff_periodic(temperature.data, dx)
dT_dx = cx.field(dT_dx_data, x_axis)
# shape: (64,) | dims: ('x',) | units: K/m  — ∂T/∂x

error = jnp.abs(dT_dx.data - dT_dx_analytical)
print(f"Max absolute error: {jnp.max(error):.2e}")
print(f"Mean absolute error: {jnp.mean(error):.2e}")
Max absolute error: 1.89e-01
Mean absolute error: 1.09e-01

Reusable helper using cx.field(...)

def fd_derivative_periodic(field, dim='x'):
    """First derivative via centered FD on a periodic uniform grid."""
    coord_values = field.axes[dim].ticks
    dx = coord_values[1] - coord_values[0]
    axis_pos = field.dims.index(dim)
    f = field.data
    df_data = (jnp.roll(f, -1, axis=axis_pos) - jnp.roll(f, 1, axis=axis_pos)) / (2.0 * dx)
    return cx.field(df_data, *[field.axes[d] for d in field.dims])


dT_dx_fn = fd_derivative_periodic(temperature, dim='x')
print(f"Helper vs manual diff: {jnp.max(jnp.abs(dT_dx_fn.data - dT_dx.data)):.2e}")
Helper vs manual diff: 0.00e+00

5.2 Finite Differences with Non-Uniform Spacing

Non-uniform vertical grid (denser near the surface).

n_levels = 30
z_max = 20_000.0  # m
z_values = z_max * jnp.linspace(0, 1, n_levels) ** 2

P0, H = 101325.0, 7000.0  # Pa, m
pressure_data = P0 * jnp.exp(-z_values / H)
dP_dz_analytical = -(P0 / H) * jnp.exp(-z_values / H)

z_axis = cx.LabeledAxis('z', z_values)
pressure = cx.field(pressure_data, z_axis)
# shape: (30,) | dims: ('z',) | units: Pa  — P(z) = P₀·exp(−z/H)

print(f"Pressure profile: shape={pressure.shape}, z range: {float(z_values.min()):.1f} to {float(z_values.max()):.1f} m")
Pressure profile: shape=(30,), z range: 0.0 to 20000.0 m

Centered FD for non-uniform grids

def centered_diff_nonuniform(values, coords):
    """Centered FD on non-uniform grid; forward/backward at boundaries."""
    # interior: (f[i+1] - f[i-1]) / (x[i+1] - x[i-1])
    interior = (values[2:] - values[:-2]) / (coords[2:] - coords[:-2])
    left_bc  = (values[1] - values[0]) / (coords[1] - coords[0])
    right_bc = (values[-1] - values[-2]) / (coords[-1] - coords[-2])

    df_dx = jnp.concatenate([
        jnp.array([left_bc]),
        interior,
        jnp.array([right_bc]),
    ])
    return df_dx


z_coords = z_axis.ticks
# ∂P/∂z[i] ≈ (P[i+1] − P[i−1]) / (z[i+1] − z[i−1])   (non-uniform spacing)
dP_dz_data = centered_diff_nonuniform(pressure.data, z_coords)
dP_dz = cx.field(dP_dz_data, z_axis)
# shape: (30,) | dims: ('z',) | units: Pa/m  — ∂P/∂z

error_z = jnp.abs(dP_dz.data - dP_dz_analytical)
print(f"Max absolute error: {jnp.max(error_z):.3f} Pa/m")
print(f"Max relative error: {jnp.max(error_z / jnp.abs(dP_dz_analytical)):.2%}")
Max absolute error: 0.086 Pa/m
Max relative error: 10.34%

5.3 2D Gradients (∇ operator) in Cartesian Coordinates

n_y, n_x = 32, 64
Lx, Ly = 10.0, 5.0
x_vals_2d = jnp.linspace(0.0, Lx, n_x, endpoint=False)
y_vals_2d = jnp.linspace(0.0, Ly, n_y)

x_mesh, y_mesh = jnp.meshgrid(x_vals_2d, y_vals_2d, indexing='xy')
T0, DeltaT = 15.0, 20.0

temp_data_2d = (T0
                + DeltaT * jnp.cos(2.0 * jnp.pi * y_mesh / Ly)
                * jnp.sin(3.0 * 2.0 * jnp.pi * x_mesh / Lx))

dT_dx_an = (DeltaT * jnp.cos(2.0 * jnp.pi * y_mesh / Ly)
            * (3.0 * 2.0 * jnp.pi / Lx)
            * jnp.cos(3.0 * 2.0 * jnp.pi * x_mesh / Lx))
dT_dy_an = (DeltaT * (-(2.0 * jnp.pi / Ly))
            * jnp.sin(2.0 * jnp.pi * y_mesh / Ly)
            * jnp.sin(3.0 * 2.0 * jnp.pi * x_mesh / Lx))

y_axis_2d = cx.LabeledAxis('y', y_vals_2d)
x_axis_2d = cx.LabeledAxis('x', x_vals_2d)
temperature_2d = cx.field(temp_data_2d, y_axis_2d, x_axis_2d)
# shape: (32, 64) | dims: ('y','x') | units: K  — T = T₀ + ΔT·cos(2πy/Ly)·sin(6πx/Lx)

print(f"2D Temperature field: shape={temperature_2d.shape}, dims={temperature_2d.dims}")
2D Temperature field: shape=(32, 64), dims=('y', 'x')

Gradient in y (non-periodic) and x (periodic)

# y-gradient using jnp.gradient with coordinate spacing
dT_dy_data = cx.cmap(
    lambda t: jnp.gradient(t, y_vals_2d, axis=0)
)(temperature_2d.untag('y'))
dT_dy = dT_dy_data.tag(y_axis_2d)
# shape: (32, 64) | dims: ('y','x') | units: K/m  — ∂T/∂y via jnp.gradient

# x-gradient using periodic centered FD
dx_2d = float(x_vals_2d[1] - x_vals_2d[0])
axis_pos_x = temperature_2d.dims.index('x')
dT_dx_data = (
    jnp.roll(temperature_2d.data, -1, axis=axis_pos_x)
    - jnp.roll(temperature_2d.data,  1, axis=axis_pos_x)
) / (2.0 * dx_2d)
dT_dx_2d = cx.field(dT_dx_data, y_axis_2d, x_axis_2d)
# shape: (32, 64) | dims: ('y','x') | units: K/m  — ∂T/∂x (periodic centered FD)

print(f"y gradient accuracy — max error: {jnp.max(jnp.abs(dT_dy.data - dT_dy_an)):.3e}")
print(f"x gradient accuracy — max error: {jnp.max(jnp.abs(dT_dx_2d.data - dT_dx_an)):.3e}")

# Gradient magnitude
grad_mag_data = jnp.sqrt(dT_dx_2d.data**2 + dT_dy.data**2)
grad_magnitude = cx.field(grad_mag_data, y_axis_2d, x_axis_2d)
# shape: (32, 64) | dims: ('y','x') | units: K/m  — |∇T| = √((∂T/∂x)² + (∂T/∂y)²)
print(f"Gradient magnitude range: {float(grad_magnitude.data.min()):.2f} to {float(grad_magnitude.data.max()):.2f} [T]/m")
y gradient accuracy — max error: 2.538e+00
x gradient accuracy — max error: 5.427e-01
Gradient magnitude range: 1.88 to 37.16 [T]/m

Advanced: Second Derivative (Laplacian)

def fd_second_derivative_periodic(field, dim='x'):
    """Second derivative via centered FD on a periodic uniform grid: (f[i+1] - 2f[i] + f[i-1]) / dx²."""
    coord_values = field.axes[dim].ticks
    dx = coord_values[1] - coord_values[0]
    axis_pos = field.dims.index(dim)
    f = field.data
    d2f_data = (
        jnp.roll(f, -1, axis=axis_pos)
        - 2.0 * f
        + jnp.roll(f,  1, axis=axis_pos)
    ) / (dx * dx)
    return cx.field(d2f_data, *[field.axes[d] for d in field.dims])


d2T_dx2 = fd_second_derivative_periodic(temperature, dim='x')
# shape: (64,) | dims: ('x',) | units: K/m²  — ∂²T/∂x² = (T[i+1] − 2T[i] + T[i−1])/Δx²  (periodic, uniform Δx)
print(f"Second derivative shape: {d2T_dx2.shape}, dims: {d2T_dx2.dims}")
Second derivative shape: (64,), dims: ('x',)

Key Patterns Summary

  • Periodic uniform grids: (roll(f,-1) - roll(f,+1)) / (2*dx) with dx from coordinates
  • Non-uniform grids: (f[i+1] - f[i-1]) / (x[i+1] - x[i-1]) or jnp.gradient(f, coords, axis=...)
  • Multi-dimensional: compute derivatives in each dimension separately; combine for gradient magnitude
  • cmap pattern: cx.cmap(deriv_fn)(field.untag('dim')).tag(coord_axis)