Flow on two-moons
Gaussianization Flow on Two-Moons¶
A minimal demonstration of a Gaussianization flow — a normalising flow built from two kinds of bijectors stacked in blocks:
- Rotation. Either a Householder product (trainable, ) or a fixed orthogonal matrix (frozen at initialisation, optionally from the data PCA). Rotations have , so they are “free” — their only job is to redistribute non-Gaussianity across dimensions.
- Marginal Gaussianization. For each dimension independently, apply where is a trainable mixture-of-Gaussians CDF. This is where the flow does its work: after enough stacked blocks, the pushforward is approximately .
The model is trained end-to-end by maximising
with . In code, each bijector contributes via self.add_loss, and the main training loss is the base NLL on the latent .
Mathematical background¶
The 1-D probability integral transform¶
For a univariate random variable with strictly increasing CDF , the classical result is that . Composing with the standard-normal quantile gives
So in 1-D we can Gaussianize any continuous distribution exactly, given its true CDF. The catch is that we don’t know ; we have to estimate it.
Mixture-of-Gaussians CDF¶
We approximate with a mixture of normals,
with , , . The corresponding pdf is
For input dimensions we fit one such mixture per dim independently — parameter tensors all have shape in our layer, so for and that’s numbers. The forward map and its log-derivative on dim are
Log-determinant structure of a flow block¶
One block is rotation → marginal. With for any , the rotation has , so it contributes nothing. The marginal’s Jacobian is diagonal, so
with and . Stacking blocks sums such terms. Everything is differentiable, so we can train by maximum likelihood.
Iterative vs parametric Gaussianization¶
Two ways to choose the block parameters — same architecture, very different philosophies.
- Iterative Gaussianization (RBIG — Laparra & Malo 2011): greedy, non-gradient, one block at a time. Given the current state , fit PCA, rotate, fit a per-dim GMM on the rotated data, apply the marginal transform numerically, move on. Laparra shows the negentropy decreases at a geometric rate under mild conditions. No gradients ever flow — each block is fit once and frozen. Cheap, remarkably good, and a useful warm start (see notebook 03).
- Parametric Gaussianization: the flow is the same architecture, but all parameters across all blocks are learned jointly by gradient descent on the exact NLL. Gradients backprop through the composition, so later blocks can compensate for finite- approximation errors in earlier blocks — something RBIG fundamentally cannot do. This notebook trains a parametric Gaussianization flow from scratch.
Both target the same optimum ( on the latent), but the path to it is different. Empirically, parametric + IG warm-start is the best of both worlds — notebook 03 shows the numbers.
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from keras import ops
from sklearn.datasets import make_moons
from gaussianization.gauss_keras import (
base_nll_loss,
make_gaussianization_flow,
)
# --- global plot styling ------------------------------------------------------
sns.set_theme(context="poster", style="whitegrid", palette="deep", font_scale=0.85)
plt.rcParams.update(
{
"figure.dpi": 110,
"savefig.dpi": 110,
"savefig.bbox": "tight",
"savefig.pad_inches": 0.2,
"figure.constrained_layout.use": True,
"figure.constrained_layout.h_pad": 0.1,
"figure.constrained_layout.w_pad": 0.1,
"axes.grid.which": "both",
"grid.linewidth": 0.7,
"grid.alpha": 0.5,
"axes.edgecolor": "0.25",
"axes.linewidth": 1.1,
"axes.titleweight": "semibold",
"axes.labelpad": 6,
}
)
def style_axes(ax, *, aspect=None, grid=True):
if grid:
ax.minorticks_on()
ax.grid(True, which="major", linewidth=0.8, alpha=0.6)
ax.grid(True, which="minor", linewidth=0.4, alpha=0.3)
if aspect is not None:
ax.set_aspect(aspect)
return ax
def style_jointgrid(g, aspect="equal"):
style_axes(g.ax_joint, aspect=aspect)
style_axes(g.ax_marg_x, grid=False)
style_axes(g.ax_marg_y, grid=False)
for spine in ("top", "right"):
g.ax_marg_x.spines[spine].set_visible(False)
g.ax_marg_y.spines[spine].set_visible(False)
keras.utils.set_random_seed(0)
np.random.seed(0)WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1776869560.070645 410070 cpu_feature_guard.cc:227] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
1. Dataset: two moons¶
Standardised so each dimension has unit variance, which keeps the initial mixture-component scales (initialised to ) well-matched to the data.
X_raw, _ = make_moons(n_samples=5000, noise=0.05, random_state=0)
X = (X_raw - X_raw.mean(axis=0)) / X_raw.std(axis=0)
X = X.astype("float32")
palette = sns.color_palette("deep")
g = sns.jointplot(
x=X[:, 0], y=X[:, 1], kind="scatter", color=palette[0],
height=7.5, ratio=5, space=0.1,
joint_kws={"s": 10, "alpha": 0.55},
marginal_kws={"bins": 50, "color": palette[0], "alpha": 0.75, "edgecolor": "white"},
)
g.set_axis_labels("$x_1$", "$x_2$")
g.figure.suptitle("Two moons (standardised)", y=1.02)
style_jointgrid(g)
plt.show()/home/azureuser/.cache/uv/archive-v0/aYSeLZUlluhRY4DCBNG7F/lib/python3.13/site-packages/seaborn/axisgrid.py:1766: UserWarning: The figure layout has changed to tight
f.tight_layout()

