Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Coordinate-aware ops

Part 4: Coordinate-Aware Analysis (Ops Using Coordinate Info)

This notebook demonstrates how to leverage physical coordinate values for selection, slicing, regridding, and interpolation.

import coordax as cx
import jax.numpy as jnp
import numpy as np

# NOTE: cx.field() is preferred over cx.wrap() in coordax >= 0.2

4.1 Selecting Data for a Specific Location

Select by physical coordinate value rather than array index.

# Create a global temperature field (small for fast execution)
n_time, n_lat, n_lon = 12, 16, 32

temps = np.random.randn(n_time, n_lat, n_lon) * 10 + 15
lat_values = np.linspace(-90, 90, n_lat)
lat_effect = 20 * np.cos(np.deg2rad(lat_values))
temps += lat_effect[np.newaxis, :, np.newaxis]

time_axis = cx.LabeledAxis('time', np.arange(n_time, dtype=float))
lat_axis = cx.LabeledAxis('latitude', lat_values)
lon_axis = cx.LabeledAxis('longitude', np.linspace(-180, 180, n_lon))

temperature = cx.field(temps, time_axis, lat_axis, lon_axis)
# shape: (12, 16, 32) | dims: ('time','latitude','longitude') | units: °C
print(f"Temperature field: {temperature.dims}, shape: {temperature.shape}")
Temperature field: ('time', 'latitude', 'longitude'), shape: (12, 16, 32)

Find nearest coordinate index

target_lat = 47.6
target_lon = -122.3

lat_coords = lat_axis.ticks
lon_coords = lon_axis.ticks

lat_idx = int(np.argmin(np.abs(lat_coords - target_lat)))
lon_idx = int(np.argmin(np.abs(lon_coords - target_lon)))

actual_lat = lat_coords[lat_idx]
actual_lon = lon_coords[lon_idx]

print(f"Target: {target_lat}°N, {target_lon}°E")
print(f"Nearest grid point: {actual_lat:.2f}°N, {actual_lon:.2f}°E")

# Extract and re-wrap time series at that location
time_series_data = temperature.data[:, lat_idx, lon_idx]
# shape: (12,) | dtype: float64 | units: °C  — point extraction
time_series = cx.field(time_series_data, time_axis)
# shape: (12,) | dims: ('time',) | units: °C
print(f"Time series: {time_series.dims}, shape: {time_series.shape}")
print(f"Annual mean: {float(jnp.mean(time_series.data)):.2f} °C")
Target: 47.6°N, -122.3°E
Nearest grid point: 42.00°N, -121.94°E
Time series: ('time',), shape: (12,)
Annual mean: 28.81 °C

Helper: coordinate-based location selection

def select_location(field, lat_target, lon_target,
                    lat_dim='latitude', lon_dim='longitude'):
    """Select data at a specific (lat, lon) by nearest-neighbor lookup."""
    lat_coords = field.axes[lat_dim].ticks
    lon_coords = field.axes[lon_dim].ticks

    lat_idx = int(jnp.argmin(jnp.abs(lat_coords - lat_target)))
    lon_idx = int(jnp.argmin(jnp.abs(lon_coords - lon_target)))

    lat_pos = field.dims.index(lat_dim)
    lon_pos = field.dims.index(lon_dim)

    indexer = [slice(None)] * field.ndim
    indexer[lat_pos] = lat_idx
    indexer[lon_pos] = lon_idx
    selected_data = field.data[tuple(indexer)]

    # Re-wrap with remaining dimensions
    remaining_axes = [
        field.axes[d]
        for d in field.dims
        if d not in (lat_dim, lon_dim)
    ]
    return cx.field(selected_data, *remaining_axes)


location_series = select_location(temperature, target_lat, target_lon)
# shape: (12,) | dims: ('time',) | units: °C  — point extraction at (47.6°N, 122.3°W)
print(f"Selected time series: {location_series.dims}, shape: {location_series.shape}")
Selected time series: ('time',), shape: (12,)

4.2 Slicing a Region of Interest

Extract a spatial subset using coordinate boundary values.

lat_min, lat_max = 42.0, 50.0
lon_min, lon_max = -125.0, -115.0

lat_coords = lat_axis.ticks
lon_coords = lon_axis.ticks

lat_start = int(np.searchsorted(lat_coords, lat_min, side='left'))
lat_end   = int(np.searchsorted(lat_coords, lat_max, side='right'))
lon_start = int(np.searchsorted(lon_coords, lon_min, side='left'))
lon_end   = int(np.searchsorted(lon_coords, lon_max, side='right'))

regional_data = temperature.data[:, lat_start:lat_end, lon_start:lon_end]

regional_lat_axis = cx.LabeledAxis('latitude', lat_coords[lat_start:lat_end])
regional_lon_axis = cx.LabeledAxis('longitude', lon_coords[lon_start:lon_end])

regional_temperature = cx.field(regional_data, time_axis, regional_lat_axis, regional_lon_axis)
# shape: (12, lat_slice, lon_slice) | dims: ('time','latitude','longitude')
# lat_slice / lon_slice depend on grid; subset of original (12,16,32) covering [42°N–50°N]×[125°W–115°W]
print(f"Regional temperature: {regional_temperature.dims}, shape: {regional_temperature.shape}")

# Spatial mean time series for the region
regional_mean_ts = cx.cmap(
    lambda x: jnp.mean(x, axis=(-2, -1))
)(regional_temperature.untag('latitude', 'longitude'))
# shape: (12,) | dims: ('time',) | units: °C  — spatial mean over region
print(f"Regional mean time series shape: {regional_mean_ts.shape}")
Regional temperature: ('time', 'latitude', 'longitude'), shape: (12, 1, 1)
Regional mean time series shape: (12,)

