Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

IG warm-start init

Iterative-Gaussianization warm start

Notebooks 01 and 02 trained Gaussianization / coupling flows from random initialisation. That works, but spends the first hundreds of epochs pulling the flow from a near-identity map towards the data.

Iterative Gaussianization (RBIG — Laparra & Malo 2011) is a non-gradient, greedy procedure that warm-starts the flow by walking block-by-block through the architecture and fitting each block’s parameters directly:

  1. Rotation block — fit PCA on the current state YY; push the eigenvectors into the rotation layer (for Householder this uses a QR decomposition to extract reflector vectors).
  2. Marginal block — fit a per-dim mixture-of-Gaussians on the rotated data with scikit-learn’s GaussianMixture, and assign (π,μ,σ)(\pi, \mu, \sigma) directly into the MixtureCDFGaussianization layer weights.
  3. Coupling block — fit a per-b-dim GMM on the current YY; zero the conditioner’s final Dense kernel and set its bias to the flattened GMM params. At initialisation the conditioner output is constant-in-xax_a, so the coupling layer acts like a diagonal marginal with the fitted mixture. Gradient training then lets the conditioner learn to modulate on xax_a.

Between blocks the current YY is propagated forward through the newly-parameterised bijector, so the next block sees data that has already been partially Gaussianized. A classical result (Laparra & Malo 2011, Chen & Gopinath 2000) is that this drives YY towards N(0,I)\mathcal{N}(0, I) at a geometric rate — even with no gradient descent at all, the flow is already a decent density estimator after a handful of blocks.

The point of this notebook: a Keras flow initialised this way starts at a much lower NLL than random init and reaches its training optimum in a fraction of the epochs.

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import keras
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from keras import ops
from sklearn.datasets import make_moons

from gaussianization.gauss_keras import (
    base_nll_loss,
    initialize_flow_from_ig,
    make_coupling_flow,
    make_gaussianization_flow,
)

# --- global plot styling ------------------------------------------------------
sns.set_theme(context="poster", style="whitegrid", palette="deep", font_scale=0.85)
plt.rcParams.update(
    {
        "figure.dpi": 110,
        "savefig.dpi": 110,
        "savefig.bbox": "tight",
        "savefig.pad_inches": 0.2,
        "figure.constrained_layout.use": True,
        "figure.constrained_layout.h_pad": 0.1,
        "figure.constrained_layout.w_pad": 0.1,
        "axes.grid.which": "both",
        "grid.linewidth": 0.7,
        "grid.alpha": 0.5,
        "axes.edgecolor": "0.25",
        "axes.linewidth": 1.1,
        "axes.titleweight": "semibold",
        "axes.labelpad": 6,
    }
)


def style_axes(ax, *, aspect=None, grid=True):
    if grid:
        ax.minorticks_on()
        ax.grid(True, which="major", linewidth=0.8, alpha=0.6)
        ax.grid(True, which="minor", linewidth=0.4, alpha=0.3)
    if aspect is not None:
        ax.set_aspect(aspect)
    return ax


def style_jointgrid(g, aspect="equal"):
    style_axes(g.ax_joint, aspect=aspect)
    style_axes(g.ax_marg_x, grid=False)
    style_axes(g.ax_marg_y, grid=False)
    for spine in ("top", "right"):
        g.ax_marg_x.spines[spine].set_visible(False)
        g.ax_marg_y.spines[spine].set_visible(False)


palette = sns.color_palette("deep")
COLOR_RANDOM = palette[3]
COLOR_IG = palette[2]

keras.utils.set_random_seed(0)
np.random.seed(0)
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1776868596.826946  386554 cpu_feature_guard.cc:227] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.

1. Dataset

Two moons, standardised. Same data as notebooks 01/02 so the comparison is apples-to-apples.

X_raw, _ = make_moons(n_samples=5000, noise=0.05, random_state=0)
X = (X_raw - X_raw.mean(axis=0)) / X_raw.std(axis=0)
X = X.astype("float32")

