Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

A differentiable JAX/Equinox RTM — design and staged roadmap

UNEP
IMEO
MARS

RTM 4 JAX — a differentiable JAX/Equinox RTM for methane and multi-gas remote sensing

Bottom line up front. Build the RTM as a typed, Equinox-modular pipeline (atmosphere → spectroscopy → optical_properties → solver → instrument → loss) that exposes per-layer {τ, ω, phase moments, surface kernels} as the canonical differentiation interface — the same interface VLIDORT linearises by hand, but obtained for free from jax.jacrev/jax.vjp. Stage delivery in five versions that each ship a working retrieval: v0 Beer–Lambert SWIR matched-filter target generator, v1 instrument-aware clear-sky OE retrieval, v2 in-JAX line-by-line (PreMODIT-style), v3 single-scattering with differentiable Mie, v4 multi-stream + doubling-adding with implicit differentiation through linear solves via Lineax. The closest single existing reference is ExoJAX (Kawahara et al. 2022/2025), whose db→opa→art→sop factoring transfers almost wholesale to Earth remote sensing; the closest non-differentiable architectural reference is ARTS (agenda + workspace) and VLIDORT (analytical K-matrix). The user-supplied paper (Korkin et al. 2022, Comp. Phys. Comm. 271, 108198, “A practical guide to writing a radiative transfer code”) is a pedagogical Gauss–Seidel scalar scattering tutorial — not differentiable, not LBL — but contributes a clean “make-it-right-then-fast” modular skeleton, the Fourier-azimuth decomposition, and the analytic single-scattering + correction trick.


1. The reference paper, and how to read it

The DOI fragment S0010465521003106 resolves to Korkin, Sayer, Ibrahim & Lyapustin (2022), “A practical guide to writing a radiative transfer code”, Computer Physics Communications 271, 108198 (USRA/NASA GSFC; MIT-licensed code at github.com/korkins/gsit). The companion gsit is a ~268-line Python+Numba implementation of a scalar, monochromatic, plane-parallel, Lambertian-surface, multi-stream Gauss–Seidel solver. It is explicitly not a differentiable, line-by-line, or HITRAN-aware code, and the authors view spectroscopy and Mie as “plug-ins” outside the RT solver. That framing is itself the most important lesson.

Transferable design takeaways:

The Korkin paper is best treated as the scaffolding lesson for v0–v1. For algorithmic depth in spectroscopy and scattering you will lean on ExoJAX, py4CAtS, VLIDORT, and SHDOM instead.


2. Landscape: what exists, what to inherit, what to displace

Classical Fortran codes — the algorithmic ancestors

The most consequential of the surveyed classical codes is VLIDORT (Spurr, RT Solutions). It is the gold standard for retrieval-grade RT because it ships analytical Jacobians (“K-matrix”) with respect to per-layer extinction Δₙ, single-scattering albedo ωₙ, phase-function moments Bₙₗ, and surface kernels — exactly what jax.jacrev produces for free in a JAX rewrite. VLIDORT’s hand-derived eigenvector-perturbation machinery (and the small-denominator Taylor expansions used when μₖ → μ₀⁻¹ in the streaming multipliers, with ε ≈ 10⁻³ switchover) is the single biggest implementation pitfall to anticipate in JAX: jax.scipy.linalg.eigh analytical derivatives fail at repeated eigenvalues and the streaming multipliers diverge under naive AD. Reproduce Spurr’s Taylor-series branch as a jax.lax.cond with a smoothed switch, or as a jax.custom_jvp.

DISORT (Stamnes et al.) contributes the canonical 2N→N eigenproblem reduction (use jax.scipy.linalg.eigh), the exponential-scaling transformation that prevents exp(+τ) overflow in boundary value problems, and delta-M + Nakajima–Tanaka (TMS) single-scatter correction for forward-peaked phase functions. These must all survive autodiff cleanly.

ARTS (Eriksson, Buehler et al.) contributes the workspace + agenda architecture: a strongly-typed registry of physical quantities with user-rewireable function pipelines. This maps almost one-to-one onto an Equinox-module DAG with jaxtyping-annotated pytree workspaces. The right design pattern is: agendas = Equinox-module compositions; workspace methods = eqx.Module.__call__ returning typed pytree slices.