2. Build the flow¶
Eight blocks of (Householder, MixtureCDFGaussianization), preceded by a data-dependent PCA rotation that decorrelates the input before the first marginal layer sees it. Mixture means are initialised from the data quantiles so the first forward pass puts components across the observed range.
flow = make_gaussianization_flow(
input_dim=2,
num_blocks=8,
num_reflectors=2,
num_components=8,
pca_init_data=X,
mixture_init_data=X,
)
_ = flow(ops.convert_to_tensor(X[:4]))
print(f"number of bijectors : {len(flow.bijector_layers)}")
print(
"total trainable parameters : "
f"{int(sum(np.prod(w.shape) for w in flow.trainable_weights))}"
)E0000 00:00:1776869565.345186 410070 cuda_platform.cc:52] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
number of bijectors : 17
total trainable parameters : 416
3. Train¶
End-to-end gradient descent with Adam. The base-distribution NLL () is the main loss; the Jacobian log-dets flow in as add_loss contributions that Keras automatically aggregates.
flow.compile(
optimizer=keras.optimizers.Adam(learning_rate=5e-3),
loss=base_nll_loss,
)
history = flow.fit(
X, X, batch_size=512, epochs=200, verbose=0,
)
loss_curve = np.asarray(history.history["loss"])
fig, axes = plt.subplots(1, 2, figsize=(17, 6.5))
c = palette[0]
ax = axes[0]
ax.plot(loss_curve, color=c, linewidth=2.2)
ax.set_xlabel("epoch")
ax.set_ylabel("training loss $(-\\log p_X + \\text{const})$")
ax.set_title("Full training curve")
style_axes(ax)
ax = axes[1]
tail_start = 20
tail = loss_curve[tail_start:]
ax.plot(range(tail_start, len(loss_curve)), tail, color=c, linewidth=2.2)
span = tail.max() - tail.min()
ax.set_ylim(tail.min() - 0.05 * span, tail.max() + 0.05 * span)
ax.set_xlabel("epoch")
ax.set_ylabel("training loss")
ax.set_title(f"Zoomed (epochs $\\geq$ {tail_start})")
style_axes(ax)
plt.show()
print(f"final loss: {loss_curve[-1]:+.4f} (min: {loss_curve.min():+.4f})")
final loss: +1.2327 (min: +1.2327)
4. Pushforward: is actually Gaussian?¶
The whole point of Gaussianization is that after training, the flow maps the data to an approximate standard normal. A jointplot shows the pushforward scatter together with per-dimension histograms; the black reference curves are so we can eyeball the fit.
z = ops.convert_to_numpy(flow(ops.convert_to_tensor(X)))
mean_z = z.mean(axis=0)
cov_z = np.cov(z, rowvar=False)
print("pushforward diagnostics")
print(f" mean: [{mean_z[0]:+.3f}, {mean_z[1]:+.3f}] target = [0, 0]")
print(f" cov[0,0]: {cov_z[0, 0]:.3f} target = 1")
print(f" cov[1,1]: {cov_z[1, 1]:.3f} target = 1")
print(f" cov[0,1]: {cov_z[0, 1]:+.3f} target = 0")
g = sns.jointplot(
x=z[:, 0], y=z[:, 1], kind="scatter", color=palette[1],
height=7.5, ratio=5, space=0.1,
joint_kws={"s": 10, "alpha": 0.55},
marginal_kws={
"bins": 60, "color": palette[1], "alpha": 0.75,
"edgecolor": "white", "stat": "density",
},
xlim=(-4, 4), ylim=(-4, 4),
)
zz = np.linspace(-4, 4, 300)
phi = np.exp(-0.5 * zz**2) / np.sqrt(2 * np.pi)
g.ax_marg_x.plot(zz, phi, color="black", linewidth=2.0, label="$\\mathcal{N}(0, 1)$")
g.ax_marg_y.plot(phi, zz, color="black", linewidth=2.0)
g.ax_marg_x.legend(loc="upper left", frameon=False, fontsize=11)
g.set_axis_labels("$z_1$", "$z_2$")
g.figure.suptitle("Pushforward $z = f(x)$", y=1.02)
style_jointgrid(g)
plt.show()pushforward diagnostics
mean: [-0.006, -0.011] target = [0, 0]
cov[0,0]: 0.955 target = 1
cov[1,1]: 0.932 target = 1
cov[0,1]: +0.040 target = 0
/home/azureuser/.cache/uv/archive-v0/aYSeLZUlluhRY4DCBNG7F/lib/python3.13/site-packages/seaborn/axisgrid.py:1766: UserWarning: The figure layout has changed to tight
f.tight_layout()

