Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Fair learning with frozen Gaussianization flows — design doc

Three fairness penalties built from a frozen Gaussianization flow

Fair learning with frozen Gaussianization flows

1. TL;DR

Replace the CKA fairness penalty in keras-fairkl’s FairModelWrapper — a Keras port of the fair-kernel-learning idea of Pérez-Suay et al. (2017) — with a family of fairness losses built from a frozen Gaussianization flow. The flow is trained once on an auxiliary dataset, its weights are frozen, and it is then used as a differentiable Gaussian-space probe inside the downstream task’s optimisation loop. Gaussianisation lets us replace bandwidth-tuned RBF kernels (CKA, HSIC) with closed-form, parametric, scale-invariant penalties — the flow absorbs the kernel choice.

Three penalties, in order of strictness:

Table (1):Output-side fairness losses; see 4. Mathematical formulation for the math.

LossCapturesClosed form?Joint flow needed?Class
G-XCOV2nd-moment dependence in Gaussianised space (linear CKA there)yesno — two marginal flowsGaussianizedXCovLoss
G-MIMI assuming joint-Gaussian after Gaussianisationyesno — two marginal flowsGaussianizedMutualInfoLoss
G-TCFull MI / total correlation, no joint-Gaussian assumptionno — via flow NLLyes — one joint flow over (z,q)(z, q)GaussianizedTotalCorrelationLoss

All three are differentiable w.r.t. the downstream model parameters and plug into FairModelWrapper via its fairness_loss=... argument.


2. The mental model — three pictures

2.1 Stage 1 (one-time): pretrain the probes

2.2 Stage 2 (every step): the fair training loop

The trick — and the load-bearing claim of this whole experiment — is that the flow’s weights are frozen but the flow’s input is the predictor’s output, so gradients still propagate from the loss back through TzT_z to θ.

2.3 Where each loss penalises


3. Why Gaussianisation helps (the one-paragraph theory)

A Gaussianization flow T:RdRdT: \mathbb{R}^d \to \mathbb{R}^d — in the lineage of Chen & Gopinath (2000), Laparra et al. (2011), and Meng et al. (2020) — is a smooth diffeomorphism with smooth inverse, so it preserves all statistical dependence: T(Z)T(Q)    ZQT(Z) \perp T(Q) \iff Z \perp Q. What it changes is the shape of the marginals. After training, each marginal of T(X)T(X) is approximately N(0,1)\mathcal{N}(0, 1). Three consequences:

  1. Bandwidth-free dependence measures. CKA and HSIC need a kernel bandwidth; the “right” bandwidth depends on the scale of the data. In Gaussianised space the scale is fixed at 1, so a linear kernel (or a unit-bandwidth RBF) suffices. The flow absorbs the bandwidth choice into its mixture-CDF parameters during pretraining.

  2. Gaussian-joint assumption becomes nearly free. Closed-form MI (12logdet(ICC)-\tfrac{1}{2}\log\det(I - CC^\top)) requires assuming the joint (Z,Q)(Z, Q) is Gaussian. On raw data this is wildly wrong. After marginal Gaussianisation it is much closer to true — the marginals are exact, only the copula remains non-Gaussian — so the closed-form MI estimate becomes a usable surrogate.

  3. Compatibility with frozen-flow autodiff. All flow components (MixtureCDFGaussianization, Householder) are smooth in their inputs. Stopping gradient on the flow’s weights does not stop gradient flow through the flow’s inputs, so the predictor’s parameters θ still receive a gradient signal from zLfair(Tz(z),Tq(q))\nabla_z \mathcal{L}_{\text{fair}}(T_z(z), T_q(q)) via the chain rule.

The flow is therefore exactly the right thing to freeze: a fixed, smooth, scale-normalising, differentiable preprocessor that turns “measure non-linear dependence in the data” into “measure linear dependence between near-Gaussian variables.”


4. Mathematical formulation

