Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

gaussianization tutorial

Part 7 — Conditional Gaussianization

Every flow so far learned a single density p(x)p(x). Make the flow’s parameters depend on a context yy and it learns a whole family of densities p(xy)p(x \mid y) at once — a conditional Gaussianizer T(y)T(\cdot \mid y) that maps each conditional slice of the data to the same N(0,I)\mathcal{N}(0, I), with a tractable conditional log-likelihood and one-pass conditional sampling. This part shows where to inject the context (the base, the couplings, or both), how conditional density estimation works, and how a conditional flow becomes an amortised posterior for inverse problems — the bridge to the plug-and-play priors of Part 16. Built on gauss_flows.

The defining change from Parts 4-6: a parameter that was constant becomes a function of yy,

p(xy)=pZ(Tθ(x;y))detJTθ(;y)(x),p(x \mid y) = p_Z\big(T_\theta(x; y)\big)\,\big|\det J_{T_\theta(\cdot;y)}(x)\big|,

and the same NLL training, sampling, and log-det machinery carries straight over.

Notebooks

#notebookmaster listwhat you take away
00Three ways to condition7.1, 7.4inject yy at the base, the couplings, or both — and how to choose
01Conditional marginals & density estimation7.2, 7.3yy-dependent CDF margins; calibrated p(xy)p(x\mid y) on a heteroscedastic benchmark
02Conditional flow as an amortised posterior7.5train once on (x,y=Ax+η)(x, y=Ax+\eta), sample p(xy)p(x\mid y) for any yy — feeds Part 16

The headline: where to put the context

A conditional flow has three slots for yy — the base pZ(y)p_Z(\cdot\mid y) (per-context location/scale), the couplings Tθ(y)T_{\theta(y)} (per-context shape), and, for transforms that cannot natively read a context (rotations, normalisations), a FiLM-style Conditioner wrapper. Notebook 00 fits all four combinations side by side and reads off the rule of thumb: condition the base for shifts, the couplings for shape changes.

Threads

Running

Same FlowJax stack as the earlier parts (no ODE this time, so these are fast). Notebooks are paired (jupytext, py:percent) and set jax_enable_x64:

cd projects/gaussianization
PATH="$GF_VENV/bin:$PATH" "$GF_VENV/bin/jupyter" nbconvert --to notebook \
  --execute --inplace notebooks/07_conditional/0*.ipynb \
  --ExecutePreprocessor.timeout=900

where $GF_VENV is a virtualenv with gauss_flows, flowjax, optax, and a Jupyter stack.