libRadtran (Mayer & Kylling) contributes the solver-pluggability pattern (DISORT, two-stream, polRadtran, MYSTIC MC, SHDOM all consume the same optical setup) and the reptran “representative wavelength” band parameterisation for broadband applications.

py4CAtS (Schreier, DLR) is the closest Python ancestor: a pipeline of typed dataclasses xs → ac → od → ri (cross section → absorption coefficient → optical depth → radiance) wrapping HITRAN/GEISA I/O and Humliček/Weideman Voigt. Copy this typing scheme directly; replace NumPy with jax.numpy and the dataclasses with eqx.Module.

RTTOV contributes the convention that forward, tangent linear (jvp), adjoint (vjp), and K (jacobian) belong together as four faces of the same model — JAX provides all four automatically. RTTOV’s regression-coefficient fast model is also the right reference for an eventual “RTTOV-mode” in which a small neural network or polynomial replaces explicit LBL during operational throughput.

SHDOM (Evans) is the open 3-D RT reference; its Picard fixed-point iteration is a natural fit for jaxopt.FixedPointIteration / optimistix with implicit differentiation. 6S / OSOAA (successive orders) are conceptually clean — each scattering order is an explicit operator with an algebraic Jacobian — and useful as alternative engines, especially OSOAA’s Cox-Munk air–sea interface for ocean coupling.

SCIATRAN is the operational reference for UV–Vis–NIR–SWIR trace gas retrievals (GOME, SCIAMACHY, TROPOMI AMF LUTs); it already includes Raman/Ring inelastic scattering and pseudo-spherical geometry but is closed-source. MODTRAN is the proprietary band-model standard; do not re-implement it, but support correlated-k as an alternative substitution layer for the optical-depth provider.

JAX-native and differentiable efforts

ExoJAX (Kawahara et al., 2022 ApJS 258:31; 2025 ApJ, “ExoJAX2”) is the closest and most important reference. It already implements in pure JAX everything v0–v2 of this roadmap needs:

ComponentExoJAX implementationLesson
Voigt profilelpf.voigt / voigtone via Voigt–Hjerting hjert(x,a), pure jnp, no custom_vjp neededFP64 required at line center; vmap over ν is the unit pattern
LBL many-line opacityThree tiers: OpaDirect (full LPF), OpaModit (FFT/DIT on ESLOG grid), OpaPremodit (precomputed Line Basis Density, ≳10⁵–10⁸ lines on one GPU)For SWIR CH₄ ν₃ band (~10⁴–10⁵ HITRAN lines), PreMODIT-class precomputation is essential
MemoryOpart layer-wise jax.lax.scan makes memory O(1) in N_layerAdopt for all RT solvers; critical for HMC retrievals
Line list I/OUses radis.api (vaex/HDF5) for HITRAN/HITEMP/ExoMolDon’t reinvent — wrap radis or HAPI
RTArtEmisPure (intensity-based n-stream + flux-based 2-stream), ArtEmisScat (Toon flux-adding), ArtTransPureToon two-stream is the natural v4 starting point
MieOpaMie precomputed miegrid over log-normal PSDsBetter than recurrence-on-the-fly for retrievals
StackPure functional JAX (NOT Equinox), NumPyro + JAXNS downstreamThe clearest opportunity for the new project: re-frame in Equinox PyTrees for cleaner static/dynamic separation and filter_jit

ExoJAX is built for exoplanet transmission/emission with high-T atmospheres; the Earth-RS use case differs in (i) lower-T line widths and tighter auto_trange around 200–320 K, (ii) operational latency requirements, (iii) downstream coupling to instrument noise and plume retrieval rather than HMC posteriors over T-profiles. The algorithmic core is reusable; the surrounding system is not.

Other JAX/differentiable RT efforts worth knowing:


3. Staged roadmap

Each version ships a runnable retrieval against real data, accumulates regression benchmarks, and adds one fidelity axis at a time. Never refactor and add features in the same release.

v0 — Clear-sky Beer–Lambert SWIR (4–6 weeks)

