Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Flow on two-moons

Gaussianization Flow on Two-Moons

A minimal demonstration of a Gaussianization flow — a normalising flow built from two kinds of bijectors stacked in blocks:

  1. Rotation. Either a Householder product (trainable, Q=j(I2vjvj/vj2)Q = \prod_j (I - 2 v_j v_j^\top / \lVert v_j \rVert^2)) or a fixed orthogonal matrix (frozen at initialisation, optionally from the data PCA). Rotations have logdetJ=0\log\lvert\det J\rvert = 0, so they are “free” — their only job is to redistribute non-Gaussianity across dimensions.
  2. Marginal Gaussianization. For each dimension independently, apply zi=Φ1(Fi(xi))z_i = \Phi^{-1}(F_i(x_i)) where FiF_i is a trainable mixture-of-Gaussians CDF. This is where the flow does its work: after enough stacked blocks, the pushforward is approximately N(0,I)\mathcal{N}(0, I).

The model is trained end-to-end by maximising

logpX(x)  =  logpZ(f(x))  +  logdetJf()\log p_X(x) \;=\; \log p_Z(f(x)) \;+\; \sum_\ell \log\lvert\det J_{f_\ell}(\cdot)\rvert

with pZ=N(0,I)p_Z = \mathcal{N}(0, I). In code, each bijector contributes E[logdetJ]-\mathbb{E}[\log\lvert\det J\rvert] via self.add_loss, and the main training loss is the base NLL on the latent zz.

Mathematical background

The 1-D probability integral transform

For a univariate random variable XpX \sim p with strictly increasing CDF FF, the classical result is that F(X)Uniform[0,1]F(X) \sim \text{Uniform}[0, 1]. Composing with the standard-normal quantile Φ1\Phi^{-1} gives

z  =  Φ1 ⁣(F(x))    N(0,1).z \;=\; \Phi^{-1}\!\big(F(x)\big) \;\sim\; \mathcal{N}(0, 1).

So in 1-D we can Gaussianize any continuous distribution exactly, given its true CDF. The catch is that we don’t know FF; we have to estimate it.

Mixture-of-Gaussians CDF

We approximate FF with a mixture of KK normals,

F(x;π,μ,σ)  =  k=1KπkΦ ⁣((xμk)/σk),F(x;\,\pi, \mu, \sigma) \;=\; \sum_{k=1}^{K} \pi_k \,\Phi\!\big((x - \mu_k)/\sigma_k\big),

with πΔK1\pi \in \Delta^{K-1}, μRK\mu \in \mathbb{R}^K, σR>0K\sigma \in \mathbb{R}_{>0}^K. The corresponding pdf is

f(x;π,μ,σ)  =  kπkσk1φ ⁣((xμk)/σk).f(x;\,\pi, \mu, \sigma) \;=\; \sum_{k} \pi_k \,\sigma_k^{-1}\,\varphi\!\big((x - \mu_k)/\sigma_k\big).

For dd input dimensions we fit one such mixture per dim independently — parameter tensors all have shape (d,K)(d, K) in our layer, so for d=2d = 2 and K=8K = 8 that’s 328=483 \cdot 2 \cdot 8 = 48 numbers. The forward map and its log-derivative on dim ii are

zi  =  Φ1 ⁣(Fi(xi)),logzi/xi  =  logfi(xi)    logφ(zi).z_i \;=\; \Phi^{-1}\!\big(F_i(x_i)\big), \qquad \log\lvert \partial z_i / \partial x_i \rvert \;=\; \log f_i(x_i) \;-\; \log \varphi(z_i).

Log-determinant structure of a flow block

One block is rotation → marginal. With xRdx \in \mathbb{R}^d for any dd, the rotation QO(d)Q \in O(d) has logdetQ=0\log\lvert\det Q\rvert = 0, so it contributes nothing. The marginal’s Jacobian is diagonal, so

logdetJblock(x)  =  i=1d[logfi(xirot)logφ(zi)]\log\lvert\det J_{\text{block}}(x) \rvert \;=\; \sum_{i = 1}^{d} \big[\log f_i(x_i^{\text{rot}}) - \log \varphi(z_i)\big]

