Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Report 11 — `pipekit-train`: training pipelines

UNEP
IMEO
MARS

Report 11 — pipekit-train: training pipelines for emulators and amortized inference

StatusScoping proposal — committed
Reading time~22 min
Decisions locked inThin orchestration layer over existing training tools (Lightning, Equinox+Optax, Keras), not a fourth training-loop implementation. Monorepo development. Trained models are first-class pipekit.Operators.
AudienceAnyone reviewing the train-side pipeline machinery
Companion reportsReport 10 (pipekit-cycle), Report 12 (pipekit-experiment), Report 13 (statecatalog); existing reports for context

Part 1 — Where pipekit-train sits in the stack

   Domain libraries           geotoolz │ xr_toolz
                                  ▲
                                  │
   Infrastructure       ┌─ pipekit-cycle ─┐
                        │  pipekit-train  │ ← this report
                        │  pipekit-experiment │
                        └─ statecatalog ──┘
                                  ▲
                                  │
   Framework                   pipekit ◄── pipekit-array
                                  ▲
                                  │
   Substrate              georeader │ xarray ecosystem

pipekit-train is a thin orchestration layer, not a training-loop implementation. PyTorch Lightning, Equinox+Optax, Keras 3 already do the heavy training-loop work; adding a fourth is wrong. What’s missing is the glue that:

Three observations:

  1. The training pipeline is a YAML artifact, same as inference pipelines. A TrainingLoop config is just another operator graph. The “training run” is reproducible from the YAML + data hash + seed.

  2. The trained model is itself an operator. After training, the output is a pipekit.Operator (probably pipekit-array.ModelOp for traditional models, pipekit-jax.JaxModelOp for differentiable). Same composability as any other operator.

  3. The training-loop math is delegated. Lightning, Equinox+Optax, Keras 3 each have their own well-tested training-loop machinery. Pipekit-train wraps these; it doesn’t reimplement them.

Part 2 — Three training shapes this enables

The L0-L4 vision needs three categories of training pipeline, all unified under the same framework but with different data shapes and loss patterns.

2.1 Direct supervised training (data exists)

Standard ML: pairs of (input, target) come from a catalog of observations + labels.

Data source: CatalogDataset — iterates a catalog, applies a preprocessing pipeline, yields (x, y) pairs.

2.2 Emulator training (simulator exists)

Surrogate-model training: pairs are generated by running a forward model (ForwardModel from pipekit-cycle).

Data source: SimulationDataset — wraps a ForwardModel, samples parameters from a prior, evaluates the model, yields (params, simulated_output) pairs.

2.3 Amortized inference training (likelihood-free)

Simulation-based inference (SBI) / neural posterior estimation: train a conditional density estimator on (parameters, simulated_observations) pairs so that at inference time it can produce p(parameters | observations) for any new observation.

Data source: SimulationDataset again — same as emulator training but the network learns the inverse mapping. Algorithmic core from pyrox / gaussflowx.

Three categories, two data shapes. The same training-loop machinery serves all three; the difference is in what’s at the input/output ends.

Part 3 — What’s in pipekit-train

Five conceptual pieces, ~700 LOC of framework + adapters.

3.1 Source layout

pipekit-train/
  __init__.py             # public re-exports
  _src/
    dataset.py            # TrainingDataset, CachedDataset, CatalogDataset, SimulationDataset
    loss.py               # Loss Protocol + common losses (MSE, NLL, KL, custom)
    loop.py               # TrainingLoop, ValidationStep, EarlyStopping
    sweep.py              # HyperSweep, ParameterGrid
    callbacks.py          # Callbacks Protocol; common impls (Checkpoint, LogToExperiment)
    adapters/
      __init__.py
      lightning.py        # PyTorch Lightning adapter (extras-gated)
      equinox.py          # Equinox + Optax adapter (extras-gated)
      keras.py            # Keras 3 adapter (extras-gated)

3.2 The dataset abstractions (dataset.py)