Concrete use case: AVIRIS-NG / EMIT 2210–2410 nm CH₄ matched-filter target spectrum generator, replacing the MODTRAN lookup currently used by MAG1C and EMIT operational. Deliver scene-specific unit absorption spectra t(SZA, VZA, elevation, column H₂O) that drop directly into the user’s existing matched-filter code, with end-to-end gradients of t w.r.t. atmospheric state.

Deliverables.

Minimal API surface. See §7 for the actual code sketch.

v1 — Geometry, surface, instrument, noise (4–6 weeks)

v2 — Differentiable line-by-line in JAX (8–12 weeks)

The most spectroscopically demanding stage. Implement three opacity engines behind a common OpacityProvider protocol:

  1. OpaDirect-style LPF: per-line Voigt profile via Voigt–Hjerting hjert(x,a) (port from ExoJAX lpf.voigt), vmap over lines and ν. For ν grids ≲10⁴ and lines ≲10⁴, this is fine; useful for high-resolution validation.

  2. MODIT-style DIT/FFT: lineshape-density matrix on ESLOG ν grid + FFT convolution. Adopt van den Bekerom–Pannier formulation; include Lorentzian-wing aliasing correction.

  3. PreMODIT-style LBD: precompute Line Basis Density over coarse (E_lower, γ_self/γ_air) grid; runtime cost becomes O(N_grid_cells · N_ν log N_ν), independent of N_lines.

Physics components, all in pure JAX:

Validation: transmittance through AFGL US-Standard, Tropical, Sub-Arctic-Winter against py4CAtS and HAPI at 0.01 cm⁻¹, target relative error <0.1% in 2210–2410 nm, 1590–1690 nm (MethaneSAT band), 760 nm (O₂ A-band). Use FP64 throughout; FP32 produces ~10⁻⁷ cm⁻¹ line-center truncation that aliases at TROPOMI resolution.

Gradient validation: finite-difference check of ∂(transmittance)/∂(VMR, T, p, line strength) on a coarse grid; this is the most important regression test in the entire roadmap.

v3 — Single scattering + Rayleigh + aerosols (8–10 weeks)

Validation: Rayleigh + Lambertian against Coulson–Dave–Sekera analytic tables; Mie cross-sections against miepython; AERONET aerosol cases against 6SV.

v4 — Multiple scattering with implicit differentiation (12–16 weeks)

Add solvers behind a common RTSolver protocol, all consuming the same OpticalProperties pytree.

  1. Two-stream (Eddington / delta-Eddington / hemispheric mean / quadrature): closed-form 2×2 per-layer matrices with multi-layer adding (Toon et al. 1989). Implement as jax.lax.scan over layers; fully AD-safe.

  2. Toon flux-adding with delta-scaling: the ExoJAX ArtEmisScat/ArtReflectEmis scheme (Robinson & Crisp 2018). Differentiable end-to-end.

  3. Doubling–adding for plane-parallel atmospheres: doubling kernel R, T = combine(R,T,R,T) is a chain of small dense linear ops on (N_streams)² matrices, AD-friendly; the adding step composes layer (R,T) pairs.

  4. Discrete ordinates (DISORT-like) with N streams: layer eigendecomposition + block-tridiagonal boundary-value solve. This is where implicit differentiation matters. Three sub-decisions:

  1. Successive orders of scattering (SOS): each iteration is an explicit operator → jax.lax.scan; can be wrapped in a fixed-point solver with implicit differentiation (optimistix.FixedPointSolver or jaxopt.FixedPointIteration).

  2. Delta-M scaling + TMS single-scatter correction as separate modules; both algebraically differentiable.

Differentiability strategy summary for v4: scan-based solvers (two-stream, SOS) need no special treatment; the discrete-ordinates block-tridiagonal solve uses Lineax-mediated implicit differentiation; eigendecompositions need a degeneracy-safe custom_jvp branch.

v5+ — Polarisation, ocean–atmosphere, 3-D, emulators

Vector RT (Stokes I, Q, U, V) by promoting scalars to length-4 pytrees and replacing scalar phase functions with 4×4 phase matrices; OSOAA-style Cox-Munk air–sea interface as a boundary operator; 3-D RT via either an SHDOM-like spherical-harmonics + discrete-ordinates iteration (Picard fixed point + Lineax) or differentiable Monte Carlo via Mitsuba 3 / a JAX MC port (use jax.random keys and the replay trick); emulator integration via gpyroX — train GP/NN surrogates on the v4 forward model for operational throughput, expose them behind the same RTSolver protocol so retrievals can pick LBL vs surrogate at runtime.