with xrot=Qxx^{\text{rot}} = Q x and zi=Φ1(Fi(xirot))z_i = \Phi^{-1}(F_i(x_i^{\text{rot}})). Stacking LL blocks sums LL such terms. Everything is differentiable, so we can train by maximum likelihood.

Iterative vs parametric Gaussianization

Two ways to choose the block parameters {(Q,π,μ,σ)}=1L\{(Q_\ell, \pi_\ell, \mu_\ell, \sigma_\ell)\}_{\ell=1}^L — same architecture, very different philosophies.

  • Iterative Gaussianization (RBIG — Laparra & Malo 2011): greedy, non-gradient, one block at a time. Given the current state YY, fit Q=Q_\ell = PCA(Y)(Y), rotate, fit a per-dim GMM on the rotated data, apply the marginal transform numerically, move on. Laparra shows the negentropy J(Y)=DKL(YN(0,I))J(Y_\ell) = D_\text{KL}(Y_\ell \,\|\, \mathcal{N}(0, I)) decreases at a geometric rate under mild conditions. No gradients ever flow — each block is fit once and frozen. Cheap, remarkably good, and a useful warm start (see notebook 03).
  • Parametric Gaussianization: the flow is the same architecture, but all parameters (π,μ,σ,Q)(\pi, \mu, \sigma, Q) across all blocks are learned jointly by gradient descent on the exact NLL. Gradients backprop through the composition, so later blocks can compensate for finite-KK approximation errors in earlier blocks — something RBIG fundamentally cannot do. This notebook trains a parametric Gaussianization flow from scratch.

Both target the same optimum (pZ=Np_Z = \mathcal{N} on the latent), but the path to it is different. Empirically, parametric + IG warm-start is the best of both worlds — notebook 03 shows the numbers.

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import keras
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from keras import ops
from sklearn.datasets import make_moons

from gaussianization.gauss_keras import (
    base_nll_loss,
    make_gaussianization_flow,
)

# --- global plot styling ------------------------------------------------------
sns.set_theme(context="poster", style="whitegrid", palette="deep", font_scale=0.85)
plt.rcParams.update(
    {
        "figure.dpi": 110,
        "savefig.dpi": 110,
        "savefig.bbox": "tight",
        "savefig.pad_inches": 0.2,
        "figure.constrained_layout.use": True,
        "figure.constrained_layout.h_pad": 0.1,
        "figure.constrained_layout.w_pad": 0.1,
        "axes.grid.which": "both",
        "grid.linewidth": 0.7,
        "grid.alpha": 0.5,
        "axes.edgecolor": "0.25",
        "axes.linewidth": 1.1,
        "axes.titleweight": "semibold",
        "axes.labelpad": 6,
    }
)


def style_axes(ax, *, aspect=None, grid=True):
    if grid:
        ax.minorticks_on()
        ax.grid(True, which="major", linewidth=0.8, alpha=0.6)
        ax.grid(True, which="minor", linewidth=0.4, alpha=0.3)
    if aspect is not None:
        ax.set_aspect(aspect)
    return ax


def style_jointgrid(g, aspect="equal"):
    style_axes(g.ax_joint, aspect=aspect)
    style_axes(g.ax_marg_x, grid=False)
    style_axes(g.ax_marg_y, grid=False)
    for spine in ("top", "right"):
        g.ax_marg_x.spines[spine].set_visible(False)
        g.ax_marg_y.spines[spine].set_visible(False)


keras.utils.set_random_seed(0)
np.random.seed(0)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1776869560.070645  410070 cpu_feature_guard.cc:227] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.

1. Dataset: two moons

Standardised so each dimension has unit variance, which keeps the initial mixture-component scales (initialised to σ=1\sigma = 1) well-matched to the data.

X_raw, _ = make_moons(n_samples=5000, noise=0.05, random_state=0)
X = (X_raw - X_raw.mean(axis=0)) / X_raw.std(axis=0)
X = X.astype("float32")

palette = sns.color_palette("deep")