5. Learned density ¶
Evaluate the flow’s exact density on a grid via flow.log_prob. For a normalising flow this is cheap: one forward pass plus the log-det accumulation — no Monte-Carlo integration is needed.
grid = 220
xs = np.linspace(-2.5, 2.5, grid).astype("float32")
ys = np.linspace(-2.5, 2.5, grid).astype("float32")
xx, yy = np.meshgrid(xs, ys)
pts = np.stack([xx.ravel(), yy.ravel()], axis=-1).astype("float32")
log_px = ops.convert_to_numpy(flow.log_prob(ops.convert_to_tensor(pts)))
log_px = log_px.reshape(grid, grid)
fig, axes = plt.subplots(1, 2, figsize=(17, 7.5))
ax = axes[0]
pcm = ax.pcolormesh(xx, yy, np.exp(log_px), cmap="magma", shading="auto")
ax.scatter(X[:1000, 0], X[:1000, 1], s=2, alpha=0.25, color="white")
ax.set_title("Density $p_X(x)$")
ax.set_xlabel("$x_1$")
ax.set_ylabel("$x_2$")
style_axes(ax, aspect="equal")
fig.colorbar(pcm, ax=ax, shrink=0.85)
ax = axes[1]
log_px_vis = np.clip(log_px, -10, None)
pcm = ax.pcolormesh(xx, yy, log_px_vis, cmap="viridis", shading="auto")
ax.scatter(X[:1000, 0], X[:1000, 1], s=2, alpha=0.25, color="white")
ax.set_title("Log-density $\\log p_X(x)$")
ax.set_xlabel("$x_1$")
ax.set_ylabel("$x_2$")
style_axes(ax, aspect="equal")
fig.colorbar(pcm, ax=ax, shrink=0.85)
plt.show()
6. Sampling from the flow¶
Draw and invert the composition: . The inverse of each MixtureCDFGaussianization layer uses a vectorised bisection on ; rotations are their own transpose (Householder) or the literal transpose (FixedOrtho).
samples = ops.convert_to_numpy(flow.sample(num_samples=5000, seed=123))
for data, title, col in [
(X, "Data", palette[0]),
(samples, "Samples $x = f^{-1}(z), \\; z \\sim \\mathcal{N}(0, I)$", palette[2]),
]:
g = sns.jointplot(
x=data[:, 0], y=data[:, 1], kind="scatter", color=col,
height=7.5, ratio=5, space=0.1,
joint_kws={"s": 10, "alpha": 0.55},
marginal_kws={
"bins": 50, "color": col, "alpha": 0.75,
"edgecolor": "white",
},
xlim=(-2.5, 2.5), ylim=(-2.5, 2.5),
)
g.set_axis_labels("$x_1$", "$x_2$")
g.figure.suptitle(title, y=1.02)
style_jointgrid(g)
plt.show()/home/azureuser/.cache/uv/archive-v0/aYSeLZUlluhRY4DCBNG7F/lib/python3.13/site-packages/seaborn/axisgrid.py:1766: UserWarning: The figure layout has changed to tight
f.tight_layout()