4. Equinox-specific design patterns

Module structure vs pure functions

Use eqx.Module for objects that carry calibration state (instrument SRF, noise covariances, surface BRDF kernel parameters, precomputed cross-section tables, PreMODIT LBDs, line lists). Use plain functions for stateless transforms (Voigt evaluation given line parameters, Chapman factor, delta-M scaling). The rule is: if you would want to vmap the same operation across many instances, it’s a module; if across many inputs with one set of parameters, it’s a function.

Pytree design

AtmosphericState (eqx.Module)
  ├─ pressure: Float[Array, "n_layer"]
  ├─ temperature: Float[Array, "n_layer"]
  ├─ vmr: dict[str, Float[Array, "n_layer"]]   # 'CH4', 'CO2', 'H2O', ...
  └─ altitude: Float[Array, "n_layer+1"]       # level grid

Geometry (eqx.Module)
  ├─ sza, vza, raa: Float[Array, "n_pix"]
  └─ surface_elevation: Float[Array, "n_pix"]

Surface (eqx.Module, abstract)
  └─ subclasses: Lambertian, RossLi, RPV, Hapke, CoxMunk

Instrument (eqx.Module)
  ├─ srf: Callable                              # static
  ├─ wavelength_grid: Float[Array, "n_band"]   # static
  ├─ spectral_shift: Float[Array, ""]           # learnable
  ├─ fwhm: Float[Array, ""]                     # learnable
  └─ noise: NoiseModel                          # eqx submodule

OpticalProperties (eqx.Module)         # the canonical differentiable interface
  ├─ tau: Float[Array, "n_layer n_nu"]
  ├─ ssa: Float[Array, "n_layer n_nu"]
  └─ phase_moments: Float[Array, "n_layer n_nu n_mom"]

Make all numeric fields Float[Array, ...] (use jaxtyping). Mark genuinely static fields (line-list arrays after loading, SRF samples, ν grid) with eqx.field(static=True) so JIT does not retrace.

Filter-transforms

eqx.filter_jit and eqx.filter_grad are the default everywhere — they automatically partition pytrees into inexact (traced) and static. Batched retrievals over EMIT swaths use:

@eqx.filter_jit
def retrieve_pixel(state, instrument, surface, geom, y_obs):
    ...

retrieve_swath = eqx.filter_vmap(retrieve_pixel, in_axes=(None, None, None, 0, 0))

For Jacobians, use eqx.filter_jacrev when the state vector is small (≲ a few hundred entries — typical OE retrievals) and eqx.filter_jacfwd when the input dimension is smaller than the output (rare in retrievals). Always vmap outside jacrev over pixel batches, not inside.

Protocol-based solver swapping

Define abstract base classes / typing.Protocol for OpacityProvider, RTSolver, BRDF, NoiseModel. A retrieval driver then accepts any concrete combination at construction time, and eqx.filter_jit happily compiles distinct combinations. Concrete subclasses:

Integration with the user’s stack


5. Concrete retrieval applications

Hyperspectral multi-gas (EMIT, PRISMA, EnMAP, AVIRIS-NG, Tanager-1)

The v0 forward model already replaces the MODTRAN/6S lookup that generates the matched-filter target spectrum t in MAG1C and Foote et al. 2021’s generalised MF. Because t now has analytic gradients in atmospheric state, you can implement an albedo- and water-corrected MF that linearises t around per-pixel (or per-cluster) state estimates — a smooth interpolation between MAG1C-class MFs and full IMAP-DOAS. v1+v2 enables full Rodgers optimal estimation (state = layer VMRs of CH₄/CO₂/H₂O/N₂O/CO, T-shift, surface polynomial, spectral shift) with K-matrix Jacobians for free. Benchmark against Thorpe et al. 2023 (Sci. Adv.) on EMIT plume scenes, Thorpe et al. 2017 IMAP-DOAS on AVIRIS-NG, and MAG1C on synthetic plumes.