2. Diagonal flow — random vs IG init

Build two identical diagonal flows, warm-start one of them from IG, and train both with the same optimiser and schedule.

def build_diagonal_flow():
    return make_gaussianization_flow(
        input_dim=2,
        num_blocks=8,
        num_reflectors=2,
        num_components=8,
        pca_init_data=X,
        mixture_init_data=X,
    )


keras.utils.set_random_seed(1)
flow_random = build_diagonal_flow()
_ = flow_random(ops.convert_to_tensor(X[:4]))

keras.utils.set_random_seed(1)
flow_ig = build_diagonal_flow()
_ = flow_ig(ops.convert_to_tensor(X[:4]))
initialize_flow_from_ig(flow_ig, X)
E0000 00:00:1776868601.498397  386554 cuda_platform.cc:52] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
array([[ 0.21370773, 0.68232536], [-1.675768 , -0.09274741], [ 0.5233947 , 1.6713623 ], ..., [ 1.3039094 , -0.8307837 ], [ 0.17580122, -0.15649901], [ 0.17348166, -0.39425728]], shape=(5000, 2), dtype=float32)

Before any training: what did IG init already buy us?

RBIG on its own produces a density estimate. Below we compare the mean log-likelihood (per sample) of the IG-initialised flow against the random-init flow, both before a single gradient step.

lp_random0 = float(np.mean(ops.convert_to_numpy(
    flow_random.log_prob(ops.convert_to_tensor(X))
)))
lp_ig0 = float(np.mean(ops.convert_to_numpy(
    flow_ig.log_prob(ops.convert_to_tensor(X))
)))
print(f"mean log-likelihood before training:")
print(f"  random init : {lp_random0:+.3f} nats/sample")
print(f"  IG init     : {lp_ig0:+.3f} nats/sample  (improvement: {lp_ig0 - lp_random0:+.3f})")
mean log-likelihood before training:
  random init : -881.520 nats/sample
  IG init     : -1.652 nats/sample  (improvement: +879.868)

Pushforward at IG-init time

With no gradient training whatsoever, the pushforward of the IG-initialised flow should already be approximately N(0,I)\mathcal{N}(0, I). Jointplots make this immediately visible — the IG scatter should look isotropic and the marginals should line up with the black N(0,1)\mathcal{N}(0, 1) curves, while the random-init pushforward won’t.

z_ig0 = ops.convert_to_numpy(flow_ig(ops.convert_to_tensor(X)))
z_rnd0 = ops.convert_to_numpy(flow_random(ops.convert_to_tensor(X)))

zz = np.linspace(-4, 4, 300)
phi = np.exp(-0.5 * zz**2) / np.sqrt(2 * np.pi)

for z_sample, title, col in [
    (z_rnd0, "Random init  $z = f(x)$  (no training)", COLOR_RANDOM),
    (z_ig0, "IG init  $z = f(x)$  (no training)", COLOR_IG),
]:
    g = sns.jointplot(
        x=z_sample[:, 0], y=z_sample[:, 1], kind="scatter", color=col,
        height=7.5, ratio=5, space=0.1,
        joint_kws={"s": 10, "alpha": 0.55},
        marginal_kws={
            "bins": 60, "color": col, "alpha": 0.75,
            "edgecolor": "white", "stat": "density",
        },
        xlim=(-4, 4), ylim=(-4, 4),
    )
    g.ax_marg_x.plot(zz, phi, color="black", linewidth=2.0, label="$\\mathcal{N}(0, 1)$")
    g.ax_marg_y.plot(phi, zz, color="black", linewidth=2.0)
    g.ax_marg_x.legend(loc="upper left", frameon=False, fontsize=11)
    g.set_axis_labels("$z_1$", "$z_2$")
    g.figure.suptitle(title, y=1.02)
    style_jointgrid(g)
    plt.show()

