Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Unary & binary ops

Part 2: Safe, Labeled Arithmetic (Unary & Binary Ops)

This notebook demonstrates how coordax provides safe, dimension-aware arithmetic operations that automatically handle broadcasting and alignment.

import coordax as cx
import jax.numpy as jnp
import numpy as np

# NOTE: cx.field() is preferred over cx.wrap() in coordax >= 0.2

2.1 Unit Conversion (Unary Operations)

Converting temperature from Kelvin to Celsius while preserving coordinate metadata.

n_lat, n_lon = 16, 32

temps_kelvin = np.random.uniform(270, 310, size=(n_lat, n_lon))
# shape: (16, 32) | dtype: float64 | units: K  [270, 310 K]

lat_axis = cx.LabeledAxis('latitude', np.linspace(-90, 90, n_lat))
lon_axis = cx.LabeledAxis('longitude', np.linspace(-180, 180, n_lon))

temperature_K = cx.field(temps_kelvin, lat_axis, lon_axis)
# shape: (16, 32) | dims: ('latitude','longitude') | units: K
print(f"Temperature (K): {temperature_K.dims}, shape: {temperature_K.shape}")

# Unary operation — scalar arithmetic preserves coordinates
# T[°C] = T[K] − 273.15
temperature_C = temperature_K - 273.15
# shape: (16, 32) | dims: ('latitude','longitude') | units: °C
print(f"Temperature (C): {temperature_C.dims}")
print(f"Coordinates preserved: {list(temperature_C.axes.keys())}")

# Verify coordinates are identical
assert temperature_C.axes['latitude'] == temperature_K.axes['latitude']
assert temperature_C.axes['longitude'] == temperature_K.axes['longitude']
print("Coordinates preserved: True")
Temperature (K): ('latitude', 'longitude'), shape: (16, 32)
Temperature (C): ('latitude', 'longitude')
Coordinates preserved: ['latitude', 'longitude']
Coordinates preserved: True

More complex unary operations

temp_squared = temperature_C ** 2      # shape: (16, 32) | units: °C² (demo only; not a physical quantity)
_temp_abs = cx.cmap(jnp.abs)(temperature_C)  # shape: (16, 32) | units: °C
_temp_exp = cx.cmap(jnp.exp)(temperature_C / 100)  # shape: (16, 32) | dimensionless

print(f"All operations preserve dims: {temp_squared.dims == temperature_C.dims}")
All operations preserve dims: True

Custom unary transformation with cmap

def normalize_to_range(x, old_min, old_max, new_min=0.0, new_max=1.0):
    """Normalize values from [old_min, old_max] to [new_min, new_max]."""
    return (x - old_min) / (old_max - old_min) * (new_max - new_min) + new_min

# Normalization is elementwise with scalar min/max. Using cx.cmap here would
# vmap over every named axis and reduce x to a scalar, making x.min()==x.max()
# and producing NaNs — call the function directly on the Field instead.
t_min = float(temperature_C.data.min())
t_max = float(temperature_C.data.max())
temp_normalized = normalize_to_range(temperature_C, t_min, t_max)
# shape: (16, 32) | dims: ('latitude','longitude') | range: [0, 1]

print(f"Normalized range: [{float(temp_normalized.data.min()):.2f}, {float(temp_normalized.data.max()):.2f}]")
print(f"Shape: {temp_normalized.shape}, Dims: {temp_normalized.dims}")
Normalized range: [0.00, 1.00]
Shape: (16, 32), Dims: ('latitude', 'longitude')

2.2 Applying a Time-Dependent Correction (Binary Operations)

Binary operations automatically broadcast over matching named dimensions.

n_time, n_sensors = 50, 5

measurements = np.random.randn(n_time, n_sensors) * 10 + 25
# shape: (50, 5) | dtype: float64 | units: °C (simulated)

time_axis = cx.LabeledAxis('time', np.arange(n_time) * 60.0)    # seconds
sensor_axis = cx.LabeledAxis('sensor', np.arange(n_sensors, dtype=float))

data = cx.field(measurements, time_axis, sensor_axis)
# shape: (50, 5) | dims: ('time','sensor') | units: °C
print(f"Data: {data.dims}, shape: {data.shape}")
Data: ('time', 'sensor'), shape: (50, 5)

Binary operation with automatic broadcasting

# Time-dependent drift correction — shape (n_time,) broadcasts over 'sensor'
time_correction = cx.field(np.linspace(0, 2.0, n_time), time_axis)
# shape: (50,) | dims: ('time',) | units: °C  — drift correction
corrected_data = data - time_correction
# shape: (50, 5) | dims: ('time','sensor') | units: °C  — broadcasts ('time',) over 'sensor'
print(f"Corrected: {corrected_data.dims}, shape: {corrected_data.shape}")

# Sensor-specific calibration offset — shape (n_sensors,) broadcasts over 'time'
sensor_offsets = cx.field(np.array([0.1, -0.2, 0.3, -0.1, 0.0]), sensor_axis)
# shape: (5,) | dims: ('sensor',) | units: °C  — per-sensor bias
calibrated_data = corrected_data + sensor_offsets
print(f"Calibrated: {calibrated_data.dims}, shape: {calibrated_data.shape}")