TROPOMI-style operational (Sentinel-5P)

v3–v4 is required: TROPOMI retrievals need aerosol/cirrus scattering and the O₂ A-band 760 nm + SWIR 2.3 µm joint inversion (RemoTeC/SRON; Lorente et al. 2021 AMT). Build the SCIATRAN/VLIDORT analogue: pseudo-spherical multi-stream discrete ordinates + Raman/Ring (v5), with retrieval state {profile XCH₄, aerosol height/AOT/effective size, surface polynomial}. DOAS-style retrievals fall out as a special case: WFM-DOAS = linearised differential cross-section + low-order polynomial baseline; IMAP-DOAS = full OE in a narrow SWIR window. Use the same Equinox forward model and pick the loss function and regularisation accordingly.

Multispectral methane (Sentinel-2, Landsat)

v0 suffices: Beer–Lambert with effective AMF (1/μ₀ + 1/μ) and Sentinel-2 B11/B12 SRFs. Implement Varon et al. 2021 (AMT 14, 2771) SBMP/MBSP/MBMP exactly; the advantage of doing this in your differentiable stack is that the same forward model trains a Sentinel-2 NN detector (MethaNet/STARCOP-style) end-to-end with physical regularisation. Connect to plumax for synthetic plume injection.

Cloud screening / detection

Downstream classifier on top of v3 outputs: cloud-fraction retrieval from the v3 single-scattering forward model, with detection thresholds calibrated against MODIS/VIIRS cloud masks. Differentiable forward enables joint cloud-fraction + trace-gas retrieval rather than sequential masking.

Plume-to-emission coupling (the MARS payoff)

The end-to-end gradient ∂y/∂e (TOA radiance with respect to emission rate) is the unique capability that motivates the entire stack. Construction: plumax(e, u, x_source) → ΔVMR(x); RTM(state + ΔVMR, geometry, instrument) → y_pred; loss(y_pred, y_obs). One call to eqx.filter_grad gives ∂loss/∂{e, u, x_source, atmospheric_nuisance, instrument_nuisance}. This is the basis for maximum-likelihood emission retrieval that does not pass through the IME/CSF heuristics of Varon et al. 2018 — those become Gauss–Newton initialisations.


6. Validation strategy

Per-stage validation matrix:

StageReferenceBenchmarkTarget accuracy
v0py4CAtS, HAPIAFGL US Standard transmittance, 2210–2410 nm<0.1% relative
v0MAG1C / Foote 2020Synthetic plume RMSE on Permian/Turkmenistan<2% of MAG1C
v1EMIT operational (Thorpe 2023)Real plume XCH₄ retrievals<5 ppb median bias
v2py4CAtS, HAPI, LBLRTMAFGL Tropical/MidLatSummer/SubArcticWinter at 0.01 cm⁻¹ across 700–4000 cm⁻¹<0.1% relative
v2HITRAN line-by-lineSingle-line Voigt at edge cases (high γ_L/γ_D, near continuum)<10⁻⁶ absolute
v36SV, Coulson tablesRayleigh + Lambertian TOA radiance<0.05%
v3miepythonMie efficiencies and asymmetry parameter<10⁻⁴ relative
v4DISORT, VLIDORTAFGL + aerosol cases, K-matrix Jacobians<0.5% radiance, <2% Jacobian
v4CIRC Case 1–7 (Oreopoulos 2012)Clear-sky LW/SW fluxesmatch LBLRTM within CIRC tolerance
v5I3RC, Mitsuba 3 / Eradiate3-D cumulus casestolerable MC noise band
Allfinite differencesgradcheck on small reduced states<10⁻⁴ relative gradient error

Standard benchmarks beyond the table: AFGL 1986 atmospheres (rayference/afgl1986); CIRC clear-sky cases (Oreopoulos et al. 2012 JGR D06118); RAMI for surface BRDF; HITRAN-2020/2024 (Gordon et al. 2022 JQSRT 277, 107949); MT_CKD v4 H₂O continuum (Mlawer et al. 2023); Coulson–Dave–Sekera tables for polarised Rayleigh.

