Coordinate-aware ops
Part 4: Coordinate-Aware Analysis (Ops Using Coordinate Info)¶
This notebook demonstrates how to leverage physical coordinate values for selection, slicing, regridding, and interpolation.
import coordax as cx
import jax.numpy as jnp
import numpy as np
# NOTE: cx.field() is preferred over cx.wrap() in coordax >= 0.24.1 Selecting Data for a Specific Location¶
Select by physical coordinate value rather than array index.
# Create a global temperature field (small for fast execution)
n_time, n_lat, n_lon = 12, 16, 32
temps = np.random.randn(n_time, n_lat, n_lon) * 10 + 15
lat_values = np.linspace(-90, 90, n_lat)
lat_effect = 20 * np.cos(np.deg2rad(lat_values))
temps += lat_effect[np.newaxis, :, np.newaxis]
time_axis = cx.LabeledAxis('time', np.arange(n_time, dtype=float))
lat_axis = cx.LabeledAxis('latitude', lat_values)
lon_axis = cx.LabeledAxis('longitude', np.linspace(-180, 180, n_lon))
temperature = cx.field(temps, time_axis, lat_axis, lon_axis)
# shape: (12, 16, 32) | dims: ('time','latitude','longitude') | units: °C
print(f"Temperature field: {temperature.dims}, shape: {temperature.shape}")Temperature field: ('time', 'latitude', 'longitude'), shape: (12, 16, 32)
Find nearest coordinate index¶
target_lat = 47.6
target_lon = -122.3
lat_coords = lat_axis.ticks
lon_coords = lon_axis.ticks
lat_idx = int(np.argmin(np.abs(lat_coords - target_lat)))
lon_idx = int(np.argmin(np.abs(lon_coords - target_lon)))
actual_lat = lat_coords[lat_idx]
actual_lon = lon_coords[lon_idx]
print(f"Target: {target_lat}°N, {target_lon}°E")
print(f"Nearest grid point: {actual_lat:.2f}°N, {actual_lon:.2f}°E")
# Extract and re-wrap time series at that location
time_series_data = temperature.data[:, lat_idx, lon_idx]
# shape: (12,) | dtype: float64 | units: °C — point extraction
time_series = cx.field(time_series_data, time_axis)
# shape: (12,) | dims: ('time',) | units: °C
print(f"Time series: {time_series.dims}, shape: {time_series.shape}")
print(f"Annual mean: {float(jnp.mean(time_series.data)):.2f} °C")Target: 47.6°N, -122.3°E
Nearest grid point: 42.00°N, -121.94°E
Time series: ('time',), shape: (12,)
Annual mean: 28.81 °C
Helper: coordinate-based location selection¶
def select_location(field, lat_target, lon_target,
lat_dim='latitude', lon_dim='longitude'):
"""Select data at a specific (lat, lon) by nearest-neighbor lookup."""
lat_coords = field.axes[lat_dim].ticks
lon_coords = field.axes[lon_dim].ticks
lat_idx = int(jnp.argmin(jnp.abs(lat_coords - lat_target)))
lon_idx = int(jnp.argmin(jnp.abs(lon_coords - lon_target)))
lat_pos = field.dims.index(lat_dim)
lon_pos = field.dims.index(lon_dim)
indexer = [slice(None)] * field.ndim
indexer[lat_pos] = lat_idx
indexer[lon_pos] = lon_idx
selected_data = field.data[tuple(indexer)]
# Re-wrap with remaining dimensions
remaining_axes = [
field.axes[d]
for d in field.dims
if d not in (lat_dim, lon_dim)
]
return cx.field(selected_data, *remaining_axes)
location_series = select_location(temperature, target_lat, target_lon)
# shape: (12,) | dims: ('time',) | units: °C — point extraction at (47.6°N, 122.3°W)
print(f"Selected time series: {location_series.dims}, shape: {location_series.shape}")Selected time series: ('time',), shape: (12,)
4.2 Slicing a Region of Interest¶
Extract a spatial subset using coordinate boundary values.
lat_min, lat_max = 42.0, 50.0
lon_min, lon_max = -125.0, -115.0
lat_coords = lat_axis.ticks
lon_coords = lon_axis.ticks
lat_start = int(np.searchsorted(lat_coords, lat_min, side='left'))
lat_end = int(np.searchsorted(lat_coords, lat_max, side='right'))
lon_start = int(np.searchsorted(lon_coords, lon_min, side='left'))
lon_end = int(np.searchsorted(lon_coords, lon_max, side='right'))
regional_data = temperature.data[:, lat_start:lat_end, lon_start:lon_end]
regional_lat_axis = cx.LabeledAxis('latitude', lat_coords[lat_start:lat_end])
regional_lon_axis = cx.LabeledAxis('longitude', lon_coords[lon_start:lon_end])
regional_temperature = cx.field(regional_data, time_axis, regional_lat_axis, regional_lon_axis)
# shape: (12, lat_slice, lon_slice) | dims: ('time','latitude','longitude')
# lat_slice / lon_slice depend on grid; subset of original (12,16,32) covering [42°N–50°N]×[125°W–115°W]
print(f"Regional temperature: {regional_temperature.dims}, shape: {regional_temperature.shape}")
# Spatial mean time series for the region
regional_mean_ts = cx.cmap(
lambda x: jnp.mean(x, axis=(-2, -1))
)(regional_temperature.untag('latitude', 'longitude'))
# shape: (12,) | dims: ('time',) | units: °C — spatial mean over region
print(f"Regional mean time series shape: {regional_mean_ts.shape}")Regional temperature: ('time', 'latitude', 'longitude'), shape: (12, 1, 1)
Regional mean time series shape: (12,)
4.3 Regridding / Interpolation¶
Interpolate data from a coarse grid to a finer grid using coordinate values.
n_lat_coarse, n_lon_coarse = 8, 16
temps_coarse = np.random.randn(n_lat_coarse, n_lon_coarse) * 10 + 15
lat_coarse = np.linspace(-90, 90, n_lat_coarse)
lon_coarse = np.linspace(-180, 180, n_lon_coarse)
temps_coarse += (20 * np.cos(np.deg2rad(lat_coarse)))[:, np.newaxis]
lat_axis_coarse = cx.LabeledAxis('latitude', lat_coarse)
lon_axis_coarse = cx.LabeledAxis('longitude', lon_coarse)
temp_coarse = cx.field(temps_coarse, lat_axis_coarse, lon_axis_coarse)
# shape: (8, 16) | dims: ('latitude','longitude') | units: °C (coarse grid)
n_lat_fine, n_lon_fine = 16, 32
lat_fine = np.linspace(-90, 90, n_lat_fine)
lon_fine = np.linspace(-180, 180, n_lon_fine)
print(f"Coarse grid: {temp_coarse.shape}")
print(f"Target fine grid: ({n_lat_fine}, {n_lon_fine})")Coarse grid: (8, 16)
Target fine grid: (16, 32)
Simple nearest-neighbor interpolation (for demonstration)¶
def nearest_neighbor_regrid(coarse_data, lat_old, lon_old, lat_new, lon_new):
"""Nearest-neighbor regridding using broadcasting."""
lat_idx = jnp.argmin(jnp.abs(lat_new[:, None] - lat_old[None, :]), axis=1)
lon_idx = jnp.argmin(jnp.abs(lon_new[:, None] - lon_old[None, :]), axis=1)
return coarse_data[lat_idx[:, None], lon_idx[None, :]]
temps_fine_data = nearest_neighbor_regrid(
temps_coarse, lat_coarse, lon_coarse, lat_fine, lon_fine
)
lat_axis_fine = cx.LabeledAxis('latitude', lat_fine)
lon_axis_fine = cx.LabeledAxis('longitude', lon_fine)
temp_fine = cx.field(temps_fine_data, lat_axis_fine, lon_axis_fine)
# shape: (16, 32) | dims: ('latitude','longitude') | units: °C (fine grid, nearest-neighbor)
print(f"Interpolated fine grid: {temp_fine.shape}")Interpolated fine grid: (16, 32)
4.4 Sparse Weather Station Interpolation (Pattern)¶
Demonstrates the workflow for GP or nearest-neighbor interpolation from irregularly spaced station data to a regular grid.
n_stations = 20
station_lats = np.random.uniform(-60, 60, n_stations)
station_lons = np.random.uniform(-180, 180, n_stations)
station_temps = (15.0
+ 20 * np.cos(np.deg2rad(station_lats))
+ np.random.randn(n_stations) * 2.0)
print(f"Stations: {n_stations}, temp range: {station_temps.min():.1f} — {station_temps.max():.1f} °C")
n_lat_grid, n_lon_grid = 8, 16
grid_lats = np.linspace(-60, 60, n_lat_grid)
grid_lons = np.linspace(-180, 180, n_lon_grid)
def nearest_station_interp(s_lats, s_lons, s_temps, g_lat, g_lon):
"""Nearest-station interpolation from sparse points to a regular grid."""
lat_mesh, lon_mesh = jnp.meshgrid(g_lat, g_lon, indexing='ij')
# Compute distance from every grid point to every station
dlat = lat_mesh[:, :, None] - s_lats[None, None, :]
dlon = (lon_mesh[:, :, None] - s_lons[None, None, :]) * jnp.cos(
jnp.deg2rad(lat_mesh[:, :, None])
)
dist = jnp.sqrt(dlat**2 + dlon**2) # (n_lat, n_lon, n_stations)
nearest_idx = jnp.argmin(dist, axis=-1) # (n_lat, n_lon)
return s_temps[nearest_idx]
temp_grid = nearest_station_interp(
station_lats, station_lons, station_temps,
grid_lats, grid_lons
)
grid_lat_axis = cx.LabeledAxis('latitude', grid_lats)
grid_lon_axis = cx.LabeledAxis('longitude', grid_lons)
temp_interpolated = cx.field(temp_grid, grid_lat_axis, grid_lon_axis)
# shape: (8, 16) | dims: ('latitude','longitude') | units: °C (nearest-station interpolation)
print(f"Interpolated grid: {temp_interpolated.dims}, shape: {temp_interpolated.shape}")
print(f"Temperature range: {float(temp_grid.min()):.1f} — {float(temp_grid.max()):.1f} °C")Stations: 20, temp range: 21.4 — 36.0 °C
Interpolated grid: ('latitude', 'longitude'), shape: (8, 16)
Temperature range: 21.4 — 36.0 °C
Key Patterns Summary¶
- Coordinate-based selection — use
axis.ticks+argminto find nearest index - Regional slicing —
np.searchsortedon coordinate values to find slice bounds - Regridding — interpolate
.datathen re-wrap with new coordinate axes - Sparse-to-grid — nearest-neighbor or GP regression, then wrap result
For production interpolation, consider interpax (differentiable, JIT-compatible) or GPJax for probabilistic interpolation.