/home/azureuser/.cache/uv/archive-v0/aYSeLZUlluhRY4DCBNG7F/lib/python3.13/site-packages/seaborn/axisgrid.py:1766: UserWarning: The figure layout has changed to tight
f.tight_layout()

7. Round-trip inversion¶
A flow is a bijection by construction, so up to floating-point and the -clamp on that keeps bounded. We verify this numerically on the training set.
x_t = ops.convert_to_tensor(X)
z_fwd = flow(x_t)
x_rt = ops.convert_to_numpy(flow.invert(z_fwd))
err = np.linalg.norm(x_rt - X, axis=-1)
print("round-trip L2 error ||f^{-1}(f(x)) - x||_2")
print(f" median : {np.median(err):.2e}")
print(f" 95th% : {np.percentile(err, 95):.2e}")
print(f" max : {err.max():.2e}")
fig, ax = plt.subplots(figsize=(10, 5.5))
ax.hist(
np.log10(err + 1e-12), bins=50, color=palette[3],
alpha=0.85, edgecolor="white", linewidth=0.5,
)
ax.set_xlabel("$\\log_{10}$ round-trip error")
ax.set_ylabel("count")
ax.set_title("Round-trip error distribution")
style_axes(ax)
plt.show()round-trip L2 error ||f^{-1}(f(x)) - x||_2
median : 1.75e-06
95th% : 4.83e-06
max : 3.94e-05

8. Layer-by-layer Gaussianization¶
The intermediate states show how the distribution progressively deforms towards a standard normal. Rotations shuffle the mass across axes; marginal layers pull the per-dimension CDFs towards Φ.
states = flow.forward_with_intermediates(ops.convert_to_tensor(X[:2500]))
states_np = [ops.convert_to_numpy(s) for s in states]
n_states = len(states_np)
ncol = 4
nrow = int(np.ceil(n_states / ncol))
fig, axes = plt.subplots(nrow, ncol, figsize=(4.5 * ncol, 4.5 * nrow))
axes = axes.ravel()
for i, (state, ax) in enumerate(zip(states_np, axes)):
ax.scatter(state[:, 0], state[:, 1], s=5, alpha=0.6, color=palette[0])
lim = 3.4 if i > 0 else 2.8
ax.set_xlim(-lim, lim)
ax.set_ylim(-lim, lim)
if i == 0:
ax.set_title("Input $x$")
elif i == n_states - 1:
ax.set_title(f"Output $z$ (after {n_states - 1} layers)")
else:
ax.set_title(f"After layer {i}")
style_axes(ax, aspect="equal")
for ax in axes[n_states:]:
ax.axis("off")
plt.show()
What to look at¶
- Pushforward jointplot: the scatter should look isotropic and the marginals should sit right under the black curves.
- Density contour: should trace the two-moons shape. This is a learned bivariate density, evaluated in closed form on a grid.
- Sample jointplots: data and samples side-by-side — the 2-D morphology and the per-dim marginals should match.
- Round-trip error: numerical precision baseline for the bijection; typical error is with a few tail outliers from the CDF clamp.
- Layer-by-layer: the intuition behind a Gaussianization flow — each
(Rotation, Marginal)block nudges the distribution closer to .