mean_ig0 = z_ig0.mean(axis=0)
cov_ig0 = np.cov(z_ig0, rowvar=False)
print("IG-init pushforward (pre-training)")
print(f"  mean:    [{mean_ig0[0]:+.3f}, {mean_ig0[1]:+.3f}]   target = [0, 0]")
print(f"  cov[0,0]:{cov_ig0[0, 0]:.3f}    cov[1,1]:{cov_ig0[1, 1]:.3f}    target = 1, 1")
print(f"  cov[0,1]:{cov_ig0[0, 1]:+.3f}                          target = 0")
/home/azureuser/.cache/uv/archive-v0/aYSeLZUlluhRY4DCBNG7F/lib/python3.13/site-packages/seaborn/axisgrid.py:1766: UserWarning: The figure layout has changed to tight
  f.tight_layout()
<Figure size 825x825 with 3 Axes>
/home/azureuser/.cache/uv/archive-v0/aYSeLZUlluhRY4DCBNG7F/lib/python3.13/site-packages/seaborn/axisgrid.py:1766: UserWarning: The figure layout has changed to tight
  f.tight_layout()
<Figure size 825x825 with 3 Axes>
IG-init pushforward (pre-training)
  mean:    [+0.000, -0.000]   target = [0, 0]
  cov[0,0]:1.003    cov[1,1]:1.006    target = 1, 1
  cov[0,1]:+0.004                          target = 0

Training comparison

Both flows, same optimiser, same 150 epochs. The IG curve starts at (and finishes near) the random curve’s final value — the zoomed panel on the right makes the small end-of-training gap legible.

flow_random.compile(optimizer=keras.optimizers.Adam(5e-3), loss=base_nll_loss)
flow_ig.compile(optimizer=keras.optimizers.Adam(5e-3), loss=base_nll_loss)

hist_random = flow_random.fit(X, X, batch_size=512, epochs=150, verbose=0)
hist_ig = flow_ig.fit(X, X, batch_size=512, epochs=150, verbose=0)

rnd_curve = np.asarray(hist_random.history["loss"])
ig_curve = np.asarray(hist_ig.history["loss"])

fig, axes = plt.subplots(1, 2, figsize=(17, 6.5))

ax = axes[0]
ax.plot(rnd_curve, label="random init", color=COLOR_RANDOM, linewidth=2.2)
ax.plot(ig_curve, label="IG init", color=COLOR_IG, linewidth=2.2)
ax.set_xlabel("epoch")
ax.set_ylabel("training loss  $(-\\log p_X + \\text{const})$")
ax.set_title("Diagonal flow — full")
ax.legend(frameon=True)
style_axes(ax)

ax = axes[1]
tail_start = 15
tail_rnd = rnd_curve[tail_start:]
tail_ig = ig_curve[tail_start:]
both = np.concatenate([tail_rnd, tail_ig])
span = both.max() - both.min()
ax.plot(range(tail_start, len(rnd_curve)), tail_rnd,
        label="random init", color=COLOR_RANDOM, linewidth=2.2)
ax.plot(range(tail_start, len(ig_curve)), tail_ig,
        label="IG init", color=COLOR_IG, linewidth=2.2)
ax.set_ylim(both.min() - 0.05 * span, both.max() + 0.05 * span)
ax.set_xlabel("epoch")
ax.set_ylabel("training loss")
ax.set_title(f"Zoomed  (epochs $\\geq$ {tail_start})")
ax.legend(frameon=True)
style_axes(ax)

plt.show()
print(f"final loss:  random init = {rnd_curve[-1]:+.3f}   "
      f"IG init = {ig_curve[-1]:+.3f}")
<Figure size 1870x715 with 2 Axes>
final loss:  random init = +1.232   IG init = +1.149

Density and samples after training

grid = 220
xs = np.linspace(-2.5, 2.5, grid).astype("float32")
ys = np.linspace(-2.5, 2.5, grid).astype("float32")
xx, yy = np.meshgrid(xs, ys)
pts = np.stack([xx.ravel(), yy.ravel()], axis=-1).astype("float32")