g = sns.jointplot(
    x=X[:, 0], y=X[:, 1], kind="scatter", color=palette[0],
    height=7.5, ratio=5, space=0.1,
    joint_kws={"s": 10, "alpha": 0.55},
    marginal_kws={"bins": 50, "color": palette[0], "alpha": 0.75, "edgecolor": "white"},
)
g.set_axis_labels("$x_1$", "$x_2$")
g.figure.suptitle("Two moons (standardised)", y=1.02)
style_jointgrid(g)
plt.show()
/home/azureuser/.cache/uv/archive-v0/aYSeLZUlluhRY4DCBNG7F/lib/python3.13/site-packages/seaborn/axisgrid.py:1766: UserWarning: The figure layout has changed to tight
  f.tight_layout()
<Figure size 825x825 with 3 Axes>

2. Build the flow

Eight blocks of (Householder, MixtureCDFGaussianization), preceded by a data-dependent PCA rotation that decorrelates the input before the first marginal layer sees it. Mixture means are initialised from the data quantiles so the first forward pass puts components across the observed range.

flow = make_gaussianization_flow(
    input_dim=2,
    num_blocks=8,
    num_reflectors=2,
    num_components=8,
    pca_init_data=X,
    mixture_init_data=X,
)

_ = flow(ops.convert_to_tensor(X[:4]))

print(f"number of bijectors          : {len(flow.bijector_layers)}")
print(
    "total trainable parameters   : "
    f"{int(sum(np.prod(w.shape) for w in flow.trainable_weights))}"
)
E0000 00:00:1776869565.345186  410070 cuda_platform.cc:52] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
number of bijectors          : 17
total trainable parameters   : 416

3. Train

End-to-end gradient descent with Adam. The base-distribution NLL (E[logN(z;0,I)]-\mathbb{E}[\log \mathcal{N}(z; 0, I)]) is the main loss; the Jacobian log-dets flow in as add_loss contributions that Keras automatically aggregates.

flow.compile(
    optimizer=keras.optimizers.Adam(learning_rate=5e-3),
    loss=base_nll_loss,
)

history = flow.fit(
    X, X, batch_size=512, epochs=200, verbose=0,
)

loss_curve = np.asarray(history.history["loss"])

fig, axes = plt.subplots(1, 2, figsize=(17, 6.5))
c = palette[0]

ax = axes[0]
ax.plot(loss_curve, color=c, linewidth=2.2)
ax.set_xlabel("epoch")
ax.set_ylabel("training loss  $(-\\log p_X + \\text{const})$")
ax.set_title("Full training curve")
style_axes(ax)

ax = axes[1]
tail_start = 20
tail = loss_curve[tail_start:]
ax.plot(range(tail_start, len(loss_curve)), tail, color=c, linewidth=2.2)
span = tail.max() - tail.min()
ax.set_ylim(tail.min() - 0.05 * span, tail.max() + 0.05 * span)
ax.set_xlabel("epoch")
ax.set_ylabel("training loss")
ax.set_title(f"Zoomed  (epochs $\\geq$ {tail_start})")
style_axes(ax)

plt.show()
print(f"final loss: {loss_curve[-1]:+.4f}   (min: {loss_curve.min():+.4f})")
<Figure size 1870x715 with 2 Axes>
final loss: +1.2327   (min: +1.2327)

4. Pushforward: is z=f(x)z = f(x) actually Gaussian?

The whole point of Gaussianization is that after training, the flow maps the data to an approximate standard normal. A jointplot shows the pushforward scatter together with per-dimension histograms; the black reference curves are φ(z)\varphi(z) so we can eyeball the fit.

z = ops.convert_to_numpy(flow(ops.convert_to_tensor(X)))

mean_z = z.mean(axis=0)
cov_z = np.cov(z, rowvar=False)
print("pushforward diagnostics")
print(f"  mean:      [{mean_z[0]:+.3f}, {mean_z[1]:+.3f}]   target = [0, 0]")
print(f"  cov[0,0]:  {cov_z[0, 0]:.3f}                  target = 1")
print(f"  cov[1,1]:  {cov_z[1, 1]:.3f}                  target = 1")
print(f"  cov[0,1]:  {cov_z[0, 1]:+.3f}                  target = 0")