class TrainingDataset(Operator):
    """Protocol-ish base: yields (input, target) pairs.
    
    Subclasses implement __iter__. The framework gives them caching,
    splitting, content-hashing, and pipekit composition for free.
    """
    seed: int = 0
    split: Literal["train", "val", "test"] = "train"
    
    def __iter__(self) -> Iterator[tuple[Any, Any]]: ...
    def content_hash(self) -> str:
        """Hash that identifies this dataset's contents.
        
        Used for caching and for the regulatory artifact.
        For catalogs: hash(catalog_uri, query, split, seed).
        For simulations: hash(forward_model_config, prior_config, n_samples, seed).
        """

class CachedDataset(TrainingDataset):
    """Disk-backed cache around any TrainingDataset.
    
    First epoch hits the source; subsequent epochs read from disk.
    Cache is keyed on the wrapped dataset's content_hash().
    """
    source: TrainingDataset
    cache_dir: str
    format: Literal["zarr", "parquet", "tfrecord"] = "zarr"

class CatalogDataset(TrainingDataset):
    """Yields (preprocessed_carrier, label) pairs from a catalog.
    
    Used for direct supervised training:
    - geocatalog over labeled scenes
    - geopatcher.SpatialPatcher for tiling
    - preprocessing pipeline applied per-chip
    """
    catalog: GeoCatalog | StateCatalog
    preprocess: Operator              # the inference-time preprocess pipeline
    target_op: Operator               # extracts label from the catalog row
    sampler: SpatialSampler | None    # geopatcher sampler for tiling

class SimulationDataset(TrainingDataset):
    """Yields (parameters, simulator_output) pairs.
    
    Used for emulator and amortized-inference training:
    - sample parameters from a prior
    - evaluate the forward model (pipekit-cycle.ForwardModel)
    - return (params, output)
    
    The output may be a full trajectory (via Cycle) or a single state
    (single forward step), depending on what the emulator should learn.
    """
    forward_model: ForwardModel       # from pipekit-cycle
    prior: Operator                   # samples parameter realizations
    n_samples: int
    cycle: Operator | None = None     # optional pipekit_cycle.Cycle for trajectories

The SimulationDataset is the bridge to pipekit-cycle. It’s what makes the “emulator replaces forward model” loop close: train the emulator on simulations of the forward model; the trained emulator drops into the same pipekit-cycle.Cycle with the same observation operators.

3.3 The loss protocol (loss.py)

@runtime_checkable
class Loss(Protocol):
    """Loss function. Computed per-batch by the underlying training tool."""
    def __call__(self, predicted: Any, target: Any) -> float: ...

# Common implementations
class MSE(Operator): pass
class NLL(Operator):
    """Negative log-likelihood for distributional outputs.
    
    Used for emulator training when the emulator predicts a distribution
    (mean + variance), or for amortized inference when the posterior is
    explicitly a density.
    """

class KL(Operator):
    """KL divergence — for variational training."""

class Composite(Operator):
    """Weighted sum: loss = sum(w_i * L_i). Used for multi-objective."""
    components: list[tuple[float, Loss]]

The loss is a pipekit.Operator so it composes with the rest of the pipeline. The training tool (Lightning, Optax) sees it through its adapter as a callable.

3.4 The training loop (loop.py)

The headline composable. The training loop is itself a StatefulOperator (from Report 2 Group M) — its carry state is (model, optimizer_state, step, epoch, metrics).

class TrainingLoop(StatefulOperator):
    """Train a model_op on a dataset, producing a trained model_op.
    
    Delegates to one of the backend adapters (Lightning, Equinox+Optax, Keras 3).
    Logs to experiment tracker via callbacks (Report 12).
    
    Output: a trained model_op (pipekit-array.ModelOp or pipekit-jax.JaxModelOp)
    that drops into inference pipelines.
    """
    model_op: Operator                # untrained; ModelOp / JaxModelOp shaped
    dataset: TrainingDataset
    val_dataset: TrainingDataset | None
    loss: Loss
    optimizer_config: dict            # e.g. {"name": "adam", "lr": 1e-3}
    n_epochs: int
    batch_size: int
    backend: Literal["lightning", "equinox", "keras"] = "lightning"
    callbacks: list[Callback] | None = None   # None → no callbacks; avoid mutable default
    seed: int = 0

