Foundations — Field, axes, and coordinate-aware operations
The basic unit in coordax is a Field — a thin, JAX-native wrapper that pairs
a jax.Array with a tuple of named axes. Each axis is either a
LabeledAxis (numeric ticks, e.g. latitude values) or a SizedAxis (a name
with a size but no ticks, e.g. RGB channels). This design is explicitly modeled
after xarray.DataArray Hoyer & Hamman (2017), but with one hard
constraint dropped and one gained: labels must be numeric (so everything
stays inside JAX’s typed array world), and the whole object is a JAX pytree —
so jit, vmap, grad, and scan work without custom registrations.
Data model¶
A Field is the pair
where each axis carries a name and, optionally, a coordinate vector . Positional axes (no name) are also allowed and behave like raw NumPy dimensions — useful for batch / channel dims that shouldn’t participate in coordinate-aware dispatch.
Why named axes at all?¶
Two reasons, one practical and one mathematical.
- Practical: it turns “axis 0” bugs into compile-time errors. A velocity
field on a grid can be reduced
along
'lat'by name; no more counting axes after avmap. - Mathematical: many numerical operators (derivatives, reductions, broadcasts) are naturally defined on a coordinate, not an index. A centered difference needs the ticks ; a mean over latitude needs the metric . Carrying the axis around means the operator can dispatch on it.
Broadcasting rules¶
Binary ops follow a simple rule: axes are
matched by name, not by position. Unnamed (positional) axes follow the
usual NumPy rules. This gives xarray-style broadcasting
Hoyer & Hamman (2017) while keeping the fast-path tensor semantics that
JAX relies on.
Numerical considerations¶
- Tick dtype.
LabeledAxisticks must be floating — monotone integer indices should useSizedAxisinstead. This trips people coming fromxarray, which happily labels with strings. - Alignment cost. Binary ops between fields with different tick vectors
do not interpolate; they raise. Reindex (
sel/isel) orreindex_likebefore combining, so mismatched grids fail loudly instead of silently broadcasting zeros. - Positional vs named mixing. A
Fieldwith a positional axis behaves like NumPy along that axis: coordinate-aware ops (reductions by name,sel) simply skip it. This is the intended “escape hatch” for batch dimensions. - JIT. Axis names are static metadata (compile-time); axis sizes may be
traced (runtime). Don’t put Python-level axis manipulation inside a
jit-compiled function if the axis names depend on data.
Notebooks¶
01_create_datasets— buildingFieldobjects for RGB images, time-series, and spatio-temporal lat-lon data. Coverscx.field(),LabeledAxis,SizedAxis, and thewrap/untaground-trip.02_ops_unary_binary— how+,*,jnp.where, and the rest of the NumPy API lift ontoField; the name-matching broadcasting rule in anger.03_ops_coordinates— positional slicing (isel), label slicing (sel), reindexing, andCartesianProductfor stacking independent axes into one.04_reductions—sum,mean,max,stddispatched by axis name; the pattern that enables zonal means and time averages without index arithmetic.
References¶
- Hoyer, S., & Hamman, J. (2017). xarray: N-D Labeled Arrays and Datasets in Python. Journal of Open Research Software, 5(1). 10.5334/jors.148