Differentiability tests: gradcheck on every public forward function; reciprocity (BRDF and bidirectional radiance); energy conservation in non-absorbing limits (sum of fluxes = 1); reproducibility of VLIDORT analytical Jacobians on Spurr’s published test suite where accessible.


6b. Differentiability strategy, component by component

Where naive AD works: Voigt–Hjerting function via jnp (per ExoJAX), all line-strength/broadening formulas, Chapman factor, Rayleigh σ, Henyey–Greenstein, Lambertian and most kernel BRDFs, layer-wise jax.lax.scan for two-stream/SOS, FFT convolution for SRF and MODIT-style opacity. Use eqx.filter_grad directly.

Where custom_jvp / custom_vjp is needed:

Where implicit differentiation through linear/non-linear solvers is needed:

Where checkpointing / jax.checkpoint is needed for memory:

Jacobians (K-matrix) vs adjoint:


7. v0 starter code sketch

from __future__ import annotations
from typing import Protocol
import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float

# --- Pytree state ---------------------------------------------------------

class AtmosphericState(eqx.Module):
    pressure:    Float[Array, "n_layer"]
    temperature: Float[Array, "n_layer"]
    altitude:    Float[Array, "n_layer_plus_one"]
    vmr_ch4:     Float[Array, "n_layer"]
    vmr_h2o:     Float[Array, "n_layer"]
    vmr_co2:     Float[Array, "n_layer"]
    vmr_n2o:     Float[Array, "n_layer"]

class Geometry(eqx.Module):
    mu0: Float[Array, ""]          # cos(SZA)
    mu:  Float[Array, ""]          # cos(VZA)

class LambertianSurface(eqx.Module):
    albedo: Float[Array, "n_nu"]   # spectral albedo

class Instrument(eqx.Module):
    nu_hi:        Float[Array, "n_nu_hi"]      = eqx.field(static=True)
    nu_sensor:    Float[Array, "n_band"]       = eqx.field(static=True)
    srf_kernel:   Float[Array, "n_band n_nu_hi"] = eqx.field(static=True)
    spectral_shift: Float[Array, ""]           # learnable wavenumber shift

    def convolve(self, radiance_hi):
        # SRF integration; SRF kernel pre-normalised
        shifted = jnp.interp(self.nu_hi + self.spectral_shift,
                             self.nu_hi, radiance_hi)
        return self.srf_kernel @ shifted

# --- Opacity provider (v0: precomputed cross-section LUT) -----------------

class OpacityProvider(Protocol):
    def __call__(self, state: AtmosphericState,
                 nu: Float[Array, "n_nu"]) -> Float[Array, "n_layer n_nu"]:
        ...

class LookupTableOpacity(eqx.Module):
    # cross sections σ_g(p, T, ν) precomputed on a coarse (p,T) grid
    p_grid:  Float[Array, "n_p"]              = eqx.field(static=True)
    T_grid:  Float[Array, "n_T"]              = eqx.field(static=True)
    nu_grid: Float[Array, "n_nu"]             = eqx.field(static=True)
    sigma:   dict[str, Float[Array, "n_p n_T n_nu"]]    # static after loading

    def __call__(self, state, nu):
        # bilinear interpolation in (log p, T) for each layer and gas
        layer_tau = jnp.zeros((state.pressure.shape[0], nu.shape[0]))
        dz = jnp.diff(state.altitude)
        N_air = state.pressure / (1.380649e-23 * state.temperature)  # m^-3
        for gas, vmr in [("CH4", state.vmr_ch4), ("H2O", state.vmr_h2o),
                         ("CO2", state.vmr_co2), ("N2O", state.vmr_n2o)]:
            sig = bilinear(self.sigma[gas], jnp.log(state.pressure),
                           state.temperature, self.p_grid, self.T_grid)
            #  sig shape: (n_layer, n_nu)
            layer_tau = layer_tau + sig * (vmr * N_air * dz)[:, None]
        return layer_tau

def bilinear(table, x, y, x_grid, y_grid):
    # placeholder: AD-safe bilinear interpolator over (x, y) per layer
    ...

# --- v0 forward model: Beer–Lambert TOA radiance --------------------------

