Skip to content

Observation Operators

Observation operators map model state to observation space. All of them satisfy the pipekit-cycle ObservationOperator Protocol — __call__(state) produces the predicted observation, and linearize(state) returns the tangent-linear (Jacobian) operator that the incremental and posterior-covariance machinery needs. Because conformance is structural, custom operators plug in by matching that signature; nothing here needs to be subclassed. See Observation Operators in the Mathematical Reference for the modelling background.

Core operators

LinearObs wraps an explicit observation matrix; MaskedIdentity handles the ubiquitous "observe a subset of grid points" case (satellite tracks, sparse sensor networks); AveragingKernel implements the smoothing kernels of retrieval products such as atmospheric-composition column averages.

vardax — Modular variational data assimilation with learned components.

All public symbols are re-exported from the private _src subpackage so that user code imports from the top-level namespace:

import vardax

model = vardax.FourDVarNet1D(...)

LinearObs

Bases: Module

H(x) = H_mat @ x.

Examples:

>>> import jax.numpy as jnp, lineax as lx
>>> import vardax as vdx
>>> H = lx.MatrixLinearOperator(jnp.eye(3)[:2])
>>> obs = vdx.LinearObs(H_mat=H)
>>> obs(jnp.array([1.0, 2.0, 3.0]))
Array([1., 2.], dtype=float32)
>>> obs.linearize(jnp.zeros(3)) is H
True
Source code in src/vardax/_src/obs_operators/linear.py
class LinearObs(eqx.Module):
    """``H(x) = H_mat @ x``.

    Examples:
        >>> import jax.numpy as jnp, lineax as lx
        >>> import vardax as vdx
        >>> H = lx.MatrixLinearOperator(jnp.eye(3)[:2])
        >>> obs = vdx.LinearObs(H_mat=H)
        >>> obs(jnp.array([1.0, 2.0, 3.0]))
        Array([1., 2.], dtype=float32)
        >>> obs.linearize(jnp.zeros(3)) is H
        True
    """

    H_mat: lx.AbstractLinearOperator

    def __call__(self, x: Float[Array, ...]) -> Float[Array, ...]:
        return self.H_mat.mv(x)

    def linearize(self, x: Float[Array, ...]) -> lx.AbstractLinearOperator:
        """Tangent-linear operator at ``x``: just ``H_mat`` (linear)."""
        return self.H_mat

linearize

linearize(x: Float[Array, ...]) -> AbstractLinearOperator

Tangent-linear operator at x: just H_mat (linear).

Source code in src/vardax/_src/obs_operators/linear.py
def linearize(self, x: Float[Array, ...]) -> lx.AbstractLinearOperator:
    """Tangent-linear operator at ``x``: just ``H_mat`` (linear)."""
    return self.H_mat

MaskedIdentity

Bases: Module

\(H(x) = m \odot x\) with optional element-wise mask.

Stateless: no parameters. The mask can be supplied at call time (per-batch) rather than baked in.

Examples:

>>> import jax.numpy as jnp
>>> import vardax as vdx
>>> op = vdx.MaskedIdentity()
>>> x = jnp.array([1.0, 2.0, 3.0])
>>> op(x, mask=jnp.array([1.0, 0.0, 1.0]))
Array([1., 0., 3.], dtype=float32)
>>> type(op.linearize(x)).__name__
'JacobianLinearOperator'
Source code in src/vardax/_src/obs_operators/masked.py
class MaskedIdentity(eqx.Module):
    r"""$H(x) = m \odot x$ with optional element-wise mask.

    Stateless: no parameters. The mask can be supplied at call time
    (per-batch) rather than baked in.

    Examples:
        >>> import jax.numpy as jnp
        >>> import vardax as vdx
        >>> op = vdx.MaskedIdentity()
        >>> x = jnp.array([1.0, 2.0, 3.0])
        >>> op(x, mask=jnp.array([1.0, 0.0, 1.0]))
        Array([1., 0., 3.], dtype=float32)
        >>> type(op.linearize(x)).__name__
        'JacobianLinearOperator'
    """

    def __call__(
        self,
        x: Float[Array, ...],
        mask: Float[Array, ...] | None = None,
    ) -> Float[Array, ...]:
        if mask is None:
            return x
        return x * mask

    def linearize(self, x: Float[Array, ...]) -> lx.AbstractLinearOperator:
        """Tangent-linear operator at ``x``: identity (since $H$ is linear)."""
        return lx.JacobianLinearOperator(lambda y, _args=None: y, x)

linearize

linearize(x: Float[Array, ...]) -> AbstractLinearOperator

Tangent-linear operator at x: identity (since \(H\) is linear).

Source code in src/vardax/_src/obs_operators/masked.py
def linearize(self, x: Float[Array, ...]) -> lx.AbstractLinearOperator:
    """Tangent-linear operator at ``x``: identity (since $H$ is linear)."""
    return lx.JacobianLinearOperator(lambda y, _args=None: y, x)