class ValidationStep(Operator):
    """Compute validation metrics on a held-out dataset.
    
    Used inside TrainingLoop and after for final-model evaluation.
    """
    model_op: Operator
    dataset: TrainingDataset
    metrics: list[Operator]           # composable: pa.metrics.RMSE(), etc.

class EarlyStopping(Operator):
    """Stop training when a monitored metric hasn't improved for N epochs.
    
    Implemented as a callback (works with all backends).
    """
    metric: str
    patience: int = 10
    mode: Literal["min", "max"] = "min"

3.5 The backend adapters (adapters/)

Each adapter ~150 LOC. Takes a TrainingLoop config and runs the underlying tool’s training loop. Translates pipekit constructs (TrainingDataset, Loss, ValidationStep) into the tool’s idioms.

# adapters/lightning.py — for PyTorch models
def run_lightning(loop: TrainingLoop) -> Operator:
    """Translate TrainingLoop into a Lightning LitModel + Trainer.
    
    Returns the trained model_op (pipekit-array.ModelOp wrapping the
    trained Lightning module).
    """
    ...

# adapters/equinox.py — for JAX models
def run_equinox(loop: TrainingLoop) -> Operator:
    """Equinox+Optax training loop. Returns a pipekit-jax.JaxModelOp
    wrapping the trained eqx.Module.
    """
    ...

# adapters/keras.py — for multi-backend Keras 3 models
def run_keras(loop: TrainingLoop) -> Operator:
    """Keras 3 training. Backend (TF/JAX/torch) follows Keras 3 config.
    Returns a pipekit-array.ModelOp wrapping the trained Keras model.
    """
    ...

The adapter’s job: translate. It does NOT reimplement training logic. PyTorch Lightning’s Trainer.fit is well-tested and handles 90% of operational concerns (distributed, mixed precision, checkpointing, profiling). Pipekit-train just wires the pipekit constructs into it.

3.6 Callbacks (callbacks.py)

@runtime_checkable
class Callback(Protocol):
    """Per-epoch / per-step hooks. Adapter translates to backend's callback API."""
    def on_train_start(self, loop, state): ...
    def on_epoch_end(self, loop, state, metrics): ...
    def on_train_end(self, loop, state): ...

class Checkpoint(Operator):
    """Save (model_op, optimizer_state) every N epochs."""
    every_n_epochs: int = 1
    keep_last: int = 3
    save_dir: str

class LogToExperiment(Operator):
    """Hook into pipekit-experiment's experiment tracker.
    
    Per-step / per-epoch metrics flow to MLflow / W&B / etc. via the
    experiment adapter (Report 12).
    """
    experiment_name: str

Part 4 — Worked examples

4.1 Direct supervised training (cloud-mask classifier)

import pipekit as pk
import pipekit_train as pt
import pipekit_array as pa
import geocatalog as gc
import geopatcher as gp
import geotoolz as gz

# The inference-time preprocess pipeline — re-used at train time
preprocess = pk.Sequential([
    gz.radiometry.ToFloat32(),
    gz.radiometry.PercentileClip(p_min=2, p_max=98),
])

# Catalog of labeled scenes (cloud-mask labels from human annotators)
catalog = gc.open_catalog("s3://imeo/labeled-scenes.parquet")

dataset = pt.CatalogDataset(
    catalog=catalog,
    preprocess=preprocess,
    target_op=gz.cloud.LoadLabel(label_key="cloud_mask"),
    sampler=gp.SpatialRegularStride((256, 256)),
)

loop = pt.TrainingLoop(
    model_op=pa.ModelOp(UNet(in_channels=4, out_channels=2)),
    dataset=dataset,
    val_dataset=dataset.with_split("val"),
    loss=pt.NLL(),
    optimizer_config={"name": "adam", "lr": 1e-3},
    n_epochs=50,
    batch_size=16,
    backend="lightning",
    callbacks=[
        pt.EarlyStopping(metric="val_loss", patience=5),
        pt.Checkpoint(every_n_epochs=5, save_dir="ckpts/"),
        pt.LogToExperiment(experiment_name="cloud_mask_unet"),
    ],
    seed=42,
)

