Part 1 — 1D Marginal Transforms
The atomic operation of Gaussianization: turn one coordinate’s distribution into
a standard normal via for some monotone CDF estimator .
Every method in the curriculum stacks these 1D maps between rotations, so Part 1
builds them every way that matters — estimating , computing the Jacobian
the flow needs, inverting it, and training it — grounded in
rbig and
gauss_flows.
Each notebook keeps the Part 0 pattern: derive the idea, then confirm it against the packages.
Notebooks¶
| # | notebook | master list | what you take away |
|---|---|---|---|
| 00 | ECDF & histograms | 1.1–1.3 | rank→uniform→normal; Glivenko–Cantelli; the ECDF’s degenerate Jacobian |
| 01 | KDE & Gaussian-mixture CDFs | 1.4–1.6 | smooth CDFs; analytic mixture log-det; choosing / (BIC) |
| 02 | Monotone-spline CDFs | 1.7–1.8 | monotonicity (PCHIP vs overshoot); RQS with exact inverse + analytic log-det |
| 03 | Learnable mixture-CDF | 1.9 | the marginal as a trainable layer; end-to-end MLE |
| 04 | Inversion strategies | 1.10–1.12 | bisection vs Newton; safeguarded hybrid; differentiating the inverse (unroll/one-step/adjoint); batched vmap |
Each estimator notebook also carries a Jacobian / log-determinant section
(, ),
since that per-coordinate gradient is the term a flow sums in log_prob.
Threads from Part 0¶
- The change-of-variables log-det (Part 0, 00) becomes concrete here: each estimator’s is its log-det.
- The forward/inverse trade-off (Part 0, 02) is realised: smooth CDFs invert by root-find, splines invert in closed form.
- Differentiating the root-find inverse — and the live gauss_flows#111 zero-gradient pitfall — is covered in notebook 04.
Running¶
Same uv environment as
Part 0 (rbig + gauss_flows + a Jupyter
stack), with interpax added for gauss_flows.HistogramCDF:
cd projects/gaussianization
uv pip install --python .venv-tutorials/bin/python interpax
.venv-tutorials/bin/jupyter nbconvert --to notebook --execute --inplace \
notebooks/01_marginal_transforms/0*.ipynb --ExecutePreprocessor.timeout=600Notebooks are paired (jupytext, py:percent) and set jax_enable_x64.