Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Foundations — Field, axes, and coordinate-aware operations

The basic unit in coordax is a Field — a thin, JAX-native wrapper that pairs a jax.Array with a tuple of named axes. Each axis is either a LabeledAxis (numeric ticks, e.g. latitude values) or a SizedAxis (a name with a size but no ticks, e.g. RGB channels). This design is explicitly modeled after xarray.DataArray Hoyer & Hamman (2017), but with one hard constraint dropped and one gained: labels must be numeric (so everything stays inside JAX’s typed array world), and the whole object is a JAX pytree — so jit, vmap, grad, and scan work without custom registrations.

Data model

A Field is the pair

F=(X, {a1,a2,,an}),XRN1×N2××Nn,\mathbf{F} = (\mathbf{X},\ \{a_1, a_2, \ldots, a_n\}),\qquad \mathbf{X} \in \mathbb{R}^{N_1 \times N_2 \times \cdots \times N_n},

where each axis aia_i carries a name and, optionally, a coordinate vector ciRNi\mathbf{c}_i \in \mathbb{R}^{N_i}. Positional axes (no name) are also allowed and behave like raw NumPy dimensions — useful for batch / channel dims that shouldn’t participate in coordinate-aware dispatch.

Why named axes at all?

Two reasons, one practical and one mathematical.

Broadcasting rules

Binary ops F1F2\mathbf{F}_1 \odot \mathbf{F}_2 follow a simple rule: axes are matched by name, not by position. Unnamed (positional) axes follow the usual NumPy rules. This gives xarray-style broadcasting Hoyer & Hamman (2017) while keeping the fast-path tensor semantics that JAX relies on.

Numerical considerations

Notebooks

References

References
  1. Hoyer, S., & Hamman, J. (2017). xarray: N-D Labeled Arrays and Datasets in Python. Journal of Open Research Software, 5(1). 10.5334/jors.148