class BeerLambertSolver(eqx.Module):
    solar_irradiance: Float[Array, "n_nu_hi"]  = eqx.field(static=True)

    def __call__(self, tau_layer, surface, geom):
        tau_col = tau_layer.sum(axis=0)                  # (n_nu_hi,)
        amf = 1.0 / geom.mu0 + 1.0 / geom.mu
        T = jnp.exp(-amf * tau_col)
        # Lambertian TOA radiance (no scattering)
        L = self.solar_irradiance * geom.mu0 / jnp.pi * surface.albedo * T
        return L

class ForwardModel(eqx.Module):
    opacity:    eqx.AbstractVar[OpacityProvider]
    solver:     BeerLambertSolver
    instrument: Instrument

    def __call__(self, state, surface, geom):
        tau = self.opacity(state, self.instrument.nu_hi)
        L_hi = self.solver(tau, surface, geom)
        L_sensor = self.instrument.convolve(L_hi)
        return L_sensor

# --- JIT + Jacobians ------------------------------------------------------

@eqx.filter_jit
def predict(fm: ForwardModel, state, surface, geom):
    return fm(state, surface, geom)

# K-matrix Jacobian w.r.t. selected scalar parameters of the state
def K_xch4_shift_albedo(fm, state, surface, geom):
    def f(xch4_scale, shift, alb_scale):
        s = eqx.tree_at(lambda s: s.vmr_ch4, state, state.vmr_ch4 * xch4_scale)
        i = eqx.tree_at(lambda i: i.spectral_shift, fm.instrument, shift)
        srf = eqx.tree_at(lambda s: s.albedo, surface, surface.albedo * alb_scale)
        fm2 = eqx.tree_at(lambda m: m.instrument, fm, i)
        return fm2(s, srf, geom)
    return jax.jacrev(f, argnums=(0, 1, 2))(1.0, 0.0, 1.0)

# Batched retrieval over an EMIT swath
batched_predict = eqx.filter_vmap(predict, in_axes=(None, 0, 0, 0))

This skeleton is JIT-compatible, jacrev/jacfwd-clean, and structurally extends to v1–v4: replace LookupTableOpacity with PreModitOpacity; replace BeerLambertSolver with SingleScatteringSolver or DiscreteOrdinatesSolver; replace LambertianSurface with RossLiBRDF. The retrieval driver (Levenberg–Marquardt with Lineax for normal equations, or NumPyro HMC) sits on top unchanged.


8. Reading list and benchmarks

Foundational textbooks. Liou, An Introduction to Atmospheric Radiation (2nd ed., 2002); Thomas & Stamnes, Radiative Transfer in the Atmosphere and Ocean (Cambridge, 2nd ed., 2017) — the canonical pedagogical references for v1–v4; Mishchenko, Travis & Lacis, Scattering, Absorption, and Emission of Light by Small Particles (Cambridge, 2002, free PDF) for vector RT and particle scattering; Bohren & Huffman, Absorption and Scattering of Light by Small Particles (Wiley, 1983) for Mie; Rodgers, Inverse Methods for Atmospheric Sounding (World Scientific, 2000) — every retrieval choice in this design ultimately answers to chapter 4 of Rodgers; Chandrasekhar, Radiative Transfer (Dover, 1960) for the analytic single-scattering and Rayleigh foundations the Korkin paper builds on.

Spectroscopy. Gordon et al. 2022 JQSRT 277, 107949 (HITRAN-2020); Rothman et al. 2010 JQSRT 111, 2139 (HITEMP); Mlawer et al. 2023 JQSRT 306, 108645 (MT_CKD v4); Schreier et al. 2019 Atmosphere 10 (py4CAtS); Kochanov et al. 2016 JQSRT 177, 15 (HAPI).

Classical RT codes. Stamnes et al. 1988 Appl. Opt. 27, 2502 (DISORT); Spurr 2006 JQSRT 102, 316 (VLIDORT — read this for the linearisation playbook); Spurr & Christi 2014 (profile vs bulk Jacobians); Eriksson et al. 2011 JQSRT 112, 1551 and Buehler et al. 2005 (ARTS); Mayer & Kylling 2005 ACP 5, 1855 and Emde et al. 2016 GMD 9, 1647 (libRadtran); Rozanov et al. 2014 JQSRT 133, 13 (SCIATRAN); Bourassa et al. 2008 JQSRT 109, 52 (SASKTRAN); Saunders et al. 2018 GMD 11, 2717 (RTTOV); Evans 1998 JAS 55, 429 (SHDOM); Kotchenova & Vermote 2007 Appl. Opt. 46, 4455 (6SV); Berk et al. 2014 SPIE 9088 (MODTRAN6).

