Layer 1 — Observation Operators¶
Per Decision D9, averaging kernel and multi-instrument fusion are day-one
operators, not deferred features. Every class here satisfies
pipekit_cycle.ObservationOperator directly (Decision D8): __call__(state)
returns predicted observations, and linearize(state) returns an
AbstractLinearOperator representing the tangent-linear operator
\(H' = \partial H / \partial x\).
Protocol satisfaction¶
from pipekit_cycle import ObservationOperator
from lineax import JacobianLinearOperator, AbstractLinearOperator
# Every vardax obs operator implements:
#
# def __call__(self, state: Array) -> Array: ...
# def linearize(self, state: Array) -> AbstractLinearOperator: ...
#
# The default `linearize` uses `lineax.JacobianLinearOperator` (autodiff
# Jacobian). Operators with structure may override with a structured op
# (gaussx Kronecker, LowRank, BlockDiag) for efficiency.
MaskedIdentity¶
Default operator for any "we observe the state with gaps" use case.
class MaskedIdentity(eqx.Module):
"""H(x) = mask ⊙ x."""
def __call__(self, x: Array, mask: Array | None = None) -> Array:
return x * mask if mask is not None else x
def linearize(self, x: Array) -> AbstractLinearOperator:
return JacobianLinearOperator(self.__call__, x)
Use case: SSH altimetry with along-track gaps, SST with cloud masks.
LinearObs¶
class LinearObs(eqx.Module):
"""H(x) = H_mat @ x."""
H_mat: AbstractLinearOperator # gaussx / lineax operator
def __call__(self, x: Array) -> Array:
return self.H_mat @ x
def linearize(self, x: Array) -> AbstractLinearOperator:
return self.H_mat # linear → Jacobian is the operator itself
Use case: any pre-computed linear projection (interpolation matrix, footprint matrix from Lagrangian transport).
AveragingKernel¶
Decision D9: first-class day-one operator for any satellite L2 product.
Mathematical formulation¶
The averaging kernel maps a model state \(x\) (mixing-ratio profile, surface field) to the L2 retrieval-equivalent observation:
where:
- \(x \in \mathbb{R}^N\) — model state (gas profile, surface field)
- \(x_a \in \mathbb{R}^N\) — retrieval prior (a priori state from L2 metadata)
- \(h \in \mathbb{R}^N\) — weighting vector (often pressure-weighted;
h = 1for surface-only obs) - \(A \in \mathbb{R}^{N \times N}\) — averaging kernel matrix
- \(\hat{y}\) — predicted L2 observation
This operator is mandatory for any RTM-derived satellite product (TROPOMI CH₄, EMIT CH₄, OCO CO₂, MOPITT CO, …). Skipping it is the most common cause of bias in operational inversions.
Class contract¶
class AveragingKernel(eqx.Module):
A: AbstractLinearOperator # gaussx — may be Kronecker / LowRank / dense
x_a: Float[Array, "N"] # retrieval prior
h: Float[Array, "N"] # weighting vector
def __call__(self, x: Array) -> Array:
return self.A @ (self.h * x + (1.0 - self.h) * self.x_a)
def linearize(self, x: Array) -> AbstractLinearOperator:
# H'(x) = A · diag(h) — linear in x
return self.A @ DiagonalLinearOperator(self.h)
A is stored as an AbstractLinearOperator so structure (e.g. Kronecker
for separable kernels) can be exploited by gaussx.ops.solve.
Construction from L2 metadata¶
def averaging_kernel_from_l2(ds: xr.Dataset, *,
ak_var: str = "averaging_kernel",
prior_var: str = "x_a",
weight_var: str = "h",
) -> AveragingKernel:
"""Build AveragingKernel from L2 sidecar metadata."""
A_dense = jnp.asarray(ds[ak_var].values)
return AveragingKernel(
A=lineax.MatrixLinearOperator(A_dense),
x_a=jnp.asarray(ds[prior_var].values),
h=jnp.asarray(ds[weight_var].values),
)
MultiInstrumentFusion¶
Decision D9: compose per-instrument operators at the likelihood level. No pre-regridding, no shared coordinate system imposed.
Mathematical formulation¶
Each instrument \(i \in \{TROPOMI, EMIT, GHGSat, \ldots\}\) has its own operator \(H_i\), mask \(m_i\), error covariance \(R_i\). The fused observation cost:
Per-pixel instrument index on Batch* selects the operator. Quality
masks zero-weight unreliable pixels (no contribution to the likelihood, not
dropped).
Class contract¶
class MultiInstrumentFusion(eqx.Module):
"""Compose per-instrument ObservationOperators at the likelihood level."""
registry: InstrumentRegistry
weights: dict[str, float] | None = None # per-instrument α_i; default uniform
def __call__(self, x: Array, batch: Batch) -> dict[str, Array]:
"""Return per-instrument predicted observations."""
return {
inst_id: spec.obs_op(x)
for inst_id, spec in self.registry.entries.items()
}
def linearize(self, x: Array) -> dict[str, AbstractLinearOperator]:
"""Per-instrument tangent-linear operators (typically block-diagonal across instruments)."""
return {
inst_id: spec.obs_op.linearize(x)
for inst_id, spec in self.registry.entries.items()
}
Note: MultiInstrumentFusion.__call__ returns a dict[str, Array], not
a single Array — the cost function consumes the dict and combines at the
likelihood level. This is a slight departure from the strict
ObservationOperator.__call__(state) → obs signature, so vardax exposes
a to_observation_operator() adapter that flattens to a single output for
strict-protocol contexts.
InstrumentRegistry¶
class InstrumentSpec(eqx.Module):
"""Per-instrument (A, x_a, h, mask, R) tuple."""
obs_op: ObservationOperator # often AveragingKernel
mask: Float[Array, "..."] # quality mask
R_op: AbstractLinearOperator # obs-err covariance
instrument_id: str = eqx.field(static=True)
class InstrumentRegistry(eqx.Module):
"""Keyed lookup of InstrumentSpec by instrument_id."""
entries: dict[str, InstrumentSpec]
@classmethod
def from_l2_dict(cls, l2_datasets: dict[str, xr.Dataset]) -> "InstrumentRegistry":
"""Build registry from per-instrument L2 datasets."""
entries = {}
for inst_id, ds in l2_datasets.items():
entries[inst_id] = InstrumentSpec(
obs_op=averaging_kernel_from_l2(ds),
mask=jnp.asarray(ds["qa_flag"].values),
R_op=lineax.DiagonalLinearOperator(jnp.asarray(ds["xch4_uncertainty"].values)),
instrument_id=inst_id,
)
return cls(entries=entries)
Bias-aware fusion (Epic 9, planned)¶
For multi-instrument joint inversions where instruments may disagree systematically:
class BiasAwareFusion(eqx.Module):
"""MultiInstrumentFusion with per-instrument bias as joint state."""
base: MultiInstrumentFusion
bias_prior: dict[str, GaussianPrior] # b_i ~ N(0, σ_b²) per instrument
def __call__(self, x: Array, bias: dict[str, Array], batch: Batch) -> dict[str, Array]:
return {
inst_id: self.base.registry.entries[inst_id].obs_op(x) + bias[inst_id]
for inst_id in self.base.registry.entries
}
Per-instrument bias becomes a joint state element in IncrementalFourDVar —
Posterior.mean carries both state and bias estimates.
Spectral / Fourier observation operators (planned)¶
For frequency-domain observations (along-track altimetry spectra, spectral SST products):
class SpectralObs(eqx.Module):
"""H(x) = F^{-1} · S · F · x — selective frequency observation."""
F: AbstractLinearOperator # FFT operator
S: AbstractLinearOperator # frequency selection / weighting
...
Reserved for Epic 4 follow-up; not in the v0.3 surface.
Test conformance¶
Every observation operator passes tests/test_pipekit_protocols.py:
def test_obs_op_satisfies_protocol(obs_op: ObservationOperator):
assert isinstance(obs_op, ObservationOperator)
x = sample_state()
y = obs_op(x)
assert y.shape == expected_obs_shape
H_lin = obs_op.linearize(x)
assert isinstance(H_lin, AbstractLinearOperator)
# Adjoint test:
u, v = sample_random_arrays()
forward = H_lin @ u
backward = H_lin.T @ v
assert jnp.allclose(jnp.dot(forward, v), jnp.dot(u, backward), atol=1e-5)