log_px_rnd = ops.convert_to_numpy(
    flow_random.log_prob(ops.convert_to_tensor(pts))
).reshape(grid, grid)
log_px_ig = ops.convert_to_numpy(
    flow_ig.log_prob(ops.convert_to_tensor(pts))
).reshape(grid, grid)

fig, axes = plt.subplots(1, 2, figsize=(17, 7.5))
for ax, lp, title in zip(
    axes, [log_px_rnd, log_px_ig], ["Random init — density", "IG init — density"]
):
    pcm = ax.pcolormesh(xx, yy, np.exp(lp), cmap="magma", shading="auto")
    ax.scatter(X[:800, 0], X[:800, 1], s=2, alpha=0.25, color="white")
    ax.set_title(title)
    ax.set_xlabel("$x_1$")
    ax.set_ylabel("$x_2$")
    style_axes(ax, aspect="equal")
    fig.colorbar(pcm, ax=ax, shrink=0.85)
plt.show()

samples_rnd = ops.convert_to_numpy(flow_random.sample(num_samples=3000, seed=2))
samples_ig = ops.convert_to_numpy(flow_ig.sample(num_samples=3000, seed=2))

for s, title, col in [
    (samples_rnd, "Random init — samples", COLOR_RANDOM),
    (samples_ig, "IG init — samples", COLOR_IG),
]:
    g = sns.jointplot(
        x=s[:, 0], y=s[:, 1], kind="scatter", color=col,
        height=7.5, ratio=5, space=0.1,
        joint_kws={"s": 10, "alpha": 0.55},
        marginal_kws={"bins": 50, "color": col, "alpha": 0.75, "edgecolor": "white"},
        xlim=(-2.5, 2.5), ylim=(-2.5, 2.5),
    )
    g.set_axis_labels("$x_1$", "$x_2$")
    g.figure.suptitle(title, y=1.02)
    style_jointgrid(g)
    plt.show()
<Figure size 1870x825 with 4 Axes>
/home/azureuser/.cache/uv/archive-v0/aYSeLZUlluhRY4DCBNG7F/lib/python3.13/site-packages/seaborn/axisgrid.py:1766: UserWarning: The figure layout has changed to tight
  f.tight_layout()
<Figure size 825x825 with 3 Axes>
/home/azureuser/.cache/uv/archive-v0/aYSeLZUlluhRY4DCBNG7F/lib/python3.13/site-packages/seaborn/axisgrid.py:1766: UserWarning: The figure layout has changed to tight
  f.tight_layout()
<Figure size 825x825 with 3 Axes>

3. Coupling flow — random vs IG init

Same experiment with a coupling flow. IG init for coupling layers is weaker in isolation (the conditioner is not exercised at init — it emits a constant bias), but it still lets training skip the “the conditioner is effectively random, so the coupling transform is random” phase by handing it a sensible starting point where the conditioner behaves as a well-fit diagonal marginal.

def build_coupling_flow():
    return make_coupling_flow(
        input_dim=2,
        num_blocks=4,
        num_components=8,
        hidden=(64, 64),
    )


keras.utils.set_random_seed(3)
cpl_random = build_coupling_flow()
_ = cpl_random(ops.convert_to_tensor(X[:4]))

keras.utils.set_random_seed(3)
cpl_ig = build_coupling_flow()
_ = cpl_ig(ops.convert_to_tensor(X[:4]))
initialize_flow_from_ig(cpl_ig, X)
array([[ 0.33592415, -0.63961166], [-1.7532271 , -0.11788122], [ 0.52066326, -1.6266781 ], ..., [ 1.1722953 , 0.8738398 ], [ 0.15655102, 0.20274866], [ 0.06255252, 0.3414428 ]], shape=(5000, 2), dtype=float32)
cpl_random.compile(optimizer=keras.optimizers.Adam(3e-3), loss=base_nll_loss)
cpl_ig.compile(optimizer=keras.optimizers.Adam(3e-3), loss=base_nll_loss)

hist_c_random = cpl_random.fit(X, X, batch_size=512, epochs=100, verbose=0)
hist_c_ig = cpl_ig.fit(X, X, batch_size=512, epochs=100, verbose=0)