AveragingKernel

Bases: Module

RTM L2-style averaging-kernel obs operator.

Attributes:

Name Type Description
A AbstractLinearOperator

Averaging kernel matrix, stored as an lineax.AbstractLinearOperator. May be MatrixLinearOperator for dense kernels or a structured operator (Kronecker for separable kernels, LowRank for retrievals with low DOF).

x_a Float[Array, N]

Retrieval prior, same shape as the model state.

h Float[Array, N]

Weighting vector (often pressure-weighted).

Source code in src/vardax/_src/obs_operators/averaging_kernel.py
class AveragingKernel(eqx.Module):
    """RTM L2-style averaging-kernel obs operator.

    Attributes:
        A: Averaging kernel matrix, stored as an
            ``lineax.AbstractLinearOperator``. May be ``MatrixLinearOperator``
            for dense kernels or a structured operator (Kronecker for
            separable kernels, LowRank for retrievals with low DOF).
        x_a: Retrieval prior, same shape as the model state.
        h: Weighting vector (often pressure-weighted).
    """

    A: lx.AbstractLinearOperator
    x_a: Float[Array, N]  # ty:ignore[unresolved-reference]
    h: Float[Array, N]  # ty:ignore[unresolved-reference]

    def __call__(self, x: Float[Array, N]) -> Float[Array, N]:  # ty:ignore[unresolved-reference]
        inner = self.h * x + (1.0 - self.h) * self.x_a
        return self.A.mv(inner)

    def linearize(self, x: Float[Array, N]) -> lx.AbstractLinearOperator:  # ty:ignore[unresolved-reference]
        r"""Tangent-linear operator at ``x``: $A \cdot \mathrm{diag}(h)$."""
        diag_h = lx.DiagonalLinearOperator(self.h)
        return self.A @ diag_h

linearize

linearize(x: Float[Array, N]) -> AbstractLinearOperator

Tangent-linear operator at x: \(A \cdot \mathrm{diag}(h)\).

Source code in src/vardax/_src/obs_operators/averaging_kernel.py
def linearize(self, x: Float[Array, N]) -> lx.AbstractLinearOperator:  # ty:ignore[unresolved-reference]
    r"""Tangent-linear operator at ``x``: $A \cdot \mathrm{diag}(h)$."""
    diag_h = lx.DiagonalLinearOperator(self.h)
    return self.A @ diag_h

Multi-instrument fusion

Assimilating several instruments at once: each instrument is described by an InstrumentSpec, registered in an InstrumentRegistry, and MultiInstrumentFusion stacks the per-instrument operators into a single observation operator over the concatenated observation vector.

vardax — Modular variational data assimilation with learned components.

All public symbols are re-exported from the private _src subpackage so that user code imports from the top-level namespace:

import vardax

model = vardax.FourDVarNet1D(...)

InstrumentSpec

Bases: Module

Per-instrument (obs_op, mask, R_op, id) bundle.

Attributes:

Name Type Description
obs_op Any

ObservationOperator-conforming operator (typically AveragingKernel).

mask Float[Array, ...]

Quality mask of shape compatible with the instrument's observation space. 1 for valid pixels, 0 for flagged/dropped.

R_op AbstractLinearOperator

Observation-error covariance as a lineax.AbstractLinearOperator. Often a DiagonalLinearOperator keyed on the per-pixel retrieval uncertainty.

instrument_id str

Identifier (e.g. "TROPOMI", "EMIT", "GHGSat").

Source code in src/vardax/_src/obs_operators/multi_instrument.py
class InstrumentSpec(eqx.Module):
    """Per-instrument ``(obs_op, mask, R_op, id)`` bundle.

    Attributes:
        obs_op: ``ObservationOperator``-conforming operator (typically
            [`AveragingKernel`][vardax.AveragingKernel]).
        mask: Quality mask of shape compatible with the instrument's
            observation space. ``1`` for valid pixels, ``0`` for
            flagged/dropped.
        R_op: Observation-error covariance as a
            ``lineax.AbstractLinearOperator``. Often a
            ``DiagonalLinearOperator`` keyed on the per-pixel
            retrieval uncertainty.
        instrument_id: Identifier (e.g. ``"TROPOMI"``, ``"EMIT"``,
            ``"GHGSat"``).
    """

    obs_op: Any  # ObservationOperator (we don't import the Protocol to avoid a cycle)
    mask: Float[Array, ...]
    R_op: lx.AbstractLinearOperator
    instrument_id: str = eqx.field(static=True)

InstrumentRegistry

Bases: Module

Keyed lookup of InstrumentSpec by instrument_id.

Attributes:

Name Type Description
entries dict[str, InstrumentSpec]