g = sns.jointplot(
    x=z[:, 0], y=z[:, 1], kind="scatter", color=palette[1],
    height=7.5, ratio=5, space=0.1,
    joint_kws={"s": 10, "alpha": 0.55},
    marginal_kws={
        "bins": 60, "color": palette[1], "alpha": 0.75,
        "edgecolor": "white", "stat": "density",
    },
    xlim=(-4, 4), ylim=(-4, 4),
)
zz = np.linspace(-4, 4, 300)
phi = np.exp(-0.5 * zz**2) / np.sqrt(2 * np.pi)
g.ax_marg_x.plot(zz, phi, color="black", linewidth=2.0, label="$\\mathcal{N}(0, 1)$")
g.ax_marg_y.plot(phi, zz, color="black", linewidth=2.0)
g.ax_marg_x.legend(loc="upper left", frameon=False, fontsize=11)
g.set_axis_labels("$z_1$", "$z_2$")
g.figure.suptitle("Pushforward  $z = f(x)$", y=1.02)
style_jointgrid(g)
plt.show()
pushforward diagnostics
  mean:      [-0.006, -0.011]   target = [0, 0]
  cov[0,0]:  0.955                  target = 1
  cov[1,1]:  0.932                  target = 1
  cov[0,1]:  +0.040                  target = 0
/home/azureuser/.cache/uv/archive-v0/aYSeLZUlluhRY4DCBNG7F/lib/python3.13/site-packages/seaborn/axisgrid.py:1766: UserWarning: The figure layout has changed to tight
  f.tight_layout()
<Figure size 825x825 with 3 Axes>

5. Learned density logpX(x)\log p_X(x)

Evaluate the flow’s exact density on a grid via flow.log_prob. For a normalising flow this is cheap: one forward pass plus the log-det accumulation — no Monte-Carlo integration is needed.

grid = 220
xs = np.linspace(-2.5, 2.5, grid).astype("float32")
ys = np.linspace(-2.5, 2.5, grid).astype("float32")
xx, yy = np.meshgrid(xs, ys)
pts = np.stack([xx.ravel(), yy.ravel()], axis=-1).astype("float32")

log_px = ops.convert_to_numpy(flow.log_prob(ops.convert_to_tensor(pts)))
log_px = log_px.reshape(grid, grid)

fig, axes = plt.subplots(1, 2, figsize=(17, 7.5))

ax = axes[0]
pcm = ax.pcolormesh(xx, yy, np.exp(log_px), cmap="magma", shading="auto")
ax.scatter(X[:1000, 0], X[:1000, 1], s=2, alpha=0.25, color="white")
ax.set_title("Density  $p_X(x)$")
ax.set_xlabel("$x_1$")
ax.set_ylabel("$x_2$")
style_axes(ax, aspect="equal")
fig.colorbar(pcm, ax=ax, shrink=0.85)

ax = axes[1]
log_px_vis = np.clip(log_px, -10, None)
pcm = ax.pcolormesh(xx, yy, log_px_vis, cmap="viridis", shading="auto")
ax.scatter(X[:1000, 0], X[:1000, 1], s=2, alpha=0.25, color="white")
ax.set_title("Log-density  $\\log p_X(x)$")
ax.set_xlabel("$x_1$")
ax.set_ylabel("$x_2$")
style_axes(ax, aspect="equal")
fig.colorbar(pcm, ax=ax, shrink=0.85)

plt.show()
<Figure size 1870x825 with 4 Axes>

6. Sampling from the flow

Draw zN(0,I)z \sim \mathcal{N}(0, I) and invert the composition: x=f1(z)x = f^{-1}(z). The inverse of each MixtureCDFGaussianization layer uses a vectorised bisection on Fi(xi)=Φ(zi)F_i(x_i) = \Phi(z_i); rotations are their own transpose (Householder) or the literal transpose (FixedOrtho).

samples = ops.convert_to_numpy(flow.sample(num_samples=5000, seed=123))

for data, title, col in [
    (X, "Data", palette[0]),
    (samples, "Samples  $x = f^{-1}(z), \\; z \\sim \\mathcal{N}(0, I)$", palette[2]),
]:
    g = sns.jointplot(
        x=data[:, 0], y=data[:, 1], kind="scatter", color=col,
        height=7.5, ratio=5, space=0.1,
        joint_kws={"s": 10, "alpha": 0.55},
        marginal_kws={
            "bins": 50, "color": col, "alpha": 0.75,
            "edgecolor": "white",
        },
        xlim=(-2.5, 2.5), ylim=(-2.5, 2.5),
    )
    g.set_axis_labels("$x_1$", "$x_2$")
    g.figure.suptitle(title, y=1.02)
    style_jointgrid(g)
    plt.show()
