Fair learning with frozen Gaussianization flows — design doc
Three fairness penalties built from a frozen Gaussianization flow
Fair learning with frozen Gaussianization flows¶
1. TL;DR¶
Replace the CKA fairness penalty in keras-fairkl’s
FairModelWrapper — a Keras port of the fair-kernel-learning idea of
Pérez-Suay et al. (2017) — with a family of fairness losses
built from a frozen Gaussianization flow. The flow is trained once on an
auxiliary dataset, its weights are frozen, and it is then used as a
differentiable Gaussian-space probe inside the downstream task’s
optimisation loop. Gaussianisation lets us replace bandwidth-tuned
RBF kernels (CKA, HSIC) with closed-form, parametric,
scale-invariant penalties — the flow absorbs the kernel choice.
Three penalties, in order of strictness:
Table (1):Output-side fairness losses; see 4. Mathematical formulation for the math.
| Loss | Captures | Closed form? | Joint flow needed? | Class |
|---|---|---|---|---|
| G-XCOV | 2nd-moment dependence in Gaussianised space (linear CKA there) | yes | no — two marginal flows | GaussianizedXCovLoss |
| G-MI | MI assuming joint-Gaussian after Gaussianisation | yes | no — two marginal flows | GaussianizedMutualInfoLoss |
| G-TC | Full MI / total correlation, no joint-Gaussian assumption | no — via flow NLL | yes — one joint flow over | GaussianizedTotalCorrelationLoss |
All three are differentiable w.r.t. the downstream model parameters and
plug into FairModelWrapper via its fairness_loss=... argument.
2. The mental model — three pictures¶
2.1 Stage 1 (one-time): pretrain the probes¶
2.2 Stage 2 (every step): the fair training loop¶
The trick — and the load-bearing claim of this whole experiment — is that the flow’s weights are frozen but the flow’s input is the predictor’s output, so gradients still propagate from the loss back through to θ.
2.3 Where each loss penalises¶
3. Why Gaussianisation helps (the one-paragraph theory)¶
A Gaussianization flow — in the lineage of Chen & Gopinath (2000), Laparra et al. (2011), and Meng et al. (2020) — is a smooth diffeomorphism with smooth inverse, so it preserves all statistical dependence: . What it changes is the shape of the marginals. After training, each marginal of is approximately . Three consequences:
Bandwidth-free dependence measures. CKA and HSIC need a kernel bandwidth; the “right” bandwidth depends on the scale of the data. In Gaussianised space the scale is fixed at 1, so a linear kernel (or a unit-bandwidth RBF) suffices. The flow absorbs the bandwidth choice into its mixture-CDF parameters during pretraining.
Gaussian-joint assumption becomes nearly free. Closed-form MI () requires assuming the joint is Gaussian. On raw data this is wildly wrong. After marginal Gaussianisation it is much closer to true — the marginals are exact, only the copula remains non-Gaussian — so the closed-form MI estimate becomes a usable surrogate.
Compatibility with frozen-flow autodiff. All flow components (
MixtureCDFGaussianization,Householder) are smooth in their inputs. Stopping gradient on the flow’s weights does not stop gradient flow through the flow’s inputs, so the predictor’s parameters θ still receive a gradient signal from via the chain rule.
The flow is therefore exactly the right thing to freeze: a fixed, smooth, scale-normalising, differentiable preprocessor that turns “measure non-linear dependence in the data” into “measure linear dependence between near-Gaussian variables.”
4. Mathematical formulation¶
Let be the predictor output and the sensitive attribute. Define , in Gaussianised space, and let , , denote sample (cross-)covariances on a batch of size .
4.1 G-XCOV — linear-CKA in Gaussianised space¶
Equation ((1)) is exact linear CKA
Cortes et al., 2012Kornblith et al., 2019 applied to the Gaussianised
features. In it collapses to where ρ is
the Gaussianised cross-correlation. The un-normalised numerator
is identically HSIC with linear
kernels Gretton et al., 2005 on the Gaussianised features — so
this single loss covers both “linear CKA in Gaussianised space”
and “HSIC with linear kernels in Gaussianised space” depending
on whether you toggle normalize.
Bounded, smooth, second-moment only. Gradient is bounded — the loss is gentle near perfect dependence.
4.2 G-MI — closed-form Gaussian mutual information¶
If were jointly Gaussian with standardised marginals, mutual information has the Gel’fand–Yaglom closed form Gel'fand & Yaglom, 1957:
After Gaussianisation and , so ((2)) simplifies to
In ((3)) is . The gradient diverges as , so G-MI is much sharper than G-XCOV at high dependence. We clip the eigenvalues of at a small for numerical safety; this caps the loss at .
The closed-form requires the joint to be Gaussian after Gaussianisation. Marginal Gaussianisation gets us most of the way there, but the residual copula is still arbitrary — so G-MI underestimates true MI whenever the dependence has structure beyond second-order correlation (quadratic-in-, XOR-style, multi-modal). That gap is exactly what G-TC closes.
4.3 G-TC — total correlation under a frozen joint flow¶
Pretrain a joint flow on the empirical product distribution of the baseline data: draw where π is a random permutation. By construction these pairs are independent, so a well-fit Gaussianises independent draws to . Freeze .
At downstream training time, evaluate the same frozen on the actual (potentially dependent) pair :
When is independent like the baseline, and ((4)) equals the entropy of (a constant in θ). When carries dependence, no longer Gaussianises the joint and the NLL is strictly larger.
By a change-of-variables argument, this NLL difference is exactly the KL divergence between and — i.e. the mutual information; see Watanabe (1960) for the total-correlation framing. Unlike G-MI it does not assume the joint is Gaussian: the flow itself learns the copula during pretraining. The price is needing a richer pretraining stage.
4.4 Comparison table¶
| Property | G-XCOV | G-MI | G-TC |
|---|---|---|---|
| Order of dependence | 2nd moment | All (joint-Gaussian) | All (no joint assumption) |
| Range | |||
| Gradient at high dep | Bounded | Diverges | Bounded if flow well-fit |
| Closed form | yes | yes | no — needs flow forward pass |
| Pretraining | 2 marginal flows | 2 marginal flows | 2 marginal or 1 joint flow |
| Compute / batch | one matmul | one matmul + eigh | full joint-flow forward |
| Sensitive to copula structure beyond 2nd order | no | no | yes |
4.5 Deferred (stretch)¶
- G-HSIC-RBF. HSIC with a unit-bandwidth RBF kernel in Gaussianised space. Identical to existing CKA code but with in place of . Useful as an ablation to separate “flow as preprocessor” from “linear vs RBF after the flow”.
- DR-MI. True mutual information via Monte-Carlo marginalisation through the joint flow. Compare with the neural estimator of Belghazi et al. (2018). Expensive; G-TC already captures the same signal at lower cost.
5. Hypotheses we’re testing¶
Each loss family makes a falsifiable prediction about what kind of dependence it can suppress, and at what cost in accuracy.
6. Experiment design¶
6.1 Two-stage pipeline (reprise of the ASCII diagrams)¶
Stage 1 — pretrain + freeze.
- Train on the predictor-output distribution of an unconstrained baseline (e.g. for classification, sigmoid probabilities of the baseline MLP — not the raw binary labels, otherwise the flow lives on a 2-point support that is off-support of the actual predictions).
- Train on the marginal of .
- Train on the shuffled product distribution of the same baseline data — independent pairs by construction.
- Freeze all weights.
Stage 2 — fair downstream training.
- Drop in
FairModelWrapper(base, mu=μ, fairness_loss=...)with any of the three new losses. compile(..., loss="...")as usual. The wrapper handles the dict packing{"x": X, "q": q}.
6.2 What gets penalised, where the gradient goes¶
A short walk through one optimiser step. Solid arrows are the forward pass; dashed arrows are the backward (autodiff) pass, each labelled with the chain-rule factor it carries.
The key chain-rule fact, read off the dashed path above: stop_gradient
(or trainable=False) blocks gradients into the parameters of the
flow, but does not block gradients into its inputs — the
mixture-CDF Jacobian is still smooth and
non-zero, so the predictor still receives a fairness signal. Without
this property the whole scheme collapses. , , and
themselves receive no gradient: they have zero trainable_weights.
6.3 File layout¶
projects/gaussianization/
├── src/gaussianization/fair/
│ ├── __init__.py # public API
│ ├── losses.py # GaussianizedXCovLoss / MutualInfoLoss / TotalCorrelationLoss
│ ├── pretrain.py # fit_and_freeze, fit_and_freeze_joint
│ ├── freeze.py # freeze_flow helper
│ └── metrics.py # numpy fairness eval metrics
├── tests/test_fair.py # 15 tests including closed-form checks
├── notebooks/fair_gauss/
│ ├── 05_fair_gauss_pretrain.ipynb
│ ├── 06_fair_gauss_synthetic.ipynb # G-XCOV + G-MI + G-TC + CKA
│ ├── 07_fair_gauss_adult.ipynb # same on UCI Adult
│ └── _style.py
└── docs/fair_gaussianization_experiment.md # this file7. Datasets¶
| Dataset | Sensitive | Task | Use |
|---|---|---|---|
| Synthetic regression: , | regression | Notebook 06; the structure is exactly the fairkl fair_model_wrapper benchmark. | |
UCI Adult Census (OpenML id adult v2) | gender | binary classification | Notebook 07; ~49k rows, 5 numeric features + gender as a feature. |
| (Future) Engineered quadratic-dependence dataset | regression | H3 test: determined by so but MI > 0. Distinguishes G-MI from G-TC. | |
| (Future) COMPAS | race | classification | Second real-data check. |
8. Evaluation plan¶
For every (dataset, loss, μ, seed) combination report:
- Predictive metrics. RMSE / R² (regression). Accuracy, ROC-AUC, log-loss (classification).
- Fairness metrics (numpy-side, neutral judge). Demographic-parity difference, equalized-odds difference, |Pearson()|. These are computed at evaluation time, not used as the training loss, so they are independent of the training penalty.
- Diagnostic: training-time loss value. Each loss’s own value through training (so we can see G-MI saturate at the eps clip if it does, etc.).
Comparison axes
- Loss family:
cka | g_xcov | g_mi | g_tc× . - Flow depth (
num_blocks ∈ {2, 4, 8}) — does a deeper joint flow improve G-TC on the quadratic-dependence test (H3)? - Pretraining set — same-as-task vs held-out vs i.i.d. Gaussian (the latter collapses G-XCOV to vanilla cross-cov; control).
- Batch size — G-MI and G-TC may be more batch-sensitive than G-XCOV (eigh + flow forward).
Statistical rigour. 3 seeds for the headline figures, 5 for any “X beats Y at matched fairness” claim, paired-bootstrap CIs.
Magnitude calibration follow-up. Each loss has a different natural
scale. To make the loss-family axis a fair comparison, scale μ
per-family by the loss value at the unconstrained baseline, so that
effective μ puts everyone at the same relative pressure. Report
both raw-μ and calibrated-μ Pareto fronts.
9. Risks & open questions¶
| Other risk | Mitigation |
|---|---|
Eigendecomposition (ops.linalg.eigh) is slow for large | We expect in practice (regression / single-logit classification); scalar fast path is taken. |
| Frozen-flow gradient is silently zeroed | Existing test_loss_gradients_flow_to_predictor parametrised over G-XCOV and G-MI asserts non-zero gradients to θ. |
| Comparing μ’s across losses is misleading | The doc warns about it; notebooks plot each loss as its own Pareto curve, no point-to-point comparison at matched μ. |
Open questions for follow-ups.
- Does fine-tuning the marginal flows during downstream training (a light EMA update) actually help, or does it leak into the predictor’s gradient and corrupt the fairness signal?
- G-TC requires a joint flow. Can a single pretrained once serve every downstream task on the same data, or does it need to be re-trained per architecture?
- For multi-class sensitive attributes (race in COMPAS), does the one-hot encoding of work, or do we need a categorical flow head?
10. Milestones¶
| # | Milestone | Acceptance |
|---|---|---|
| ✅ M1 | Skeleton: fair/{losses,freeze,pretrain,metrics}.py + tests pass on synthetic data | pytest tests/test_fair.py green; lint + typecheck clean |
| ✅ M2 | Notebook 05: pretrain + freeze + 4 diagnostics | Executed, committed with figures |
| ✅ M3 | Notebook 06: synthetic Pareto with G-XCOV vs CKA | Pareto curve from RMSE 0.11 to 1.35 |
| ✅ M4 | Notebook 07: Adult Pareto with G-XCOV vs CKA | Pareto traced; CKA beats G-XCOV terminal fairness as expected |
| 🟡 M5 | G-MI loss + G-TC loss + extended tests | This commit |
| 🟡 M6 | Notebooks 06 and 07 re-executed with G-MI and G-TC trajectories | Pareto plots show 4 curves (CKA / G-XCOV / G-MI / G-TC) |
| ⏳ M7 | H3 quadratic-dependence experiment isolating G-TC’s advantage | New notebook 08_quadratic_dependence.ipynb |
| ⏳ M8 | Hydra config + DVC stage for reproducibility | pixi run dvc repro regenerates all figures |
| ⏳ M9 | Magnitude-calibration follow-up | Calibrated-μ Pareto plots alongside raw-μ |
Appendix A — Minimum working example¶
import os
os.environ.setdefault("KERAS_BACKEND", "jax")
import keras
import numpy as np
from fairkl.models import FairModelWrapper
from gaussianization.fair import (
GaussianizedMutualInfoLoss,
GaussianizedXCovLoss,
GaussianizedTotalCorrelationLoss,
fit_and_freeze,
fit_and_freeze_joint,
)
# Synthetic data with q as a feature
rng = np.random.default_rng(0)
n = 4000
q = rng.binomial(1, 0.5, n).astype("float32")
x = rng.standard_normal((n, 3)).astype("float32")
y = (np.tanh(x[:, 0]) + 0.5 * x[:, 1] + 3 * q
+ 0.1 * rng.standard_normal(n)).astype("float32")
X = np.concatenate([x, q.reshape(-1, 1)], axis=1)
# Stage 1: pretrain + freeze probes
flow_y, _ = fit_and_freeze(y.reshape(-1, 1), num_blocks=4, epochs=40, seed=0)
flow_q, _ = fit_and_freeze(q.reshape(-1, 1), num_blocks=2, epochs=40, seed=0)
flow_yq, _ = fit_and_freeze_joint(y, q, num_blocks=4, epochs=40, seed=0)
# Stage 2: pick a loss and train
mlp = keras.Sequential([keras.layers.Dense(32, "relu"),
keras.layers.Dense(32, "relu"),
keras.layers.Dense(1)])
fair = FairModelWrapper(
mlp, mu=2.0,
fairness_loss=GaussianizedMutualInfoLoss(flow_z=flow_y, flow_q=flow_q),
# or: GaussianizedXCovLoss(flow_z=flow_y, flow_q=flow_q),
# or: GaussianizedTotalCorrelationLoss(joint_flow=flow_yq),
)
fair.compile(optimizer="adam", loss="mse")
fair.fit(X, y, q=q, epochs=40, batch_size=256)- Pérez-Suay, A., Laparra, V., Mateo-Garcı́a, G., Muñoz-Marı́, J., Gómez-Chova, L., & Camps-Valls, G. (2017). Fair Kernel Learning. Joint European Conference on Machine Learning and Knowledge Discovery in Databases (ECML PKDD). 10.1007/978-3-319-71249-9_21
- Chen, S. S., & Gopinath, R. A. (2000). Gaussianization. Advances in Neural Information Processing Systems (NeurIPS).
- Laparra, V., Camps-Valls, G., & Malo, J. (2011). Iterative Gaussianization: From ICA to Random Rotations. IEEE Transactions on Neural Networks, 22(4), 537–549. 10.1109/TNN.2011.2106511
- Meng, C., Song, Y., Song, J., & Ermon, S. (2020). Gaussianization Flows. International Conference on Artificial Intelligence and Statistics (AISTATS).
- Cortes, C., Mohri, M., & Rostamizadeh, A. (2012). Algorithms for Learning Kernels Based on Centered Alignment. Journal of Machine Learning Research, 13, 795–828.
- Kornblith, S., Norouzi, M., Lee, H., & Hinton, G. (2019). Similarity of Neural Network Representations Revisited. International Conference on Machine Learning (ICML).
- Gretton, A., Bousquet, O., Smola, A., & Schölkopf, B. (2005). Measuring Statistical Dependence with Hilbert–Schmidt Norms. Algorithmic Learning Theory (ALT). 10.1007/11564089_7
- Gel’fand, I. M., & Yaglom, A. M. (1957). Computation of the Amount of Information about a Random Function Contained in Another Such Function. American Mathematical Society Translations, 12, 199–246.
- Watanabe, S. (1960). Information Theoretical Analysis of Multivariate Correlation. IBM Journal of Research and Development, 4(1), 66–82. 10.1147/rd.41.0066
- Belghazi, M. I., Baratin, A., Rajeswar, S., Ozair, S., Bengio, Y., Hjelm, R. D., & Courville, A. (2018). Mutual Information Neural Estimation. International Conference on Machine Learning (ICML).