Reductions
Part 3: Aggregations by Name (Functional Ops on Dims)¶
This notebook demonstrates dimension-aware reductions and numerical integration using the untag + cmap pattern.
import coordax as cx
import jax.numpy as jnp
import numpy as np
# NOTE: cx.field() is preferred over cx.wrap() in coordax >= 0.23.1 Calculating the Time-Mean Map¶
n_years, n_months, n_lat, n_lon = 3, 12, 16, 32
n_time = n_years * n_months
temps = np.random.randn(n_time, n_lat, n_lon) * 10 + 15
lat_values = np.linspace(-90, 90, n_lat)
lat_effect = 20 * np.cos(np.deg2rad(lat_values))
temps += lat_effect[np.newaxis, :, np.newaxis]
time_axis = cx.LabeledAxis('time', np.arange(n_time, dtype=float))
lat_axis = cx.LabeledAxis('latitude', lat_values)
lon_axis = cx.LabeledAxis('longitude', np.linspace(-180, 180, n_lon))
temperature = cx.field(temps, time_axis, lat_axis, lon_axis)
# shape: (36, 16, 32) | dims: ('time','latitude','longitude') | units: °C
print(f"Temperature field: {temperature.dims}, shape: {temperature.shape}")Temperature field: ('time', 'latitude', 'longitude'), shape: (36, 16, 32)
Computing the time mean¶
# Step 1: Untag 'time' to expose it as the leading positional axis
temp_untagged = temperature.untag('time')
# shape: (36, 16, 32) — 'time' exposed as leading positional axis
# Step 2: Reduce over axis 0 (the untagged time axis)
time_mean = cx.cmap(lambda x: jnp.mean(x, axis=0))(temp_untagged)
# shape: (16, 32) | dims: ('latitude','longitude') | units: °C — mean over 36 steps
print(f"Time-mean map: {time_mean.dims}, shape: {time_mean.shape}")Time-mean map: ('latitude', 'longitude'), shape: (16, 32)
Seasonal climatology¶
# Reshape to (years, months, lat, lon) then average over years
temps_reshaped = temps.reshape(n_years, n_months, n_lat, n_lon)
year_axis = cx.LabeledAxis('year', np.arange(n_years, dtype=float))
month_axis = cx.LabeledAxis('month', np.arange(n_months, dtype=float))
temp_reshaped = cx.field(temps_reshaped, year_axis, month_axis, lat_axis, lon_axis)
# shape: (3, 12, 16, 32) | dims: ('year','month','latitude','longitude') | units: °C
monthly_climatology = cx.cmap(lambda x: jnp.mean(x, axis=0))(
temp_reshaped.untag('year')
)
# shape: (12, 16, 32) | dims: ('month','latitude','longitude') | units: °C — 3-yr average
print(f"Monthly climatology: {monthly_climatology.dims}, shape: {monthly_climatology.shape}")Monthly climatology: ('month', 'latitude', 'longitude'), shape: (12, 16, 32)
3.2 Area-Weighted Global Mean Temperature¶
n_lat2, n_lon2 = 16, 32
lat_values2 = np.linspace(-90, 90, n_lat2)
temps2 = np.random.randn(n_lat2, n_lon2) * 5 + 15
temps2 += (20 * np.cos(np.deg2rad(lat_values2)))[:, np.newaxis]
lat_axis2 = cx.LabeledAxis('latitude', lat_values2)
lon_axis2 = cx.LabeledAxis('longitude', np.linspace(-180, 180, n_lon2))
temperature2 = cx.field(temps2, lat_axis2, lon_axis2)
# shape: (16, 32) | dims: ('latitude','longitude') | units: °CComputing cosine-latitude weights¶
lat_weights_vals = np.cos(np.deg2rad(lat_values2))
lat_weights = cx.field(lat_weights_vals, lat_axis2)
# shape: (16,) | dims: ('latitude',) | units: dimensionless [cos φ, φ in radians]
print(f"Latitude weights: {lat_weights.dims}, shape: {lat_weights.shape}")
# w(φ) = cos(φ) — latitude weights proportional to grid-cell area
# Weighted temperature — broadcasts (lat,) over 'longitude'
weighted_temp = temperature2 * lat_weights
# shape: (16, 32) | dims: ('latitude','longitude') | units: °C — (lat,) broadcasts over lon
# Sum over all spatial dims
weighted_sum = cx.cmap(jnp.sum)(weighted_temp.untag('latitude', 'longitude'))
total_weight = jnp.sum(lat_weights_vals) * n_lon2
# T̄ = Σ(w·T) / Σ(w) where w = cos(φ)
global_mean = weighted_sum.data / total_weight
print(f"Global mean temperature: {global_mean:.2f} °C")Latitude weights: ('latitude',), shape: (16,)
Global mean temperature: 30.85 °C
Alternative: zonal-mean then latitude-weighted average¶
zonal_mean = cx.cmap(lambda x: jnp.mean(x, axis=-1))(temperature2.untag('longitude'))
# shape: (16,) | dims: ('latitude',) | units: °C — mean over 32 lons
print(f"Zonal mean: {zonal_mean.dims}, shape: {zonal_mean.shape}")
weighted_zonal = zonal_mean * lat_weights
# shape: (16,) | dims: ('latitude',) | units: °C — ready for Σ(w·T̄_lon) / Σ(w)
numerator = cx.cmap(jnp.sum)(weighted_zonal.untag('latitude'))
denominator = cx.cmap(jnp.sum)(lat_weights.untag('latitude'))
global_mean_v2 = numerator.data / denominator.data
print(f"Global mean (v2): {global_mean_v2:.2f} °C")
print(f"Methods agree (tol 1e-3): {float(np.abs(global_mean - global_mean_v2)) < 1e-3}")Zonal mean: ('latitude',), shape: (16,)
Global mean (v2): 30.85 °C
Methods agree (tol 1e-3): True
3.3 Integrating Ocean Heat Content (Trapezoid Rule)¶
n_depth = 20
n_lat3, n_lon3 = 8, 16
# Non-uniform depth grid
depth_values = np.concatenate([
np.linspace(0, 100, 10),
np.linspace(125, 500, 10),
])
# Temperature profile: exponential decay with depth
surface_temp, deep_temp, decay_scale = 25.0, 5.0, 200.0
temp_profile = (deep_temp + (surface_temp - deep_temp) *
np.exp(-depth_values / decay_scale))[:, np.newaxis, np.newaxis]
lat_vals3 = np.linspace(-60, 60, n_lat3)
lat_variation = 5 * np.cos(np.deg2rad(lat_vals3))[np.newaxis, :, np.newaxis]
ocean_temps = temp_profile + lat_variation + np.random.randn(n_depth, n_lat3, n_lon3) * 0.5
depth_axis = cx.LabeledAxis('depth', depth_values)
lat_axis3 = cx.LabeledAxis('latitude', lat_vals3)
lon_axis3 = cx.LabeledAxis('longitude', np.linspace(-180, 180, n_lon3))
ocean_temperature = cx.field(ocean_temps, depth_axis, lat_axis3, lon_axis3)
# shape: (20, 8, 16) | dims: ('depth','latitude','longitude') | units: °C
print(f"Ocean temperature: {ocean_temperature.dims}, shape: {ocean_temperature.shape}")Ocean temperature: ('depth', 'latitude', 'longitude'), shape: (20, 8, 16)
Trapezoid-rule integration over depth¶
def trapezoid_integrate(values, x):
"""Integrate over first axis using the trapezoid rule (works for 1-D slices from cmap)."""
dx = jnp.diff(x)
avg_values = (values[:-1] + values[1:]) / 2.0
return jnp.sum(avg_values * dx, axis=0)
ocean_temp_untagged = ocean_temperature.untag('depth')
# shape: (20, 8, 16) — 'depth' exposed as leading positional axis
temp_integrated = cx.cmap(
lambda t: trapezoid_integrate(t, depth_values)
)(ocean_temp_untagged)
# shape: (8, 16) | dims: ('latitude','longitude') | units: °C·m — ∫₀^z T dz
print(f"Integrated temperature: {temp_integrated.dims}, shape: {temp_integrated.shape}")
# OHC = ρ·cₚ · ∫₀^H T(z) dz [J/m²]
# Ocean heat content (J/m²)
rho, c_p = 1025.0, 3850.0
ocean_heat_content = rho * c_p * temp_integrated / 1e9 # GJ/m²
# shape: (8, 16) | dims: ('latitude','longitude') | units: GJ/m²
print(f"Ocean heat content (GJ/m²) — mean: {float(jnp.mean(ocean_heat_content.data)):.2f}")Integrated temperature: ('latitude', 'longitude'), shape: (8, 16)
Ocean heat content (GJ/m²) — mean: 32.03
Cumulative integration¶
def cumulative_trapezoid(values, x):
"""Cumulative integration over first axis (works for 1-D slices from cmap)."""
dx = jnp.diff(x)
avg_values = (values[:-1] + values[1:]) / 2.0
increments = avg_values * dx
return jnp.concatenate([jnp.zeros(1), jnp.cumsum(increments, axis=0)], axis=0)
cumulative_temp = cx.cmap(
lambda t: cumulative_trapezoid(t, depth_values)
)(ocean_temp_untagged)
cumulative_temp = cumulative_temp.tag(depth_axis)
# shape: (20, 8, 16) | dims: ('depth','latitude','longitude') | units: °C·m (cumulative)
print(f"Cumulative temperature: {cumulative_temp.dims}, shape: {cumulative_temp.shape}")Cumulative temperature: ('depth', 'latitude', 'longitude'), shape: (20, 8, 16)
Key Patterns Summary¶
1. Simple reduction¶
result = cx.cmap(lambda x: jnp.mean(x, axis=0))(field.untag('time'))2. Weighted aggregation¶
weighted = field * weights # broadcasting
result = cx.cmap(jnp.sum)(weighted.untag('lat', 'lon')) / total_weight3. Numerical integration¶
integrated = cx.cmap(lambda f: trapezoid_integrate(f, z_vals))(field.untag('z'))4. Cumulative operations¶
cumul = cx.cmap(lambda f: cumulative_trap(f, x))(field.untag('x')).tag(x_axis)