Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Reductions

Part 3: Aggregations by Name (Functional Ops on Dims)

This notebook demonstrates dimension-aware reductions and numerical integration using the untag + cmap pattern.

import coordax as cx
import jax.numpy as jnp
import numpy as np

# NOTE: cx.field() is preferred over cx.wrap() in coordax >= 0.2

3.1 Calculating the Time-Mean Map

n_years, n_months, n_lat, n_lon = 3, 12, 16, 32
n_time = n_years * n_months

temps = np.random.randn(n_time, n_lat, n_lon) * 10 + 15
lat_values = np.linspace(-90, 90, n_lat)
lat_effect = 20 * np.cos(np.deg2rad(lat_values))
temps += lat_effect[np.newaxis, :, np.newaxis]

time_axis = cx.LabeledAxis('time', np.arange(n_time, dtype=float))
lat_axis = cx.LabeledAxis('latitude', lat_values)
lon_axis = cx.LabeledAxis('longitude', np.linspace(-180, 180, n_lon))

temperature = cx.field(temps, time_axis, lat_axis, lon_axis)
# shape: (36, 16, 32) | dims: ('time','latitude','longitude') | units: °C
print(f"Temperature field: {temperature.dims}, shape: {temperature.shape}")
Temperature field: ('time', 'latitude', 'longitude'), shape: (36, 16, 32)

Computing the time mean

# Step 1: Untag 'time' to expose it as the leading positional axis
temp_untagged = temperature.untag('time')
# shape: (36, 16, 32) — 'time' exposed as leading positional axis

# Step 2: Reduce over axis 0 (the untagged time axis)
time_mean = cx.cmap(lambda x: jnp.mean(x, axis=0))(temp_untagged)
# shape: (16, 32) | dims: ('latitude','longitude') | units: °C  — mean over 36 steps
print(f"Time-mean map: {time_mean.dims}, shape: {time_mean.shape}")
Time-mean map: ('latitude', 'longitude'), shape: (16, 32)

Seasonal climatology

# Reshape to (years, months, lat, lon) then average over years
temps_reshaped = temps.reshape(n_years, n_months, n_lat, n_lon)

year_axis = cx.LabeledAxis('year', np.arange(n_years, dtype=float))
month_axis = cx.LabeledAxis('month', np.arange(n_months, dtype=float))

temp_reshaped = cx.field(temps_reshaped, year_axis, month_axis, lat_axis, lon_axis)
# shape: (3, 12, 16, 32) | dims: ('year','month','latitude','longitude') | units: °C

monthly_climatology = cx.cmap(lambda x: jnp.mean(x, axis=0))(
    temp_reshaped.untag('year')
)
# shape: (12, 16, 32) | dims: ('month','latitude','longitude') | units: °C — 3-yr average
print(f"Monthly climatology: {monthly_climatology.dims}, shape: {monthly_climatology.shape}")
Monthly climatology: ('month', 'latitude', 'longitude'), shape: (12, 16, 32)

3.2 Area-Weighted Global Mean Temperature

n_lat2, n_lon2 = 16, 32
lat_values2 = np.linspace(-90, 90, n_lat2)
temps2 = np.random.randn(n_lat2, n_lon2) * 5 + 15
temps2 += (20 * np.cos(np.deg2rad(lat_values2)))[:, np.newaxis]

lat_axis2 = cx.LabeledAxis('latitude', lat_values2)
lon_axis2 = cx.LabeledAxis('longitude', np.linspace(-180, 180, n_lon2))
temperature2 = cx.field(temps2, lat_axis2, lon_axis2)
# shape: (16, 32) | dims: ('latitude','longitude') | units: °C

Computing cosine-latitude weights

lat_weights_vals = np.cos(np.deg2rad(lat_values2))
lat_weights = cx.field(lat_weights_vals, lat_axis2)
# shape: (16,) | dims: ('latitude',) | units: dimensionless  [cos φ, φ in radians]
print(f"Latitude weights: {lat_weights.dims}, shape: {lat_weights.shape}")

# w(φ) = cos(φ)  — latitude weights proportional to grid-cell area
# Weighted temperature — broadcasts (lat,) over 'longitude'
weighted_temp = temperature2 * lat_weights
# shape: (16, 32) | dims: ('latitude','longitude') | units: °C  — (lat,) broadcasts over lon

# Sum over all spatial dims
weighted_sum = cx.cmap(jnp.sum)(weighted_temp.untag('latitude', 'longitude'))
total_weight = jnp.sum(lat_weights_vals) * n_lon2
# T̄ = Σ(w·T) / Σ(w)   where w = cos(φ)
global_mean = weighted_sum.data / total_weight
print(f"Global mean temperature: {global_mean:.2f} °C")
Latitude weights: ('latitude',), shape: (16,)
Global mean temperature: 30.85 °C

Alternative: zonal-mean then latitude-weighted average

zonal_mean = cx.cmap(lambda x: jnp.mean(x, axis=-1))(temperature2.untag('longitude'))
# shape: (16,) | dims: ('latitude',) | units: °C  — mean over 32 lons
print(f"Zonal mean: {zonal_mean.dims}, shape: {zonal_mean.shape}")