trained_model_op = loop.run()
# trained_model_op is now a pipekit.Operator, drops into any Sequential

4.2 Emulator training (chemistry transport surrogate)

import pipekit_train as pt
import pipekit_cycle as pc
from plumax.adapters.pipekit import ChemistryForward

# Forward model — expensive physics
expensive_forward = ChemistryForward(species=["CH4", "NH3"], dt=3600.0)

# Sample initial states from a prior; generate training trajectories
dataset = pt.SimulationDataset(
    forward_model=expensive_forward,
    prior=AtmosphericPrior(distribution="climatology"),
    n_samples=10_000,
    cycle=pc.Cycle(step_op=expensive_forward, n_steps=24),
)

# Cache to disk so we don't regenerate every run
cached = pt.CachedDataset(source=dataset, cache_dir="s3://cache/chemistry_sim/")

# Train a neural emulator on the cached pairs
loop = pt.TrainingLoop(
    model_op=pa.ModelOp(EmulatorNet()),  # neural network
    dataset=cached,
    loss=pt.MSE(),
    optimizer_config={"name": "adam", "lr": 5e-4},
    n_epochs=100,
    batch_size=64,
    backend="lightning",
)

emulator_op = loop.run()
# emulator_op is a NeuralForward-compatible operator
# It drops directly into pipekit-cycle.Cycle as a forward_model replacement

4.3 Amortized inference training (neural posterior for plume sources)

import pipekit_train as pt
import pipekit_cycle as pc
from plumax.adapters.pipekit import PlumeForward, ColumnObs

# Simulator: (source_params) → simulated downwind concentrations
plume_forward = PlumeForward()
obs_op        = ColumnObs(instrument="TROPOMI")

# Generate (params, simulated_obs) pairs
dataset = pt.SimulationDataset(
    forward_model=plume_forward,
    prior=SourcePrior(loc_bounds=(-10, 10, 30, 40), strength_loguniform=(1, 1e4)),
    n_samples=100_000,
    # Note: simulator chain is forward + obs_op for the SBI training
)

# Train a conditional density estimator — pyrox / gaussflowx algorithmic core
from pyrox.adapters.pipekit_train import ConditionalNormalizingFlow

loop = pt.TrainingLoop(
    model_op=ConditionalNormalizingFlow(
        n_dim=4,                       # 4 source params
        condition_dim=...,             # depends on obs shape
        flow_type="masked_autoregressive",
    ),
    dataset=dataset,
    loss=pt.NLL(),                     # train as a likelihood
    optimizer_config={"name": "adam", "lr": 1e-3},
    n_epochs=200,
    batch_size=128,
    backend="equinox",                 # JAX-traceable
)

posterior_op = loop.run()
# posterior_op is amortized: takes observations, samples posterior parameters
# Drops into any inference pipeline that needs "given obs, give me sources"

Part 5 — Dependencies and optional extras

[project]
name = "pipekit-train"
version = "0.1.0"
dependencies = [
    "pipekit>=0.1",
    "numpy>=2.0",
]

[project.optional-dependencies]
# Backend training tools — pick one
lightning = ["lightning>=2.4", "torch>=2.0"]
equinox   = ["equinox>=0.11", "optax>=0.2", "jax>=0.4.20"]
keras     = ["keras>=3.0"]

# Carrier libraries — pick what's needed
array     = ["pipekit-array>=0.1"]
jax       = ["pipekit-jax>=0.1"]   # deferred

# Data sources
catalog   = ["geocatalog>=0.1"]
cycle     = ["pipekit-cycle>=0.1"]   # for SimulationDataset

# Experiment tracking
experiment = ["pipekit-experiment>=0.1"]

Minimum install gives the TrainingDataset / Loss / TrainingLoop machinery but no backend; backend extras must be installed. Same pattern as pipekit-array[numpy] — be explicit about which backend you want.

Part 6 — The training-pipeline reproducibility story