4.3 Regridding / Interpolation

Interpolate data from a coarse grid to a finer grid using coordinate values.

n_lat_coarse, n_lon_coarse = 8, 16
temps_coarse = np.random.randn(n_lat_coarse, n_lon_coarse) * 10 + 15
lat_coarse = np.linspace(-90, 90, n_lat_coarse)
lon_coarse = np.linspace(-180, 180, n_lon_coarse)
temps_coarse += (20 * np.cos(np.deg2rad(lat_coarse)))[:, np.newaxis]

lat_axis_coarse = cx.LabeledAxis('latitude', lat_coarse)
lon_axis_coarse = cx.LabeledAxis('longitude', lon_coarse)
temp_coarse = cx.field(temps_coarse, lat_axis_coarse, lon_axis_coarse)
# shape: (8, 16) | dims: ('latitude','longitude') | units: °C  (coarse grid)

n_lat_fine, n_lon_fine = 16, 32
lat_fine = np.linspace(-90, 90, n_lat_fine)
lon_fine = np.linspace(-180, 180, n_lon_fine)

print(f"Coarse grid: {temp_coarse.shape}")
print(f"Target fine grid: ({n_lat_fine}, {n_lon_fine})")
Coarse grid: (8, 16)
Target fine grid: (16, 32)

Simple nearest-neighbor interpolation (for demonstration)

def nearest_neighbor_regrid(coarse_data, lat_old, lon_old, lat_new, lon_new):
    """Nearest-neighbor regridding using broadcasting."""
    lat_idx = jnp.argmin(jnp.abs(lat_new[:, None] - lat_old[None, :]), axis=1)
    lon_idx = jnp.argmin(jnp.abs(lon_new[:, None] - lon_old[None, :]), axis=1)
    return coarse_data[lat_idx[:, None], lon_idx[None, :]]

temps_fine_data = nearest_neighbor_regrid(
    temps_coarse, lat_coarse, lon_coarse, lat_fine, lon_fine
)

lat_axis_fine = cx.LabeledAxis('latitude', lat_fine)
lon_axis_fine = cx.LabeledAxis('longitude', lon_fine)
temp_fine = cx.field(temps_fine_data, lat_axis_fine, lon_axis_fine)
# shape: (16, 32) | dims: ('latitude','longitude') | units: °C  (fine grid, nearest-neighbor)

print(f"Interpolated fine grid: {temp_fine.shape}")
Interpolated fine grid: (16, 32)

4.4 Sparse Weather Station Interpolation (Pattern)

Demonstrates the workflow for GP or nearest-neighbor interpolation from irregularly spaced station data to a regular grid.

n_stations = 20
station_lats = np.random.uniform(-60, 60, n_stations)
station_lons = np.random.uniform(-180, 180, n_stations)
station_temps = (15.0
                 + 20 * np.cos(np.deg2rad(station_lats))
                 + np.random.randn(n_stations) * 2.0)

print(f"Stations: {n_stations}, temp range: {station_temps.min():.1f} — {station_temps.max():.1f} °C")

n_lat_grid, n_lon_grid = 8, 16
grid_lats = np.linspace(-60, 60, n_lat_grid)
grid_lons = np.linspace(-180, 180, n_lon_grid)


def nearest_station_interp(s_lats, s_lons, s_temps, g_lat, g_lon):
    """Nearest-station interpolation from sparse points to a regular grid."""
    lat_mesh, lon_mesh = jnp.meshgrid(g_lat, g_lon, indexing='ij')

    # Compute distance from every grid point to every station
    dlat = lat_mesh[:, :, None] - s_lats[None, None, :]
    dlon = (lon_mesh[:, :, None] - s_lons[None, None, :]) * jnp.cos(
        jnp.deg2rad(lat_mesh[:, :, None])
    )
    dist = jnp.sqrt(dlat**2 + dlon**2)  # (n_lat, n_lon, n_stations)
    nearest_idx = jnp.argmin(dist, axis=-1)  # (n_lat, n_lon)
    return s_temps[nearest_idx]


temp_grid = nearest_station_interp(
    station_lats, station_lons, station_temps,
    grid_lats, grid_lons
)

grid_lat_axis = cx.LabeledAxis('latitude', grid_lats)
grid_lon_axis = cx.LabeledAxis('longitude', grid_lons)
temp_interpolated = cx.field(temp_grid, grid_lat_axis, grid_lon_axis)
# shape: (8, 16) | dims: ('latitude','longitude') | units: °C  (nearest-station interpolation)

print(f"Interpolated grid: {temp_interpolated.dims}, shape: {temp_interpolated.shape}")
print(f"Temperature range: {float(temp_grid.min()):.1f} — {float(temp_grid.max()):.1f} °C")
Stations: 20, temp range: 21.4 — 36.0 °C
Interpolated grid: ('latitude', 'longitude'), shape: (8, 16)
Temperature range: 21.4 — 36.0 °C

Key Patterns Summary

  1. Coordinate-based selection — use axis.ticks + argmin to find nearest index
  2. Regional slicingnp.searchsorted on coordinate values to find slice bounds
  3. Regridding — interpolate .data then re-wrap with new coordinate axes
  4. Sparse-to-grid — nearest-neighbor or GP regression, then wrap result

For production interpolation, consider interpax (differentiable, JIT-compatible) or GPJax for probabilistic interpolation.