# Sensor-specific scaling (multiplicative)
sensor_scales = cx.field(np.array([1.02, 0.98, 1.01, 0.99, 1.00]), sensor_axis)
fully_corrected = (data - time_correction) * sensor_scales + sensor_offsets
# shape: (50, 5) | dims: ('time','sensor') | units: °C
print(f"Fully corrected: {fully_corrected.dims}, shape: {fully_corrected.shape}")
Corrected: ('time', 'sensor'), shape: (50, 5)
Calibrated: ('time', 'sensor'), shape: (50, 5)
Fully corrected: ('time', 'sensor'), shape: (50, 5)

Automatic alignment with different dimension orders

Coordax automatically aligns dimensions when operands have different axis orderings.

data_ts = cx.field(np.ones((n_time, n_sensors)), 'time', 'sensor')   # (time, sensor)
data_st = cx.field(np.ones((n_sensors, n_time)), 'sensor', 'time')   # (sensor, time)

# Addition with automatic alignment — no manual transpose needed!
result = data_ts + data_st
# shape: (50, 5) | dims: ('time','sensor')  — alignment is automatic!
print(f"Result dims: {result.dims}, shape: {result.shape}")
Result dims: ('time', 'sensor'), shape: (50, 5)

2.3 Subtracting the Zonal Mean (Advanced Binary Operations)

A common operation in climate science: removing the longitudinal mean to study anomalies.

n_time, n_lat, n_lon = 6, 16, 32

temps = np.random.randn(n_time, n_lat, n_lon) * 5 + 15
lat_vals = np.linspace(-90, 90, n_lat)
lat_effect = np.linspace(30, -10, n_lat)
temps += lat_effect[np.newaxis, :, np.newaxis]

time_axis = cx.LabeledAxis('time', np.arange(n_time, dtype=float))
lat_axis = cx.LabeledAxis('latitude', lat_vals)
lon_axis = cx.LabeledAxis('longitude', np.linspace(-180, 180, n_lon))

temperature = cx.field(temps, time_axis, lat_axis, lon_axis)
# shape: (6, 16, 32) | dims: ('time','latitude','longitude') | units: °C
print(f"Temperature field: {temperature.dims}, shape: {temperature.shape}")
Temperature field: ('time', 'latitude', 'longitude'), shape: (6, 16, 32)

Compute and subtract the zonal mean

# Step 1: Untag longitude to expose it as a positional axis
temp_lon_untagged = temperature.untag('longitude')
# shape: (6, 16, 32) — 'longitude' is now positional (trailing axis)

# Step 2: Apply mean over the trailing (longitude) axis
zonal_mean = cx.cmap(lambda x: jnp.mean(x, axis=-1))(temp_lon_untagged)
# shape: (6, 16) | dims: ('time','latitude') | units: °C  — mean over 32 lons
print(f"Zonal mean: {zonal_mean.dims}, shape: {zonal_mean.shape}")

# Step 3: Subtract — (time, lat, lon) - (time, lat) broadcasts automatically
temperature_anomaly = temperature - zonal_mean
# shape: (6, 16, 32) | dims: ('time','latitude','longitude') | units: °C
# Broadcasting: (time,lat) is tiled across the 32 longitude points
print(f"Anomaly: {temperature_anomaly.dims}, shape: {temperature_anomaly.shape}")

# Verify: zonal mean of anomaly ≈ 0
anomaly_check = cx.cmap(lambda x: jnp.mean(x, axis=-1))(
    temperature_anomaly.untag('longitude')
)
print(f"Anomaly zonal mean (should be ~0): max={float(jnp.abs(anomaly_check.data).max()):.2e}")
Zonal mean: ('time', 'latitude'), shape: (6, 16)
Anomaly: ('time', 'latitude', 'longitude'), shape: (6, 16, 32)
Anomaly zonal mean (should be ~0): max=7.15e-06

Temporal mean and global spatial mean

# Temporal mean: mean over 'time' axis
temporal_mean = cx.cmap(lambda x: jnp.mean(x, axis=0))(temperature.untag('time'))
# shape: (16, 32) | dims: ('latitude','longitude') | units: °C
print(f"Temporal mean: {temporal_mean.dims}, shape: {temporal_mean.shape}")

temporal_anomaly = temperature - temporal_mean
# shape: (6, 16, 32) | dims: ('time','latitude','longitude') | units: °C
print(f"Temporal anomaly: {temporal_anomaly.dims}, shape: {temporal_anomaly.shape}")

# Global spatial mean per time step
global_mean_per_time = cx.cmap(
    lambda x: jnp.mean(x, axis=(-2, -1))
)(temperature.untag('latitude', 'longitude'))
# shape: (6,) | dims: ('time',) | units: °C  — global mean at each timestep
print(f"Global mean: {global_mean_per_time.dims}, shape: {global_mean_per_time.shape}")
Temporal mean: ('latitude', 'longitude'), shape: (16, 32)
Temporal anomaly: ('time', 'latitude', 'longitude'), shape: (6, 16, 32)
Global mean: ('time',), shape: (6,)

Key Patterns Summary

  1. Unary operations (field * 2, field - 273.15, cx.cmap(jnp.abs)(field)) — preserve dims/coords
  2. Binary operations automatically broadcast over matching named dimensions
  3. Automatic alignment — different dimension orderings are handled without manual transpose
  4. Reductionsfield.untag('dim') then cx.cmap(lambda x: jnp.mean(x, axis=...))(untagged)