dict[instrument_id, InstrumentSpec].

Source code in src/vardax/_src/obs_operators/multi_instrument.py
class InstrumentRegistry(eqx.Module):
    """Keyed lookup of ``InstrumentSpec`` by ``instrument_id``.

    Attributes:
        entries: ``dict[instrument_id, InstrumentSpec]``.
    """

    entries: dict[str, InstrumentSpec] = eqx.field(default_factory=dict)

MultiInstrumentFusion

Bases: Module

Compose per-instrument operators at the likelihood level.

__call__ returns dict[instrument_id, predicted_obs] — one array per instrument. The cost function consumes the dict and sums per-instrument terms with their respective \(R_i^{-1}\). There is no shared coordinate system; each instrument keeps its native footprint and resolution.

For strict pipekit_cycle.ObservationOperator contexts where a single observation vector + single linear operator are required, call .to_observation_operator() for a flattened wrapper.

Attributes:

Name Type Description
registry InstrumentRegistry

Per-instrument InstrumentRegistry.

weights dict[str, float] | None

Optional {instrument_id: alpha} mapping. None ⇒ uniform \(\alpha_i = 1\).

Source code in src/vardax/_src/obs_operators/multi_instrument.py
class MultiInstrumentFusion(eqx.Module):
    r"""Compose per-instrument operators at the likelihood level.

    ``__call__`` returns ``dict[instrument_id, predicted_obs]`` — one
    array per instrument. The cost function consumes the dict and
    sums per-instrument terms with their respective $R_i^{-1}$.
    There is no shared coordinate system; each instrument keeps its
    native footprint and resolution.

    For strict ``pipekit_cycle.ObservationOperator`` contexts where a
    single observation vector + single linear operator are required,
    call ``.to_observation_operator()`` for a flattened wrapper.

    Attributes:
        registry: Per-instrument
            [`InstrumentRegistry`][vardax.InstrumentRegistry].
        weights: Optional ``{instrument_id: alpha}`` mapping. ``None``
            ⇒ uniform $\alpha_i = 1$.
    """

    registry: InstrumentRegistry
    weights: dict[str, float] | None = None

    def __call__(self, x: Float[Array, ...]) -> dict[str, Float[Array, ...]]:
        return {
            inst_id: spec.obs_op(x) for inst_id, spec in self.registry.entries.items()
        }

    def linearize(self, x: Float[Array, ...]) -> dict[str, lx.AbstractLinearOperator]:
        """Per-instrument tangent-linear operators.

        Returns ``{instrument_id: H_i'(x)}``. The fused tangent linear
        is the block-diagonal stack of these — assembled lazily by
        the cost function or via ``to_observation_operator()``.
        """
        return {
            inst_id: spec.obs_op.linearize(x)
            for inst_id, spec in self.registry.entries.items()
        }

    def to_observation_operator(self) -> _FlattenedMultiInstrument:
        """Adapt to the strict ``pipekit_cycle.ObservationOperator`` protocol.

        Returns a wrapper that concatenates per-instrument outputs and
        exposes a block-diagonal linear operator. Use this when the
        consumer requires a single ``(state) -> Array`` signature.
        """
        return _FlattenedMultiInstrument(fusion=self)

linearize

linearize(
    x: Float[Array, ...],
) -> dict[str, AbstractLinearOperator]

Per-instrument tangent-linear operators.

Returns {instrument_id: H_i'(x)}. The fused tangent linear is the block-diagonal stack of these — assembled lazily by the cost function or via to_observation_operator().

Source code in src/vardax/_src/obs_operators/multi_instrument.py
def linearize(self, x: Float[Array, ...]) -> dict[str, lx.AbstractLinearOperator]:
    """Per-instrument tangent-linear operators.

    Returns ``{instrument_id: H_i'(x)}``. The fused tangent linear
    is the block-diagonal stack of these — assembled lazily by
    the cost function or via ``to_observation_operator()``.
    """
    return {
        inst_id: spec.obs_op.linearize(x)
        for inst_id, spec in self.registry.entries.items()
    }

to_observation_operator

to_observation_operator() -> _FlattenedMultiInstrument

Adapt to the strict pipekit_cycle.ObservationOperator protocol.

Returns a wrapper that concatenates per-instrument outputs and exposes a block-diagonal linear operator. Use this when the consumer requires a single (state) -> Array signature.

Source code in src/vardax/_src/obs_operators/multi_instrument.py
def to_observation_operator(self) -> _FlattenedMultiInstrument:
    """Adapt to the strict ``pipekit_cycle.ObservationOperator`` protocol.

    Returns a wrapper that concatenates per-instrument outputs and
    exposes a block-diagonal linear operator. Use this when the
    consumer requires a single ``(state) -> Array`` signature.
    """
    return _FlattenedMultiInstrument(fusion=self)