crnd_curve = np.asarray(hist_c_random.history["loss"])
cig_curve = np.asarray(hist_c_ig.history["loss"])

fig, axes = plt.subplots(1, 2, figsize=(17, 6.5))

ax = axes[0]
ax.plot(crnd_curve, label="random init", color=COLOR_RANDOM, linewidth=2.2)
ax.plot(cig_curve, label="IG init", color=COLOR_IG, linewidth=2.2)
ax.set_xlabel("epoch")
ax.set_ylabel("training loss  $(-\\log p_X + \\text{const})$")
ax.set_title("Coupling flow — full")
ax.legend(frameon=True)
style_axes(ax)

ax = axes[1]
tail_start = 10
tail_rnd = crnd_curve[tail_start:]
tail_ig = cig_curve[tail_start:]
both = np.concatenate([tail_rnd, tail_ig])
span = both.max() - both.min()
ax.plot(range(tail_start, len(crnd_curve)), tail_rnd,
        label="random init", color=COLOR_RANDOM, linewidth=2.2)
ax.plot(range(tail_start, len(cig_curve)), tail_ig,
        label="IG init", color=COLOR_IG, linewidth=2.2)
ax.set_ylim(both.min() - 0.05 * span, both.max() + 0.05 * span)
ax.set_xlabel("epoch")
ax.set_ylabel("training loss")
ax.set_title(f"Zoomed  (epochs $\\geq$ {tail_start})")
ax.legend(frameon=True)
style_axes(ax)

plt.show()
print(f"final loss:  random init = {crnd_curve[-1]:+.3f}   "
      f"IG init = {cig_curve[-1]:+.3f}")
<Figure size 1870x715 with 2 Axes>
final loss:  random init = +1.400   IG init = +1.092

Coupling density and samples after training

log_px_crnd = ops.convert_to_numpy(
    cpl_random.log_prob(ops.convert_to_tensor(pts))
).reshape(grid, grid)
log_px_cig = ops.convert_to_numpy(
    cpl_ig.log_prob(ops.convert_to_tensor(pts))
).reshape(grid, grid)

fig, axes = plt.subplots(1, 2, figsize=(17, 7.5))
for ax, lp, title in zip(
    axes, [log_px_crnd, log_px_cig], ["Coupling — random init", "Coupling — IG init"]
):
    pcm = ax.pcolormesh(xx, yy, np.exp(lp), cmap="magma", shading="auto")
    ax.scatter(X[:800, 0], X[:800, 1], s=2, alpha=0.25, color="white")
    ax.set_title(title)
    ax.set_xlabel("$x_1$")
    ax.set_ylabel("$x_2$")
    style_axes(ax, aspect="equal")
    fig.colorbar(pcm, ax=ax, shrink=0.85)
plt.show()

csamples_rnd = ops.convert_to_numpy(cpl_random.sample(num_samples=3000, seed=4))
csamples_ig = ops.convert_to_numpy(cpl_ig.sample(num_samples=3000, seed=4))

for s, title, col in [
    (csamples_rnd, "Coupling  random init — samples", COLOR_RANDOM),
    (csamples_ig, "Coupling  IG init — samples", COLOR_IG),
]:
    g = sns.jointplot(
        x=s[:, 0], y=s[:, 1], kind="scatter", color=col,
        height=7.5, ratio=5, space=0.1,
        joint_kws={"s": 10, "alpha": 0.55},
        marginal_kws={"bins": 50, "color": col, "alpha": 0.75, "edgecolor": "white"},
        xlim=(-2.5, 2.5), ylim=(-2.5, 2.5),
    )
    g.set_axis_labels("$x_1$", "$x_2$")
    g.figure.suptitle(title, y=1.02)
    style_jointgrid(g)
    plt.show()
<Figure size 1870x825 with 4 Axes>
/home/azureuser/.cache/uv/archive-v0/aYSeLZUlluhRY4DCBNG7F/lib/python3.13/site-packages/seaborn/axisgrid.py:1766: UserWarning: The figure layout has changed to tight
  f.tight_layout()