Differentiable / modern. Kawahara et al. 2022 ApJS 258, 31 and Kawahara et al. 2025 ApJ (ExoJAX I, II); Ukkonen 2020 JAMES and Ukkonen et al. 2023 GMD 16, 3241 (RRTMGP-NN); Zhang et al. 2019 ACM TOG 38, “Differential theory of radiative transfer”; Salesin et al. 2024 JQSRT 314, 108847 (differentiable atmosphere–ocean Mitsuba 3); Doicu & Efremenko 2019 MDPI Atmosphere (linearised 3-D SHDOM); Brodrick et al. 2021 RSE (sRTMnet); Verrelst et al. 2016 RSE (GP emulators); Larosa et al. 2024 GMD 17, 2053 (PyRTlib); Jackson et al. 2025 APL Photonics 11, 046114 (PyMieDiff).

Methane retrievals. Foote et al. 2020 IEEE TGRS 58 (MAG1C); Foote et al. 2021 RSE (scene-specific MF); Thorpe et al. 2014, 2017 AMT (IMAP-DOAS for AVIRIS-NG); Thorpe et al. 2023 Sci. Adv. 9, eadh2391 (EMIT operational); Lorente et al. 2021 AMT 14, 665 (RemoTeC TROPOMI); Varon et al. 2018 AMT 11, 5673 (IME) and 2021 AMT 14, 2771 (Sentinel-2); Jongaramrungruang et al. 2022 RSE (MethaNet); Cusworth et al. 2022 PNAS (PRISMA point sources); Chan Miller et al. 2024 AMT 17, 5429 (MethaneSAT proxy).

Benchmarks. Anderson et al. 1986 AFGL-TR-86-0110 (atmospheres; CSV at github.com/rayference/afgl1986); Oreopoulos et al. 2012 JGR 117, D06118 (CIRC); Cahalan et al. 2005 BAMS 86 (I3RC); RAMI at rami-benchmark.jrc.ec.europa.eu.

The user’s reference paper. Korkin, Sayer, Ibrahim & Lyapustin 2022 Comp. Phys. Comm. 271, 108198. Code: github.com/korkins/gsit. Read for the modular skeleton and the “make it right then make it fast” philosophy; do not expect spectroscopy or differentiability lessons.


Conclusion: where the bets are

The strongest single bet in this design is to make OpticalProperties = {τ, ω, B_ℓ, surface_kernels} the canonical differentiable interface and let jax.jacrev replace VLIDORT’s hand-derived K-matrix machinery. This is the architectural insight that makes a JAX/Equinox RTM strictly better than the heritage Fortran codes for retrievals — not faster (VLIDORT is fast), not more accurate (DISORT is the reference), but vastly easier to evolve: every change in spectroscopy, surface model, or aerosol parameterisation gets exact Jacobians for free, with no derivative-code refactor. The risks are concentrated in two places — eigendecomposition near-degeneracy in v4 and Voigt/Mie gradient stability — and both have published mitigations (VLIDORT Taylor branches, PyMieDiff recurrences) that port cleanly into jax.custom_jvp.

The closest existing system is ExoJAX, and the cleanest description of the project’s contribution is “ExoJAX for Earth, in Equinox, coupled to plumax for end-to-end emission retrieval”. The largest gap in the literature it fills is an open, autodiff-native, retrieval-grade Earth-RS RTM that the methane community can iterate without proprietary or registration-gated tools. Stage v0 ships in weeks and immediately replaces MODTRAN/6S in matched-filter target generation; v2 closes the spectroscopy gap with HITRAN; v4 closes the scattering gap with DISORT/VLIDORT-equivalence; v5+ opens the door to differentiable 3-D RT for cloud tomography and joint plume/atmosphere inversion that no operational code currently offers.