/home/azureuser/.cache/uv/archive-v0/aYSeLZUlluhRY4DCBNG7F/lib/python3.13/site-packages/seaborn/axisgrid.py:1766: UserWarning: The figure layout has changed to tight
  f.tight_layout()
<Figure size 825x825 with 3 Axes>
/home/azureuser/.cache/uv/archive-v0/aYSeLZUlluhRY4DCBNG7F/lib/python3.13/site-packages/seaborn/axisgrid.py:1766: UserWarning: The figure layout has changed to tight
  f.tight_layout()
<Figure size 825x825 with 3 Axes>

7. Round-trip inversion

A flow is a bijection by construction, so f1(f(x))=xf^{-1}(f(x)) = x up to floating-point and the ε\varepsilon-clamp on F(x)[ε,1ε]F(x) \in [\varepsilon, 1-\varepsilon] that keeps Φ1\Phi^{-1} bounded. We verify this numerically on the training set.

x_t = ops.convert_to_tensor(X)
z_fwd = flow(x_t)
x_rt = ops.convert_to_numpy(flow.invert(z_fwd))

err = np.linalg.norm(x_rt - X, axis=-1)
print("round-trip L2 error  ||f^{-1}(f(x)) - x||_2")
print(f"  median : {np.median(err):.2e}")
print(f"  95th%  : {np.percentile(err, 95):.2e}")
print(f"  max    : {err.max():.2e}")

fig, ax = plt.subplots(figsize=(10, 5.5))
ax.hist(
    np.log10(err + 1e-12), bins=50, color=palette[3],
    alpha=0.85, edgecolor="white", linewidth=0.5,
)
ax.set_xlabel("$\\log_{10}$ round-trip error")
ax.set_ylabel("count")
ax.set_title("Round-trip error distribution")
style_axes(ax)
plt.show()
round-trip L2 error  ||f^{-1}(f(x)) - x||_2
  median : 1.75e-06
  95th%  : 4.83e-06
  max    : 3.94e-05
<Figure size 1100x605 with 1 Axes>

8. Layer-by-layer Gaussianization

The intermediate states x,f1(x),(f2f1)(x),x,\, f_1(x),\, (f_2 \circ f_1)(x), \ldots show how the distribution progressively deforms towards a standard normal. Rotations shuffle the mass across axes; marginal layers pull the per-dimension CDFs towards Φ.

states = flow.forward_with_intermediates(ops.convert_to_tensor(X[:2500]))
states_np = [ops.convert_to_numpy(s) for s in states]

n_states = len(states_np)
ncol = 4
nrow = int(np.ceil(n_states / ncol))
fig, axes = plt.subplots(nrow, ncol, figsize=(4.5 * ncol, 4.5 * nrow))
axes = axes.ravel()

for i, (state, ax) in enumerate(zip(states_np, axes)):
    ax.scatter(state[:, 0], state[:, 1], s=5, alpha=0.6, color=palette[0])
    lim = 3.4 if i > 0 else 2.8
    ax.set_xlim(-lim, lim)
    ax.set_ylim(-lim, lim)
    if i == 0:
        ax.set_title("Input  $x$")
    elif i == n_states - 1:
        ax.set_title(f"Output  $z$  (after {n_states - 1} layers)")
    else:
        ax.set_title(f"After layer {i}")
    style_axes(ax, aspect="equal")

for ax in axes[n_states:]:
    ax.axis("off")

plt.show()
<Figure size 1980x2475 with 20 Axes>

What to look at

  • Pushforward jointplot: the scatter should look isotropic and the marginals should sit right under the black N(0,1)\mathcal{N}(0, 1) curves.
  • Density contour: should trace the two-moons shape. This is a learned bivariate density, evaluated in closed form on a grid.
  • Sample jointplots: data and samples side-by-side — the 2-D morphology and the per-dim marginals should match.
  • Round-trip error: numerical precision baseline for the bijection; typical error is 104\lesssim 10^{-4} with a few tail outliers from the CDF clamp.
  • Layer-by-layer: the intuition behind a Gaussianization flow — each (Rotation, Marginal) block nudges the distribution closer to N(0,I)\mathcal{N}(0, I).