weighted_zonal = zonal_mean * lat_weights
# shape: (16,) | dims: ('latitude',) | units: °C  — ready for Σ(w·T̄_lon) / Σ(w)
numerator = cx.cmap(jnp.sum)(weighted_zonal.untag('latitude'))
denominator = cx.cmap(jnp.sum)(lat_weights.untag('latitude'))
global_mean_v2 = numerator.data / denominator.data
print(f"Global mean (v2): {global_mean_v2:.2f} °C")
print(f"Methods agree (tol 1e-3): {float(np.abs(global_mean - global_mean_v2)) < 1e-3}")
Zonal mean: ('latitude',), shape: (16,)
Global mean (v2): 30.85 °C
Methods agree (tol 1e-3): True

3.3 Integrating Ocean Heat Content (Trapezoid Rule)

n_depth = 20
n_lat3, n_lon3 = 8, 16

# Non-uniform depth grid
depth_values = np.concatenate([
    np.linspace(0, 100, 10),
    np.linspace(125, 500, 10),
])

# Temperature profile: exponential decay with depth
surface_temp, deep_temp, decay_scale = 25.0, 5.0, 200.0
temp_profile = (deep_temp + (surface_temp - deep_temp) *
                np.exp(-depth_values / decay_scale))[:, np.newaxis, np.newaxis]

lat_vals3 = np.linspace(-60, 60, n_lat3)
lat_variation = 5 * np.cos(np.deg2rad(lat_vals3))[np.newaxis, :, np.newaxis]
ocean_temps = temp_profile + lat_variation + np.random.randn(n_depth, n_lat3, n_lon3) * 0.5

depth_axis = cx.LabeledAxis('depth', depth_values)
lat_axis3 = cx.LabeledAxis('latitude', lat_vals3)
lon_axis3 = cx.LabeledAxis('longitude', np.linspace(-180, 180, n_lon3))
ocean_temperature = cx.field(ocean_temps, depth_axis, lat_axis3, lon_axis3)
# shape: (20, 8, 16) | dims: ('depth','latitude','longitude') | units: °C
print(f"Ocean temperature: {ocean_temperature.dims}, shape: {ocean_temperature.shape}")
Ocean temperature: ('depth', 'latitude', 'longitude'), shape: (20, 8, 16)

Trapezoid-rule integration over depth

def trapezoid_integrate(values, x):
    """Integrate over first axis using the trapezoid rule (works for 1-D slices from cmap)."""
    dx = jnp.diff(x)
    avg_values = (values[:-1] + values[1:]) / 2.0
    return jnp.sum(avg_values * dx, axis=0)


ocean_temp_untagged = ocean_temperature.untag('depth')
# shape: (20, 8, 16) — 'depth' exposed as leading positional axis
temp_integrated = cx.cmap(
    lambda t: trapezoid_integrate(t, depth_values)
)(ocean_temp_untagged)
# shape: (8, 16) | dims: ('latitude','longitude') | units: °C·m  — ∫₀^z T dz
print(f"Integrated temperature: {temp_integrated.dims}, shape: {temp_integrated.shape}")

# OHC = ρ·cₚ · ∫₀^H T(z) dz   [J/m²]
# Ocean heat content (J/m²)
rho, c_p = 1025.0, 3850.0
ocean_heat_content = rho * c_p * temp_integrated / 1e9  # GJ/m²
# shape: (8, 16) | dims: ('latitude','longitude') | units: GJ/m²
print(f"Ocean heat content (GJ/m²) — mean: {float(jnp.mean(ocean_heat_content.data)):.2f}")
Integrated temperature: ('latitude', 'longitude'), shape: (8, 16)
Ocean heat content (GJ/m²) — mean: 32.03

Cumulative integration

def cumulative_trapezoid(values, x):
    """Cumulative integration over first axis (works for 1-D slices from cmap)."""
    dx = jnp.diff(x)
    avg_values = (values[:-1] + values[1:]) / 2.0
    increments = avg_values * dx
    return jnp.concatenate([jnp.zeros(1), jnp.cumsum(increments, axis=0)], axis=0)


cumulative_temp = cx.cmap(
    lambda t: cumulative_trapezoid(t, depth_values)
)(ocean_temp_untagged)
cumulative_temp = cumulative_temp.tag(depth_axis)
# shape: (20, 8, 16) | dims: ('depth','latitude','longitude') | units: °C·m (cumulative)
print(f"Cumulative temperature: {cumulative_temp.dims}, shape: {cumulative_temp.shape}")
Cumulative temperature: ('depth', 'latitude', 'longitude'), shape: (20, 8, 16)

Key Patterns Summary

1. Simple reduction

result = cx.cmap(lambda x: jnp.mean(x, axis=0))(field.untag('time'))

2. Weighted aggregation

weighted = field * weights          # broadcasting
result = cx.cmap(jnp.sum)(weighted.untag('lat', 'lon')) / total_weight

3. Numerical integration

integrated = cx.cmap(lambda f: trapezoid_integrate(f, z_vals))(field.untag('z'))

4. Cumulative operations

cumul = cx.cmap(lambda f: cumulative_trap(f, x))(field.untag('x')).tag(x_axis)