<Figure size 825x825 with 3 Axes>
/home/azureuser/.cache/uv/archive-v0/aYSeLZUlluhRY4DCBNG7F/lib/python3.13/site-packages/seaborn/axisgrid.py:1766: UserWarning: The figure layout has changed to tight
  f.tight_layout()
<Figure size 825x825 with 3 Axes>

4. Verifying the zero-kernel contract

Immediately after initialize_flow_from_ig every MixtureCDFCoupling’s final Dense has a zero kernel (so the conditioner emits its bias for any input — the coupling layer behaves as a diagonal marginal). During gradient training those kernels become non-zero as the conditioner learns to modulate on xax_a.

Below: build a fresh coupling flow, IG-init it, check the kernels before any training, then train briefly and check again.

from gaussianization.gauss_keras.bijectors import MixtureCDFCoupling


def max_final_kernel(flow):
    out = []
    for b in flow.bijector_layers:
        if isinstance(b, MixtureCDFCoupling):
            last = None
            for sub in b.conditioner.layers:
                if hasattr(sub, "kernel") and hasattr(sub, "bias"):
                    last = sub
            out.append(float(np.max(np.abs(ops.convert_to_numpy(last.kernel)))))
    return out


keras.utils.set_random_seed(42)
cpl_probe = build_coupling_flow()
_ = cpl_probe(ops.convert_to_tensor(X[:4]))
kernels_before = max_final_kernel(cpl_probe)
print(f"fresh flow (before IG init):      max|W| per coupling = {[f'{k:.3f}' for k in kernels_before]}")

initialize_flow_from_ig(cpl_probe, X)
kernels_post_ig = max_final_kernel(cpl_probe)
print(f"right after IG init (no train):   max|W| per coupling = {[f'{k:.1e}' for k in kernels_post_ig]}")
print(f"all zero after IG init : {all(k < 1e-6 for k in kernels_post_ig)}")

cpl_probe.compile(optimizer=keras.optimizers.Adam(3e-3), loss=base_nll_loss)
_ = cpl_probe.fit(X, X, batch_size=512, epochs=20, verbose=0)
kernels_after_train = max_final_kernel(cpl_probe)
print(f"after 20 epochs of training:      max|W| per coupling = {[f'{k:.3f}' for k in kernels_after_train]}")
print("(non-zero is expected — the conditioner has started to modulate on x_a)")
fresh flow (before IG init):      max|W| per coupling = ['0.000', '0.000', '0.000', '0.000', '0.000', '0.000', '0.000', '0.000']
right after IG init (no train):   max|W| per coupling = ['0.0e+00', '0.0e+00', '0.0e+00', '0.0e+00', '0.0e+00', '0.0e+00', '0.0e+00', '0.0e+00']
all zero after IG init : True
after 20 epochs of training:      max|W| per coupling = ['0.283', '0.394', '0.373', '0.376', '0.556', '0.181', '0.145', '0.226']
(non-zero is expected — the conditioner has started to modulate on x_a)

Recap

  • RBIG (Laparra & Malo 2011) is a non-gradient greedy warm-start that fits each block’s parameters in sequence (PCA for rotation, sklearn GaussianMixture per dim for marginal/coupling) and propagates the data forward.
  • For the diagonal flow the init is exact: the layer’s (π,μ,σ)(\pi, \mu, \sigma) parameters are exactly the GMM fit, and the pre-training pushforward is already near N(0,I)\mathcal{N}(0, I).
  • For the coupling flow the conditioner’s inner Dense weights stay random, but the final Dense is forced to kernel=0, bias = fitted mixture params. So at init the coupling layer behaves like a diagonal marginal layer with the IG-fitted mixture; gradient training then lets the conditioner become genuinely conditional.
  • End result in both cases: lower starting NLL, faster convergence, and fewer chances for the optimiser to get stuck in a bad local minimum.