Let z=fθ(X)Rdzz = f_\theta(X) \in \mathbb{R}^{d_z} be the predictor output and qRdqq \in \mathbb{R}^{d_q} the sensitive attribute. Define Z=Tz(z)Z = T_z(z), Q=Tq(q)Q = T_q(q) in Gaussianised space, and let C=Cov^(Z,Q)C = \widehat{\mathrm{Cov}}(Z, Q), Sz=Cov^(Z,Z)S_z = \widehat{\mathrm{Cov}}(Z, Z), Sq=Cov^(Q,Q)S_q = \widehat{\mathrm{Cov}}(Q, Q) denote sample (cross-)covariances on a batch of size nn.

4.1 G-XCOV — linear-CKA in Gaussianised space

LG-XCOV  =  CF2SzFSqF+ε  [0,1].\mathcal{L}_{\text{G-XCOV}} \;=\; \frac{\lVert C \rVert_F^2} {\lVert S_z \rVert_F \, \lVert S_q \rVert_F + \varepsilon} \quad\in\; [0, 1].

Equation ((1)) is exact linear CKA Cortes et al., 2012Kornblith et al., 2019 applied to the Gaussianised features. In dz=dq=1d_z = d_q = 1 it collapses to ρ2\rho^2 where ρ is the Gaussianised cross-correlation. The un-normalised numerator CF2\lVert C \rVert_F^2 is identically HSIC with linear kernels Gretton et al., 2005 on the Gaussianised features — so this single loss covers both “linear CKA in Gaussianised space” and “HSIC with linear kernels in Gaussianised space” depending on whether you toggle normalize.

Bounded, smooth, second-moment only. Gradient ρ2/ρ=2ρ\partial \rho^2 / \partial \rho = 2\rho is bounded — the loss is gentle near perfect dependence.

4.2 G-MI — closed-form Gaussian mutual information

If (Z,Q)(Z, Q) were jointly Gaussian with standardised marginals, mutual information has the Gel’fand–Yaglom closed form Gel'fand & Yaglom, 1957:

I(Z;Q)  =  12logdet(IdzCSq1CSz1).I(Z; Q) \;=\; -\tfrac{1}{2}\log\det\bigl(I_{d_z} - C\, S_q^{-1}\, C^\top\, S_z^{-1}\bigr).

After Gaussianisation SzIdzS_z \approx I_{d_z} and SqIdqS_q \approx I_{d_q}, so ((2)) simplifies to

LG-MI  =  12logdet(IdzCC)  [0,+).\mathcal{L}_{\text{G-MI}} \;=\; -\tfrac{1}{2}\log\det(I_{d_z} - C\, C^\top) \quad\in\; [0, +\infty).

In dz=dq=1d_z = d_q = 1 ((3)) is 12log(1ρ2)-\tfrac{1}{2}\log(1 - \rho^2). The gradient L/ρ=ρ/(1ρ2)\partial \mathcal{L} / \partial \rho = \rho / (1 - \rho^2) diverges as ρ1\rho \to 1, so G-MI is much sharper than G-XCOV at high dependence. We clip the eigenvalues of ICCI - CC^\top at a small ε\varepsilon for numerical safety; this caps the loss at d2logε-\tfrac{d}{2}\log\varepsilon.

The closed-form requires the joint to be Gaussian after Gaussianisation. Marginal Gaussianisation gets us most of the way there, but the residual copula is still arbitrary — so G-MI underestimates true MI whenever the dependence has structure beyond second-order correlation (quadratic-in-ZZ, XOR-style, multi-modal). That gap is exactly what G-TC closes.

4.3 G-TC — total correlation under a frozen joint flow

Pretrain a joint flow Tzq:Rdz+dqRdz+dqT_{zq}: \mathbb{R}^{d_z + d_q} \to \mathbb{R}^{d_z + d_q} on the empirical product distribution of the baseline data: draw (zi(0),qπ(i))(z^{(0)}_i, q_{\pi(i)}) where π is a random permutation. By construction these pairs are independent, so a well-fit TzqT_{zq} Gaussianises independent draws to N(0,Idz+dq)\mathcal{N}(0, I_{d_z + d_q}). Freeze TzqT_{zq}.

