Unary & binary ops
Part 2: Safe, Labeled Arithmetic (Unary & Binary Ops)¶
This notebook demonstrates how coordax provides safe, dimension-aware arithmetic operations that automatically handle broadcasting and alignment.
import coordax as cx
import jax.numpy as jnp
import numpy as np
# NOTE: cx.field() is preferred over cx.wrap() in coordax >= 0.22.1 Unit Conversion (Unary Operations)¶
Converting temperature from Kelvin to Celsius while preserving coordinate metadata.
n_lat, n_lon = 16, 32
temps_kelvin = np.random.uniform(270, 310, size=(n_lat, n_lon))
# shape: (16, 32) | dtype: float64 | units: K [270, 310 K]
lat_axis = cx.LabeledAxis('latitude', np.linspace(-90, 90, n_lat))
lon_axis = cx.LabeledAxis('longitude', np.linspace(-180, 180, n_lon))
temperature_K = cx.field(temps_kelvin, lat_axis, lon_axis)
# shape: (16, 32) | dims: ('latitude','longitude') | units: K
print(f"Temperature (K): {temperature_K.dims}, shape: {temperature_K.shape}")
# Unary operation — scalar arithmetic preserves coordinates
# T[°C] = T[K] − 273.15
temperature_C = temperature_K - 273.15
# shape: (16, 32) | dims: ('latitude','longitude') | units: °C
print(f"Temperature (C): {temperature_C.dims}")
print(f"Coordinates preserved: {list(temperature_C.axes.keys())}")
# Verify coordinates are identical
assert temperature_C.axes['latitude'] == temperature_K.axes['latitude']
assert temperature_C.axes['longitude'] == temperature_K.axes['longitude']
print("Coordinates preserved: True")Temperature (K): ('latitude', 'longitude'), shape: (16, 32)
Temperature (C): ('latitude', 'longitude')
Coordinates preserved: ['latitude', 'longitude']
Coordinates preserved: True
More complex unary operations¶
temp_squared = temperature_C ** 2 # shape: (16, 32) | units: °C² (demo only; not a physical quantity)
_temp_abs = cx.cmap(jnp.abs)(temperature_C) # shape: (16, 32) | units: °C
_temp_exp = cx.cmap(jnp.exp)(temperature_C / 100) # shape: (16, 32) | dimensionless
print(f"All operations preserve dims: {temp_squared.dims == temperature_C.dims}")All operations preserve dims: True
Custom unary transformation with cmap¶
def normalize_to_range(x, old_min, old_max, new_min=0.0, new_max=1.0):
"""Normalize values from [old_min, old_max] to [new_min, new_max]."""
return (x - old_min) / (old_max - old_min) * (new_max - new_min) + new_min
# Normalization is elementwise with scalar min/max. Using cx.cmap here would
# vmap over every named axis and reduce x to a scalar, making x.min()==x.max()
# and producing NaNs — call the function directly on the Field instead.
t_min = float(temperature_C.data.min())
t_max = float(temperature_C.data.max())
temp_normalized = normalize_to_range(temperature_C, t_min, t_max)
# shape: (16, 32) | dims: ('latitude','longitude') | range: [0, 1]
print(f"Normalized range: [{float(temp_normalized.data.min()):.2f}, {float(temp_normalized.data.max()):.2f}]")
print(f"Shape: {temp_normalized.shape}, Dims: {temp_normalized.dims}")
Normalized range: [0.00, 1.00]
Shape: (16, 32), Dims: ('latitude', 'longitude')
2.2 Applying a Time-Dependent Correction (Binary Operations)¶
Binary operations automatically broadcast over matching named dimensions.
n_time, n_sensors = 50, 5
measurements = np.random.randn(n_time, n_sensors) * 10 + 25
# shape: (50, 5) | dtype: float64 | units: °C (simulated)
time_axis = cx.LabeledAxis('time', np.arange(n_time) * 60.0) # seconds
sensor_axis = cx.LabeledAxis('sensor', np.arange(n_sensors, dtype=float))
data = cx.field(measurements, time_axis, sensor_axis)
# shape: (50, 5) | dims: ('time','sensor') | units: °C
print(f"Data: {data.dims}, shape: {data.shape}")Data: ('time', 'sensor'), shape: (50, 5)
Binary operation with automatic broadcasting¶
# Time-dependent drift correction — shape (n_time,) broadcasts over 'sensor'
time_correction = cx.field(np.linspace(0, 2.0, n_time), time_axis)
# shape: (50,) | dims: ('time',) | units: °C — drift correction
corrected_data = data - time_correction
# shape: (50, 5) | dims: ('time','sensor') | units: °C — broadcasts ('time',) over 'sensor'
print(f"Corrected: {corrected_data.dims}, shape: {corrected_data.shape}")
# Sensor-specific calibration offset — shape (n_sensors,) broadcasts over 'time'
sensor_offsets = cx.field(np.array([0.1, -0.2, 0.3, -0.1, 0.0]), sensor_axis)
# shape: (5,) | dims: ('sensor',) | units: °C — per-sensor bias
calibrated_data = corrected_data + sensor_offsets
print(f"Calibrated: {calibrated_data.dims}, shape: {calibrated_data.shape}")
# Sensor-specific scaling (multiplicative)
sensor_scales = cx.field(np.array([1.02, 0.98, 1.01, 0.99, 1.00]), sensor_axis)
fully_corrected = (data - time_correction) * sensor_scales + sensor_offsets
# shape: (50, 5) | dims: ('time','sensor') | units: °C
print(f"Fully corrected: {fully_corrected.dims}, shape: {fully_corrected.shape}")Corrected: ('time', 'sensor'), shape: (50, 5)
Calibrated: ('time', 'sensor'), shape: (50, 5)
Fully corrected: ('time', 'sensor'), shape: (50, 5)
Automatic alignment with different dimension orders¶
Coordax automatically aligns dimensions when operands have different axis orderings.
data_ts = cx.field(np.ones((n_time, n_sensors)), 'time', 'sensor') # (time, sensor)
data_st = cx.field(np.ones((n_sensors, n_time)), 'sensor', 'time') # (sensor, time)
# Addition with automatic alignment — no manual transpose needed!
result = data_ts + data_st
# shape: (50, 5) | dims: ('time','sensor') — alignment is automatic!
print(f"Result dims: {result.dims}, shape: {result.shape}")Result dims: ('time', 'sensor'), shape: (50, 5)
2.3 Subtracting the Zonal Mean (Advanced Binary Operations)¶
A common operation in climate science: removing the longitudinal mean to study anomalies.
n_time, n_lat, n_lon = 6, 16, 32
temps = np.random.randn(n_time, n_lat, n_lon) * 5 + 15
lat_vals = np.linspace(-90, 90, n_lat)
lat_effect = np.linspace(30, -10, n_lat)
temps += lat_effect[np.newaxis, :, np.newaxis]
time_axis = cx.LabeledAxis('time', np.arange(n_time, dtype=float))
lat_axis = cx.LabeledAxis('latitude', lat_vals)
lon_axis = cx.LabeledAxis('longitude', np.linspace(-180, 180, n_lon))
temperature = cx.field(temps, time_axis, lat_axis, lon_axis)
# shape: (6, 16, 32) | dims: ('time','latitude','longitude') | units: °C
print(f"Temperature field: {temperature.dims}, shape: {temperature.shape}")Temperature field: ('time', 'latitude', 'longitude'), shape: (6, 16, 32)
Compute and subtract the zonal mean¶
# Step 1: Untag longitude to expose it as a positional axis
temp_lon_untagged = temperature.untag('longitude')
# shape: (6, 16, 32) — 'longitude' is now positional (trailing axis)
# Step 2: Apply mean over the trailing (longitude) axis
zonal_mean = cx.cmap(lambda x: jnp.mean(x, axis=-1))(temp_lon_untagged)
# shape: (6, 16) | dims: ('time','latitude') | units: °C — mean over 32 lons
print(f"Zonal mean: {zonal_mean.dims}, shape: {zonal_mean.shape}")
# Step 3: Subtract — (time, lat, lon) - (time, lat) broadcasts automatically
temperature_anomaly = temperature - zonal_mean
# shape: (6, 16, 32) | dims: ('time','latitude','longitude') | units: °C
# Broadcasting: (time,lat) is tiled across the 32 longitude points
print(f"Anomaly: {temperature_anomaly.dims}, shape: {temperature_anomaly.shape}")
# Verify: zonal mean of anomaly ≈ 0
anomaly_check = cx.cmap(lambda x: jnp.mean(x, axis=-1))(
temperature_anomaly.untag('longitude')
)
print(f"Anomaly zonal mean (should be ~0): max={float(jnp.abs(anomaly_check.data).max()):.2e}")Zonal mean: ('time', 'latitude'), shape: (6, 16)
Anomaly: ('time', 'latitude', 'longitude'), shape: (6, 16, 32)
Anomaly zonal mean (should be ~0): max=7.15e-06
Temporal mean and global spatial mean¶
# Temporal mean: mean over 'time' axis
temporal_mean = cx.cmap(lambda x: jnp.mean(x, axis=0))(temperature.untag('time'))
# shape: (16, 32) | dims: ('latitude','longitude') | units: °C
print(f"Temporal mean: {temporal_mean.dims}, shape: {temporal_mean.shape}")
temporal_anomaly = temperature - temporal_mean
# shape: (6, 16, 32) | dims: ('time','latitude','longitude') | units: °C
print(f"Temporal anomaly: {temporal_anomaly.dims}, shape: {temporal_anomaly.shape}")
# Global spatial mean per time step
global_mean_per_time = cx.cmap(
lambda x: jnp.mean(x, axis=(-2, -1))
)(temperature.untag('latitude', 'longitude'))
# shape: (6,) | dims: ('time',) | units: °C — global mean at each timestep
print(f"Global mean: {global_mean_per_time.dims}, shape: {global_mean_per_time.shape}")Temporal mean: ('latitude', 'longitude'), shape: (16, 32)
Temporal anomaly: ('time', 'latitude', 'longitude'), shape: (6, 16, 32)
Global mean: ('time',), shape: (6,)
Key Patterns Summary¶
- Unary operations (
field * 2,field - 273.15,cx.cmap(jnp.abs)(field)) — preserve dims/coords - Binary operations automatically broadcast over matching named dimensions
- Automatic alignment — different dimension orderings are handled without manual transpose
- Reductions —
field.untag('dim')thencx.cmap(lambda x: jnp.mean(x, axis=...))(untagged)