A training run produces a model. The model’s provenance must be recoverable:

  1. Training pipeline YAML. The TrainingLoop config serializes via Operator.state like any other operator. Includes model architecture (model_op’s config), dataset config, loss, optimizer, seed.

  2. Dataset content hash. TrainingDataset.content_hash() returns a hash identifying exactly which data the loop trained on. For catalog datasets: hash of catalog + query + split + seed. For simulation datasets: hash of forward_model + prior + n_samples + seed.

  3. Trained model artifact. Output of loop.run() is the trained ModelOp, content-addressed by (training_pipeline_hash, dataset_content_hash).

  4. Stored in model registry. pipekit-experiment (Report 12) handles registry storage.

The reproducibility artifact for a trained model is:

# methane_emulator_v3.training-artifact.yaml
training_pipeline: ...                  # the TrainingLoop YAML
dataset_hash: "ab12cd..."               # what was trained on
training_run_id: "mlflow://runs/xyz"    # experiment tracker reference
trained_model_uri: "s3://models/methane_emulator_v3.pt"
trained_model_hash: "ef34gh..."         # content hash of trained weights
backend: "lightning"
backend_version: "2.4.0"
hardware: "1xA100"
duration_seconds: 7200

This is the analog of pipekit.repro.Artifact (use case 9, regulatory) but for training runs. Pipekit-experiment ships it.

Part 7 — Honest tradeoffs

7.1 What gets better

  1. Train and infer with the same operator config. The model trained at v3 is loaded as a pipekit.Operator; it drops into inference pipelines without rewrite.

  2. Training is reproducible. Dataset content hash + training pipeline YAML + seed → byte-identical (modulo nondeterminism) trained model.

  3. Emulator training closes the loop with pipekit-cycle. SimulationDataset(forward_model=..., cycle=Cycle(...)) is one composable spec.

  4. Multi-backend. Same TrainingLoop config drives Lightning, Equinox+Optax, or Keras 3 — pick by user preference.

  5. Hyperparameter sweeps inherit experiment tracking. HyperSweep configs go through the same LogToExperiment callback.

7.2 What gets harder

  1. The adapter layer adds indirection. Users debugging a training failure may need to look at both pipekit-train and Lightning. Mitigation: clear error messages that point at the right layer; backend-specific debugging guides.

  2. Caching invariants are tricky. CachedDataset assumes the underlying source is deterministic given its config + seed. If a catalog mutates underneath, cache stays stale. Mitigation: opt-in cache; explicit invalidation API.

  3. Multi-backend coverage is uneven. Lightning is the default and best-tested. Equinox/Optax has rougher edges. Keras 3 is newer. Mitigation: be honest about which backends are battle-tested; mark Equinox/Keras adapters as experimental in v0.1.

  4. SimulationDataset can be very expensive. Generating 100K plume simulations is hours of compute. Mitigation: aggressive caching (CachedDataset on Zarr); orchestrator-driven generation via pipekit.parallel.ProcessMap.

7.3 What doesn’t fit and isn’t tried

Part 8 — Effort and timing

8.1 Effort

Total: ~2-3 weeks for v0.1 (Lightning + Equinox).

8.2 Timing

Ship after pipekit-cycle (Report 10) is stable. The most valuable feature (SimulationDataset for emulator training) depends on ForwardModel from pipekit-cycle.

Realistic timeline: v0.3 of the ecosystem.

8.3 What this unblocks

  1. Emulator training as a pipeline. The trained emulator drops into pipekit-cycle as a forward model.

  2. Amortized inference training. Train once, infer fast forever.

  3. Reproducible model versioning. Every trained model has a hash that traces back to its training data and pipeline.

  4. The “ML at every level” claim becomes operational. Training L0/L1 cloud detectors, L2 learnable retrievals, L3 ML gap-filling, L4 neural emulators — all use the same machinery, same registry, same composability.

Part 9 — Recommendation

Ship pipekit-train as a separate sister package. Signals:

The package lives in the same monorepo: packages/pipekit-train/. Sibling of pipekit-cycle and pipekit-experiment.

This is the trainer-side counterpart to pipekit-cycle. Together they form the ML-augmented L3-L4 stack: pipekit-cycle for the inference loops, pipekit-train for the training loops, both producing artifacts that compose into the same operator graph machinery. Without it, training stays research code outside the framework. With it, the train→serve loop is tight, versioned, and composable.