At downstream training time, evaluate the same frozen TzqT_{zq} on the actual (potentially dependent) pair (z,q)(z, q):

LG-TC  =  1ni=1nlogpN(0,I) ⁣(Tzq(zi,qi)).\mathcal{L}_{\text{G-TC}} \;=\; -\frac{1}{n}\sum_{i=1}^{n} \log p_{\mathcal{N}(0, I)}\!\bigl(T_{zq}(z_i, q_i)\bigr).

When (z,q)(z, q) is independent like the baseline, Tzq(z,q)N(0,I)T_{zq}(z, q) \sim \mathcal{N}(0, I) and ((4)) equals the entropy of N(0,Idz+dq)\mathcal{N}(0, I_{d_z + d_q}) (a constant in θ). When (z,q)(z, q) carries dependence, TzqT_{zq} no longer Gaussianises the joint and the NLL is strictly larger.

By a change-of-variables argument, this NLL difference is exactly the KL divergence between p(z,q)p(z, q) and p(z)p(q)p(z)\,p(q) — i.e. the mutual information; see Watanabe (1960) for the total-correlation framing. Unlike G-MI it does not assume the joint is Gaussian: the flow itself learns the copula during pretraining. The price is needing a richer pretraining stage.

4.4 Comparison table

PropertyG-XCOVG-MIG-TC
Order of dependence2nd momentAll (joint-Gaussian)All (no joint assumption)
Range[0,1][0, 1][0,d2logε][0, -\tfrac{d}{2}\log\varepsilon][HN(0,I),+)[H_{\mathcal{N}(0,I)}, +\infty)
Gradient at high depBoundedDivergesBounded if flow well-fit
Closed formyesyesno — needs flow forward pass
Pretraining2 marginal flows2 marginal flows2 marginal or 1 joint flow
Compute / batchone matmulone matmul + eighfull joint-flow forward
Sensitive to copula structure beyond 2nd ordernonoyes

4.5 Deferred (stretch)


5. Hypotheses we’re testing

Each loss family makes a falsifiable prediction about what kind of dependence it can suppress, and at what cost in accuracy.


6. Experiment design

6.1 Two-stage pipeline (reprise of the ASCII diagrams)

Stage 1 — pretrain + freeze.

Stage 2 — fair downstream training.

6.2 What gets penalised, where the gradient goes

A short walk through one optimiser step. Solid arrows are the forward pass; dashed arrows are the backward (autodiff) pass, each labelled with the chain-rule factor it carries.

The key chain-rule fact, read off the dashed path above: stop_gradient (or trainable=False) blocks gradients into the parameters of the flow, but does not block gradients into its inputs — the mixture-CDF Jacobian Tz(z)/z\partial T_z(z)/\partial z is still smooth and non-zero, so the predictor still receives a fairness signal. Without this property the whole scheme collapses. TzT_z, TqT_q, and TzqT_{zq} themselves receive no gradient: they have zero trainable_weights.

6.3 File layout

projects/gaussianization/
├── src/gaussianization/fair/
│   ├── __init__.py          # public API
│   ├── losses.py            # GaussianizedXCovLoss / MutualInfoLoss / TotalCorrelationLoss
│   ├── pretrain.py          # fit_and_freeze, fit_and_freeze_joint
│   ├── freeze.py            # freeze_flow helper
│   └── metrics.py           # numpy fairness eval metrics
├── tests/test_fair.py       # 15 tests including closed-form checks
├── notebooks/fair_gauss/
│   ├── 05_fair_gauss_pretrain.ipynb
│   ├── 06_fair_gauss_synthetic.ipynb   # G-XCOV + G-MI + G-TC + CKA
│   ├── 07_fair_gauss_adult.ipynb       # same on UCI Adult
│   └── _style.py
└── docs/fair_gaussianization_experiment.md   # this file

