Part 7 — Conditional Gaussianization
Every flow so far learned a single density . Make the flow’s parameters depend
on a context and it learns a whole family of densities at once —
a conditional Gaussianizer that maps each conditional slice of the
data to the same , with a tractable conditional log-likelihood and
one-pass conditional sampling. This part shows where to inject the context (the
base, the couplings, or both), how conditional density estimation works, and how a
conditional flow becomes an amortised posterior for inverse problems — the bridge
to the plug-and-play priors of Part 16. Built on
gauss_flows.
The defining change from Parts 4-6: a parameter that was constant becomes a function of ,
and the same NLL training, sampling, and log-det machinery carries straight over.
Notebooks¶
| # | notebook | master list | what you take away |
|---|---|---|---|
| 00 | Three ways to condition | 7.1, 7.4 | inject at the base, the couplings, or both — and how to choose |
| 01 | Conditional marginals & density estimation | 7.2, 7.3 | -dependent CDF margins; calibrated on a heteroscedastic benchmark |
| 02 | Conditional flow as an amortised posterior | 7.5 | train once on , sample for any — feeds Part 16 |
The headline: where to put the context¶
A conditional flow has three slots for — the base (per-context
location/scale), the couplings (per-context shape), and, for
transforms that cannot natively read a context (rotations, normalisations), a FiLM-style
Conditioner wrapper. Notebook 00 fits all four combinations side by side and reads off
the rule of thumb: condition the base for shifts, the couplings for shape changes.
Threads¶
- Back to Part 5. The conditioner that drove coupling layers is the same machinery; here it simply also reads the external context (cf. 5.17).
- Forward to Part 16. A conditional flow trained on is an amortised posterior — notebook 02 is the toy that the plug-and-play inverse-problem solvers of Part 16 scale up.
Running¶
Same FlowJax stack as the earlier parts (no ODE this time, so these are fast).
Notebooks are paired (jupytext, py:percent) and set jax_enable_x64:
cd projects/gaussianization
PATH="$GF_VENV/bin:$PATH" "$GF_VENV/bin/jupyter" nbconvert --to notebook \
--execute --inplace notebooks/07_conditional/0*.ipynb \
--ExecutePreprocessor.timeout=900where $GF_VENV is a virtualenv with gauss_flows, flowjax, optax, and a Jupyter
stack.