Report 11 — pipekit-train: training pipelines for emulators and amortized inference¶
| Status | Scoping proposal — committed |
| Reading time | ~22 min |
| Decisions locked in | Thin orchestration layer over existing training tools (Lightning, Equinox+Optax, Keras), not a fourth training-loop implementation. Monorepo development. Trained models are first-class pipekit.Operators. |
| Audience | Anyone reviewing the train-side pipeline machinery |
| Companion reports | Report 10 (pipekit-cycle), Report 12 (pipekit-experiment), Report 13 (statecatalog); existing reports for context |
Part 1 — Where pipekit-train sits in the stack¶
Domain libraries geotoolz │ xr_toolz
▲
│
Infrastructure ┌─ pipekit-cycle ─┐
│ pipekit-train │ ← this report
│ pipekit-experiment │
└─ statecatalog ──┘
▲
│
Framework pipekit ◄── pipekit-array
▲
│
Substrate georeader │ xarray ecosystempipekit-train is a thin orchestration layer, not a training-loop implementation. PyTorch Lightning, Equinox+Optax, Keras 3 already do the heavy training-loop work; adding a fourth is wrong. What’s missing is the glue that:
Generates training data from forward models (via
pipekit-cycle) or catalogs (viageocatalog)Versions the training config alongside the pipeline
Produces trained models as
pipekit.Operators that drop into inference pipelinesLogs to experiment trackers (via
pipekit-experiment)
Three observations:
The training pipeline is a YAML artifact, same as inference pipelines. A
TrainingLoopconfig is just another operator graph. The “training run” is reproducible from the YAML + data hash + seed.The trained model is itself an operator. After training, the output is a
pipekit.Operator(probablypipekit-array.ModelOpfor traditional models,pipekit-jax.JaxModelOpfor differentiable). Same composability as any other operator.The training-loop math is delegated. Lightning, Equinox+Optax, Keras 3 each have their own well-tested training-loop machinery. Pipekit-train wraps these; it doesn’t reimplement them.
Part 2 — Three training shapes this enables¶
The L0-L4 vision needs three categories of training pipeline, all unified under the same framework but with different data shapes and loss patterns.
2.1 Direct supervised training (data exists)¶
Standard ML: pairs of (input, target) come from a catalog of observations + labels.
Cloud-mask classifier trained on labeled scenes
Super-resolution model trained on co-registered pairs
Retrieval network trained on (radiance, geophysical-variable) labels from in-situ matchups
Data source: CatalogDataset — iterates a catalog, applies a preprocessing pipeline, yields (x, y) pairs.
2.2 Emulator training (simulator exists)¶
Surrogate-model training: pairs are generated by running a forward model (ForwardModel from pipekit-cycle).
Neural radiative transfer surrogate (cheap NN replacing expensive RT calculation)
Chemistry emulator (NN replacing GEOS-Chem forward step)
Plume dispersion emulator (NN replacing FLEXPART)
Data source: SimulationDataset — wraps a ForwardModel, samples parameters from a prior, evaluates the model, yields (params, simulated_output) pairs.
2.3 Amortized inference training (likelihood-free)¶
Simulation-based inference (SBI) / neural posterior estimation: train a conditional density estimator on (parameters, simulated_observations) pairs so that at inference time it can produce p(parameters | observations) for any new observation.
Neural posterior for plume source attribution (given downwind concentrations, sample sources)
Neural likelihood for retrieval inversion
Neural ratio estimator for model comparison
Data source: SimulationDataset again — same as emulator training but the network learns the inverse mapping. Algorithmic core from pyrox / gaussflowx.
Three categories, two data shapes. The same training-loop machinery serves all three; the difference is in what’s at the input/output ends.
Part 3 — What’s in pipekit-train¶
Five conceptual pieces, ~700 LOC of framework + adapters.
3.1 Source layout¶
pipekit-train/
__init__.py # public re-exports
_src/
dataset.py # TrainingDataset, CachedDataset, CatalogDataset, SimulationDataset
loss.py # Loss Protocol + common losses (MSE, NLL, KL, custom)
loop.py # TrainingLoop, ValidationStep, EarlyStopping
sweep.py # HyperSweep, ParameterGrid
callbacks.py # Callbacks Protocol; common impls (Checkpoint, LogToExperiment)
adapters/
__init__.py
lightning.py # PyTorch Lightning adapter (extras-gated)
equinox.py # Equinox + Optax adapter (extras-gated)
keras.py # Keras 3 adapter (extras-gated)3.2 The dataset abstractions (dataset.py)¶
class TrainingDataset(Operator):
"""Protocol-ish base: yields (input, target) pairs.
Subclasses implement __iter__. The framework gives them caching,
splitting, content-hashing, and pipekit composition for free.
"""
seed: int = 0
split: Literal["train", "val", "test"] = "train"
def __iter__(self) -> Iterator[tuple[Any, Any]]: ...
def content_hash(self) -> str:
"""Hash that identifies this dataset's contents.
Used for caching and for the regulatory artifact.
For catalogs: hash(catalog_uri, query, split, seed).
For simulations: hash(forward_model_config, prior_config, n_samples, seed).
"""
class CachedDataset(TrainingDataset):
"""Disk-backed cache around any TrainingDataset.
First epoch hits the source; subsequent epochs read from disk.
Cache is keyed on the wrapped dataset's content_hash().
"""
source: TrainingDataset
cache_dir: str
format: Literal["zarr", "parquet", "tfrecord"] = "zarr"
class CatalogDataset(TrainingDataset):
"""Yields (preprocessed_carrier, label) pairs from a catalog.
Used for direct supervised training:
- geocatalog over labeled scenes
- geopatcher.SpatialPatcher for tiling
- preprocessing pipeline applied per-chip
"""
catalog: GeoCatalog | StateCatalog
preprocess: Operator # the inference-time preprocess pipeline
target_op: Operator # extracts label from the catalog row
sampler: SpatialSampler | None # geopatcher sampler for tiling
class SimulationDataset(TrainingDataset):
"""Yields (parameters, simulator_output) pairs.
Used for emulator and amortized-inference training:
- sample parameters from a prior
- evaluate the forward model (pipekit-cycle.ForwardModel)
- return (params, output)
The output may be a full trajectory (via Cycle) or a single state
(single forward step), depending on what the emulator should learn.
"""
forward_model: ForwardModel # from pipekit-cycle
prior: Operator # samples parameter realizations
n_samples: int
cycle: Operator | None = None # optional pipekit_cycle.Cycle for trajectoriesThe SimulationDataset is the bridge to pipekit-cycle. It’s what makes the “emulator replaces forward model” loop close: train the emulator on simulations of the forward model; the trained emulator drops into the same pipekit-cycle.Cycle with the same observation operators.
3.3 The loss protocol (loss.py)¶
@runtime_checkable
class Loss(Protocol):
"""Loss function. Computed per-batch by the underlying training tool."""
def __call__(self, predicted: Any, target: Any) -> float: ...
# Common implementations
class MSE(Operator): pass
class NLL(Operator):
"""Negative log-likelihood for distributional outputs.
Used for emulator training when the emulator predicts a distribution
(mean + variance), or for amortized inference when the posterior is
explicitly a density.
"""
class KL(Operator):
"""KL divergence — for variational training."""
class Composite(Operator):
"""Weighted sum: loss = sum(w_i * L_i). Used for multi-objective."""
components: list[tuple[float, Loss]]The loss is a pipekit.Operator so it composes with the rest of the pipeline. The training tool (Lightning, Optax) sees it through its adapter as a callable.
3.4 The training loop (loop.py)¶
The headline composable. The training loop is itself a StatefulOperator (from Report 2 Group M) — its carry state is (model, optimizer_state, step, epoch, metrics).
class TrainingLoop(StatefulOperator):
"""Train a model_op on a dataset, producing a trained model_op.
Delegates to one of the backend adapters (Lightning, Equinox+Optax, Keras 3).
Logs to experiment tracker via callbacks (Report 12).
Output: a trained model_op (pipekit-array.ModelOp or pipekit-jax.JaxModelOp)
that drops into inference pipelines.
"""
model_op: Operator # untrained; ModelOp / JaxModelOp shaped
dataset: TrainingDataset
val_dataset: TrainingDataset | None
loss: Loss
optimizer_config: dict # e.g. {"name": "adam", "lr": 1e-3}
n_epochs: int
batch_size: int
backend: Literal["lightning", "equinox", "keras"] = "lightning"
callbacks: list[Callback] | None = None # None → no callbacks; avoid mutable default
seed: int = 0
class ValidationStep(Operator):
"""Compute validation metrics on a held-out dataset.
Used inside TrainingLoop and after for final-model evaluation.
"""
model_op: Operator
dataset: TrainingDataset
metrics: list[Operator] # composable: pa.metrics.RMSE(), etc.
class EarlyStopping(Operator):
"""Stop training when a monitored metric hasn't improved for N epochs.
Implemented as a callback (works with all backends).
"""
metric: str
patience: int = 10
mode: Literal["min", "max"] = "min"3.5 The backend adapters (adapters/)¶
Each adapter ~150 LOC. Takes a TrainingLoop config and runs the underlying tool’s training loop. Translates pipekit constructs (TrainingDataset, Loss, ValidationStep) into the tool’s idioms.
# adapters/lightning.py — for PyTorch models
def run_lightning(loop: TrainingLoop) -> Operator:
"""Translate TrainingLoop into a Lightning LitModel + Trainer.
Returns the trained model_op (pipekit-array.ModelOp wrapping the
trained Lightning module).
"""
...
# adapters/equinox.py — for JAX models
def run_equinox(loop: TrainingLoop) -> Operator:
"""Equinox+Optax training loop. Returns a pipekit-jax.JaxModelOp
wrapping the trained eqx.Module.
"""
...
# adapters/keras.py — for multi-backend Keras 3 models
def run_keras(loop: TrainingLoop) -> Operator:
"""Keras 3 training. Backend (TF/JAX/torch) follows Keras 3 config.
Returns a pipekit-array.ModelOp wrapping the trained Keras model.
"""
...The adapter’s job: translate. It does NOT reimplement training logic. PyTorch Lightning’s Trainer.fit is well-tested and handles 90% of operational concerns (distributed, mixed precision, checkpointing, profiling). Pipekit-train just wires the pipekit constructs into it.
3.6 Callbacks (callbacks.py)¶
@runtime_checkable
class Callback(Protocol):
"""Per-epoch / per-step hooks. Adapter translates to backend's callback API."""
def on_train_start(self, loop, state): ...
def on_epoch_end(self, loop, state, metrics): ...
def on_train_end(self, loop, state): ...
class Checkpoint(Operator):
"""Save (model_op, optimizer_state) every N epochs."""
every_n_epochs: int = 1
keep_last: int = 3
save_dir: str
class LogToExperiment(Operator):
"""Hook into pipekit-experiment's experiment tracker.
Per-step / per-epoch metrics flow to MLflow / W&B / etc. via the
experiment adapter (Report 12).
"""
experiment_name: strPart 4 — Worked examples¶
4.1 Direct supervised training (cloud-mask classifier)¶
import pipekit as pk
import pipekit_train as pt
import pipekit_array as pa
import geocatalog as gc
import geopatcher as gp
import geotoolz as gz
# The inference-time preprocess pipeline — re-used at train time
preprocess = pk.Sequential([
gz.radiometry.ToFloat32(),
gz.radiometry.PercentileClip(p_min=2, p_max=98),
])
# Catalog of labeled scenes (cloud-mask labels from human annotators)
catalog = gc.open_catalog("s3://imeo/labeled-scenes.parquet")
dataset = pt.CatalogDataset(
catalog=catalog,
preprocess=preprocess,
target_op=gz.cloud.LoadLabel(label_key="cloud_mask"),
sampler=gp.SpatialRegularStride((256, 256)),
)
loop = pt.TrainingLoop(
model_op=pa.ModelOp(UNet(in_channels=4, out_channels=2)),
dataset=dataset,
val_dataset=dataset.with_split("val"),
loss=pt.NLL(),
optimizer_config={"name": "adam", "lr": 1e-3},
n_epochs=50,
batch_size=16,
backend="lightning",
callbacks=[
pt.EarlyStopping(metric="val_loss", patience=5),
pt.Checkpoint(every_n_epochs=5, save_dir="ckpts/"),
pt.LogToExperiment(experiment_name="cloud_mask_unet"),
],
seed=42,
)
trained_model_op = loop.run()
# trained_model_op is now a pipekit.Operator, drops into any Sequential4.2 Emulator training (chemistry transport surrogate)¶
import pipekit_train as pt
import pipekit_cycle as pc
from plumax.adapters.pipekit import ChemistryForward
# Forward model — expensive physics
expensive_forward = ChemistryForward(species=["CH4", "NH3"], dt=3600.0)
# Sample initial states from a prior; generate training trajectories
dataset = pt.SimulationDataset(
forward_model=expensive_forward,
prior=AtmosphericPrior(distribution="climatology"),
n_samples=10_000,
cycle=pc.Cycle(step_op=expensive_forward, n_steps=24),
)
# Cache to disk so we don't regenerate every run
cached = pt.CachedDataset(source=dataset, cache_dir="s3://cache/chemistry_sim/")
# Train a neural emulator on the cached pairs
loop = pt.TrainingLoop(
model_op=pa.ModelOp(EmulatorNet()), # neural network
dataset=cached,
loss=pt.MSE(),
optimizer_config={"name": "adam", "lr": 5e-4},
n_epochs=100,
batch_size=64,
backend="lightning",
)
emulator_op = loop.run()
# emulator_op is a NeuralForward-compatible operator
# It drops directly into pipekit-cycle.Cycle as a forward_model replacement4.3 Amortized inference training (neural posterior for plume sources)¶
import pipekit_train as pt
import pipekit_cycle as pc
from plumax.adapters.pipekit import PlumeForward, ColumnObs
# Simulator: (source_params) → simulated downwind concentrations
plume_forward = PlumeForward()
obs_op = ColumnObs(instrument="TROPOMI")
# Generate (params, simulated_obs) pairs
dataset = pt.SimulationDataset(
forward_model=plume_forward,
prior=SourcePrior(loc_bounds=(-10, 10, 30, 40), strength_loguniform=(1, 1e4)),
n_samples=100_000,
# Note: simulator chain is forward + obs_op for the SBI training
)
# Train a conditional density estimator — pyrox / gaussflowx algorithmic core
from pyrox.adapters.pipekit_train import ConditionalNormalizingFlow
loop = pt.TrainingLoop(
model_op=ConditionalNormalizingFlow(
n_dim=4, # 4 source params
condition_dim=..., # depends on obs shape
flow_type="masked_autoregressive",
),
dataset=dataset,
loss=pt.NLL(), # train as a likelihood
optimizer_config={"name": "adam", "lr": 1e-3},
n_epochs=200,
batch_size=128,
backend="equinox", # JAX-traceable
)
posterior_op = loop.run()
# posterior_op is amortized: takes observations, samples posterior parameters
# Drops into any inference pipeline that needs "given obs, give me sources"Part 5 — Dependencies and optional extras¶
[project]
name = "pipekit-train"
version = "0.1.0"
dependencies = [
"pipekit>=0.1",
"numpy>=2.0",
]
[project.optional-dependencies]
# Backend training tools — pick one
lightning = ["lightning>=2.4", "torch>=2.0"]
equinox = ["equinox>=0.11", "optax>=0.2", "jax>=0.4.20"]
keras = ["keras>=3.0"]
# Carrier libraries — pick what's needed
array = ["pipekit-array>=0.1"]
jax = ["pipekit-jax>=0.1"] # deferred
# Data sources
catalog = ["geocatalog>=0.1"]
cycle = ["pipekit-cycle>=0.1"] # for SimulationDataset
# Experiment tracking
experiment = ["pipekit-experiment>=0.1"]Minimum install gives the TrainingDataset / Loss / TrainingLoop machinery but no backend; backend extras must be installed. Same pattern as pipekit-array[numpy] — be explicit about which backend you want.
Part 6 — The training-pipeline reproducibility story¶
A training run produces a model. The model’s provenance must be recoverable:
Training pipeline YAML. The
TrainingLoopconfig serializes viaOperator.statelike any other operator. Includes model architecture (model_op’s config), dataset config, loss, optimizer, seed.Dataset content hash.
TrainingDataset.content_hash()returns a hash identifying exactly which data the loop trained on. For catalog datasets: hash of catalog + query + split + seed. For simulation datasets: hash of forward_model + prior + n_samples + seed.Trained model artifact. Output of
loop.run()is the trainedModelOp, content-addressed by(training_pipeline_hash, dataset_content_hash).Stored in model registry.
pipekit-experiment(Report 12) handles registry storage.
The reproducibility artifact for a trained model is:
# methane_emulator_v3.training-artifact.yaml
training_pipeline: ... # the TrainingLoop YAML
dataset_hash: "ab12cd..." # what was trained on
training_run_id: "mlflow://runs/xyz" # experiment tracker reference
trained_model_uri: "s3://models/methane_emulator_v3.pt"
trained_model_hash: "ef34gh..." # content hash of trained weights
backend: "lightning"
backend_version: "2.4.0"
hardware: "1xA100"
duration_seconds: 7200This is the analog of pipekit.repro.Artifact (use case 9, regulatory) but for training runs. Pipekit-experiment ships it.
Part 7 — Honest tradeoffs¶
7.1 What gets better¶
Train and infer with the same operator config. The model trained at v3 is loaded as a
pipekit.Operator; it drops into inference pipelines without rewrite.Training is reproducible. Dataset content hash + training pipeline YAML + seed → byte-identical (modulo nondeterminism) trained model.
Emulator training closes the loop with pipekit-cycle.
SimulationDataset(forward_model=..., cycle=Cycle(...))is one composable spec.Multi-backend. Same
TrainingLoopconfig drives Lightning, Equinox+Optax, or Keras 3 — pick by user preference.Hyperparameter sweeps inherit experiment tracking.
HyperSweepconfigs go through the sameLogToExperimentcallback.
7.2 What gets harder¶
The adapter layer adds indirection. Users debugging a training failure may need to look at both pipekit-train and Lightning. Mitigation: clear error messages that point at the right layer; backend-specific debugging guides.
Caching invariants are tricky.
CachedDatasetassumes the underlying source is deterministic given its config + seed. If a catalog mutates underneath, cache stays stale. Mitigation: opt-in cache; explicit invalidation API.Multi-backend coverage is uneven. Lightning is the default and best-tested. Equinox/Optax has rougher edges. Keras 3 is newer. Mitigation: be honest about which backends are battle-tested; mark Equinox/Keras adapters as experimental in v0.1.
SimulationDatasetcan be very expensive. Generating 100K plume simulations is hours of compute. Mitigation: aggressive caching (CachedDataset on Zarr); orchestrator-driven generation viapipekit.parallel.ProcessMap.
7.3 What doesn’t fit and isn’t tried¶
Distributed training across nodes. Single-node multi-GPU works via Lightning’s built-in support. Multi-node is the orchestrator’s job.
Online / streaming training. Pipekit-train assumes batch training. Online learning is out of scope.
Reinforcement learning. No environment-step / reward loop. Different abstraction entirely.
Foundation-model fine-tuning at scale. Lightning + LoRA / PEFT plug in naturally; pipekit-train doesn’t add machinery for it.
Part 8 — Effort and timing¶
8.1 Effort¶
Day 1-2:
dataset.py—TrainingDataset,CatalogDataset,SimulationDataset,CachedDataset.Day 3:
loss.py,callbacks.py— common losses and callbacks.Day 4-5:
loop.py—TrainingLoop,ValidationStep,EarlyStopping.Day 6-8:
adapters/lightning.py— first adapter. Test end-to-end with a toy classifier.Day 9:
adapters/equinox.py— second adapter. Test with a toy JAX model.Day 10: Documentation, smoke tests with both adapters.
Day 11-12 (optional):
sweep.py— hyperparameter sweep machinery.Day 13-14:
adapters/keras.pyif time permits, or defer to v0.2.
Total: ~2-3 weeks for v0.1 (Lightning + Equinox).
8.2 Timing¶
Ship after pipekit-cycle (Report 10) is stable. The most valuable feature (SimulationDataset for emulator training) depends on ForwardModel from pipekit-cycle.
Realistic timeline: v0.3 of the ecosystem.
8.3 What this unblocks¶
Emulator training as a pipeline. The trained emulator drops into pipekit-cycle as a forward model.
Amortized inference training. Train once, infer fast forever.
Reproducible model versioning. Every trained model has a hash that traces back to its training data and pipeline.
The “ML at every level” claim becomes operational. Training L0/L1 cloud detectors, L2 learnable retrievals, L3 ML gap-filling, L4 neural emulators — all use the same machinery, same registry, same composability.
Part 9 — Recommendation¶
Ship pipekit-train as a separate sister package. Signals:
Heavy backend dependencies (PyTorch, JAX, Optax, Keras) — pipekit core can’t take these on
Multiple training-tool backends with different idioms — adapter pattern is the right shape
Distinct from inference-time concerns — keeping it separate clarifies what pipekit core is
Algorithm libraries (
pyrox,gaussflowx) plug in via their own[pipekit_train]extras — same pattern as everywhere else
The package lives in the same monorepo: packages/pipekit-train/. Sibling of pipekit-cycle and pipekit-experiment.
This is the trainer-side counterpart to pipekit-cycle. Together they form the ML-augmented L3-L4 stack: pipekit-cycle for the inference loops, pipekit-train for the training loops, both producing artifacts that compose into the same operator graph machinery. Without it, training stays research code outside the framework. With it, the train→serve loop is tight, versioned, and composable.