7. Datasets

DatasetSensitive qqTaskUse
Synthetic regression: y=tanh(x1)+0.5x2+3q+εy = \tanh(x_1) + 0.5 x_2 + 3 q + \varepsilon, qBern(0.5)q\sim\mathrm{Bern}(0.5)qqregressionNotebook 06; the structure is exactly the fairkl fair_model_wrapper benchmark.
UCI Adult Census (OpenML id adult v2)genderbinary classificationNotebook 07; ~49k rows, 5 numeric features + gender as a feature.
(Future) Engineered quadratic-dependence datasetqqregressionH3 test: zz determined by q2q^2 so ρ0\rho \approx 0 but MI > 0. Distinguishes G-MI from G-TC.
(Future) COMPASraceclassificationSecond real-data check.

8. Evaluation plan

For every (dataset, loss, μ, seed) combination report:

Comparison axes

  1. Loss family: cka | g_xcov | g_mi | g_tc × μ{0,0.1,0.5,2,10,50,200}\mu \in \{0, 0.1, 0.5, 2, 10, 50, 200\}.
  2. Flow depth (num_blocks ∈ {2, 4, 8}) — does a deeper joint flow improve G-TC on the quadratic-dependence test (H3)?
  3. Pretraining set — same-as-task vs held-out vs i.i.d. Gaussian (the latter collapses G-XCOV to vanilla cross-cov; control).
  4. Batch size — G-MI and G-TC may be more batch-sensitive than G-XCOV (eigh + flow forward).

Statistical rigour. 3 seeds for the headline figures, 5 for any “X beats Y at matched fairness” claim, paired-bootstrap CIs.

Magnitude calibration follow-up. Each loss has a different natural scale. To make the loss-family axis a fair comparison, scale μ per-family by the loss value at the unconstrained baseline, so that effective μ puts everyone at the same relative pressure. Report both raw-μ and calibrated-μ Pareto fronts.


9. Risks & open questions

Other riskMitigation
Eigendecomposition (ops.linalg.eigh) is slow for large dzd_zWe expect dz=1d_z = 1 in practice (regression / single-logit classification); scalar fast path is taken.
Frozen-flow gradient is silently zeroedExisting test_loss_gradients_flow_to_predictor parametrised over G-XCOV and G-MI asserts non-zero gradients to θ.
Comparing μ’s across losses is misleadingThe doc warns about it; notebooks plot each loss as its own Pareto curve, no point-to-point comparison at matched μ.

Open questions for follow-ups.


10. Milestones

#MilestoneAcceptance
✅ M1Skeleton: fair/{losses,freeze,pretrain,metrics}.py + tests pass on synthetic datapytest tests/test_fair.py green; lint + typecheck clean
✅ M2Notebook 05: pretrain + freeze + 4 diagnosticsExecuted, committed with figures
✅ M3Notebook 06: synthetic Pareto with G-XCOV vs CKAPareto curve from RMSE 0.11 to 1.35
✅ M4Notebook 07: Adult Pareto with G-XCOV vs CKAPareto traced; CKA beats G-XCOV terminal fairness as expected
🟡 M5G-MI loss + G-TC loss + extended testsThis commit
🟡 M6Notebooks 06 and 07 re-executed with G-MI and G-TC trajectoriesPareto plots show 4 curves (CKA / G-XCOV / G-MI / G-TC)
⏳ M7H3 quadratic-dependence experiment isolating G-TC’s advantageNew notebook 08_quadratic_dependence.ipynb
⏳ M8Hydra config + DVC stage for reproducibilitypixi run dvc repro regenerates all figures
⏳ M9Magnitude-calibration follow-upCalibrated-μ Pareto plots alongside raw-μ

Appendix A — Minimum working example

import os
os.environ.setdefault("KERAS_BACKEND", "jax")

import keras
import numpy as np

from fairkl.models import FairModelWrapper
from gaussianization.fair import (
    GaussianizedMutualInfoLoss,
    GaussianizedXCovLoss,
    GaussianizedTotalCorrelationLoss,
    fit_and_freeze,
    fit_and_freeze_joint,
)

# Synthetic data with q as a feature
rng = np.random.default_rng(0)
n = 4000
q = rng.binomial(1, 0.5, n).astype("float32")
x = rng.standard_normal((n, 3)).astype("float32")
y = (np.tanh(x[:, 0]) + 0.5 * x[:, 1] + 3 * q
     + 0.1 * rng.standard_normal(n)).astype("float32")
X = np.concatenate([x, q.reshape(-1, 1)], axis=1)

# Stage 1: pretrain + freeze probes
flow_y, _  = fit_and_freeze(y.reshape(-1, 1), num_blocks=4, epochs=40, seed=0)
flow_q, _  = fit_and_freeze(q.reshape(-1, 1), num_blocks=2, epochs=40, seed=0)
flow_yq, _ = fit_and_freeze_joint(y, q, num_blocks=4, epochs=40, seed=0)

# Stage 2: pick a loss and train
mlp = keras.Sequential([keras.layers.Dense(32, "relu"),
                        keras.layers.Dense(32, "relu"),
                        keras.layers.Dense(1)])
fair = FairModelWrapper(
    mlp, mu=2.0,
    fairness_loss=GaussianizedMutualInfoLoss(flow_z=flow_y, flow_q=flow_q),
    # or: GaussianizedXCovLoss(flow_z=flow_y, flow_q=flow_q),
    # or: GaussianizedTotalCorrelationLoss(joint_flow=flow_yq),
)
fair.compile(optimizer="adam", loss="mse")
fair.fit(X, y, q=q, epochs=40, batch_size=256)
References
  1. Pérez-Suay, A., Laparra, V., Mateo-Garcı́a, G., Muñoz-Marı́, J., Gómez-Chova, L., & Camps-Valls, G. (2017). Fair Kernel Learning. Joint European Conference on Machine Learning and Knowledge Discovery in Databases (ECML PKDD). 10.1007/978-3-319-71249-9_21
  2. Chen, S. S., & Gopinath, R. A. (2000). Gaussianization. Advances in Neural Information Processing Systems (NeurIPS).
  3. Laparra, V., Camps-Valls, G., & Malo, J. (2011). Iterative Gaussianization: From ICA to Random Rotations. IEEE Transactions on Neural Networks, 22(4), 537–549. 10.1109/TNN.2011.2106511
  4. Meng, C., Song, Y., Song, J., & Ermon, S. (2020). Gaussianization Flows. International Conference on Artificial Intelligence and Statistics (AISTATS).
  5. Cortes, C., Mohri, M., & Rostamizadeh, A. (2012). Algorithms for Learning Kernels Based on Centered Alignment. Journal of Machine Learning Research, 13, 795–828.
  6. Kornblith, S., Norouzi, M., Lee, H., & Hinton, G. (2019). Similarity of Neural Network Representations Revisited. International Conference on Machine Learning (ICML).
  7. Gretton, A., Bousquet, O., Smola, A., & Schölkopf, B. (2005). Measuring Statistical Dependence with Hilbert–Schmidt Norms. Algorithmic Learning Theory (ALT). 10.1007/11564089_7
  8. Gel’fand, I. M., & Yaglom, A. M. (1957). Computation of the Amount of Information about a Random Function Contained in Another Such Function. American Mathematical Society Translations, 12, 199–246.
  9. Watanabe, S. (1960). Information Theoretical Analysis of Multivariate Correlation. IBM Journal of Research and Development, 4(1), 66–82. 10.1147/rd.41.0066
  10. Belghazi, M. I., Baratin, A., Rajeswar, S., Ozair, S., Bengio, Y., Hjelm, R. D., & Courville, A. (2018). Mutual Information Neural Estimation. International Conference on Machine Learning (ICML).