Parameterized Marginal Gaussianization#
This is my notebook where I play around with all things normalizing flow with pyro. I use the following packages:
PyTorch
Pyro
PyTorch Lightning
Wandb
#@title Install Packages
# %%capture
!pip install --upgrade --quiet pyro-ppl tqdm wandb corner loguru pytorch-lightning lightning-bolts torchtyping einops plum-dispatch pyyaml==5.4.1 nflows
!pip install --upgrade --quiet scipy
!git clone https://github.com/jejjohnson/survae_flows_lib.git
!pip install survae_flows_lib/. --use-feature=in-tree-build
#@title Import Packages
# TYPE HINTS
from typing import Tuple, Optional, Dict, Callable, Union
from pprint import pprint
# PyTorch
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
# Pyro Imports
import pyro.distributions as dist
import pyro.distributions.transforms as T
# PyTorch Lightning Imports
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pl_bolts.datamodules import SklearnDataModule
# wandb imports
import wandb
from tqdm.notebook import trange, tqdm
from pytorch_lightning.loggers import TensorBoardLogger
# Logging Settings
from loguru import logger
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
logger.info("Using device: {}".format(device))
# NUMPY SETTINGS
import numpy as np
np.set_printoptions(precision=3, suppress=True)
# MATPLOTLIB Settings
import corner
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
# SEABORN SETTINGS
import seaborn as sns
sns.set_context(context='talk',font_scale=0.7)
# PANDAS SETTINGS
import pandas as pd
pd.set_option("display.max_rows", 120)
pd.set_option("display.max_columns", 120)
HelpFul Functions#
Generate 2D Grid#
def generate_2d_grid(data: np.ndarray, n_grid: int = 1_000, buffer: float = 0.01) -> np.ndarray:
xline = np.linspace(data[:, 0].min() - buffer, data[:, 0].max() + buffer, n_grid)
yline = np.linspace(data[:, 1].min() - buffer, data[:, 1].max() + buffer, n_grid)
xgrid, ygrid = np.meshgrid(xline, yline)
xyinput = np.concatenate([xgrid.reshape(-1, 1), ygrid.reshape(-1, 1)], axis=1)
return xyinput
Plot 2D Grid#
from matplotlib import cm
def plot_2d_grid(X_plot, X_grid, X_log_prob):
# Estimated Density
cmap = cm.magma # "Reds"
probs = np.exp(X_log_prob)
# probs = np.clip(probs, 0.0, 1.0)
# probs = np.clip(probs, None, 0.0)
cmap = cm.magma # "Reds"
# cmap = "Reds"
fig, ax = plt.subplots(ncols=2, figsize=(12, 5))
h = ax[0].hist2d(
X_plot[:, 0], X_plot[:, 1], bins=512, cmap=cmap, density=True, vmin=0.0, vmax=1.0
)
ax[0].set_title("True Density")
ax[0].set(
xlim=[X_plot[:, 0].min(), X_plot[:, 0].max()],
ylim=[X_plot[:, 1].min(), X_plot[:, 1].max()],
)
h1 = ax[1].scatter(
X_grid[:, 0], X_grid[:, 1], s=1, c=probs, cmap=cmap, #vmin=0.0, vmax=1.0
)
ax[1].set(
xlim=[X_grid[:, 0].min(), X_grid[:, 0].max()],
ylim=[X_grid[:, 1].min(), X_grid[:, 1].max()],
)
# plt.colorbar(h1)
ax[1].set_title("Estimated Density")
plt.tight_layout()
plt.show()
return fig, ax
Torch 2 Numpy#
def torch_2_numpy(X):
if not isinstance(X, np.ndarray):
try:
X = X.numpy()
except RuntimeError:
X = X.detach().numpy()
except TypeError:
X = X.detach().cpu().numpy()
return X
2D Toy Data#
def get_toy_data(n_samples=1000, seed=123):
rng = np.random.RandomState(seed=seed)
x = np.abs(2 * rng.randn(n_samples, 1))
y = np.sin(x) + 0.25 * rng.randn(n_samples, 1)
data = np.hstack((x, y))
return data
X = get_toy_data(5_000, 123)
# get marginal data
X_1 = X[:, 0][:, None]
X_2 = X[:, 1][:, None]
# # Data
# ds = CheckerboardDataset
# test = Dataset(get_toy_data(2_000, 100)[0])
# train_loader = DataLoader(train, batch_size=64, shuffle=False)
# test_loader = DataLoader(test, batch_size=256, shuffle=True)
fig = corner.corner(X, color="blue")
fig.suptitle("Sine Wave")
plt.show()
Model#
Parameterized Marginal Gaussianization#
Univariate GMM#
where:
\(K\) - # of components
\(\pi_k\) - the probit weighting term for the \(k\)-th component
\(\boldsymbol{\theta} = \{ \boldsymbol{\mu}_K, \boldsymbol{\sigma}_K, \boldsymbol{\pi}_K \}\)
We will use the scikit-learn
implementation to learn the parameters for the Gaussian mixture model. They use the Expectation-Maximization (EM) scheme to solve for the parameters.
from sklearn.mixture import GaussianMixture, BayesianGaussianMixture
n_components = 4
random_state = 123
covariance_type = "diag"
# init gmm model
mg_bijection = GaussianMixture(
n_components=n_components,
random_state=random_state,
covariance_type=covariance_type
)
mg_bijection.fit(X_1)
Viz - PDF#
x_domain = np.linspace(X_1.min(), X_1.max(), 100)
probs = mg_bijection.score_samples(x_domain[:, None])
fig, ax = plt.subplots()
ax.plot(x_domain, np.exp(probs), label="Estimated Density")
ax.hist(X_1.ravel(), density=True, bins=50, color="Red", label="Samples")
plt.legend()
plt.show()
CDF - From Scratch#
where \(\phi\) is the CDF of the Gaussian distribution.
Trick: - Stabilization#
We often like to put everything in terms of logs. This reparameterizes the functions so that the gradient updates of the parameters will be small irregardless of the function.
# extract parameters
logit_weights = mg_bijection.weights_[None, :]
means = mg_bijection.means_.T
sigmas = np.sqrt(mg_bijection.covariances_).T
# convert to tensors
# assert shapes
assert_shape = (1, n_components)
assert logit_weights.shape == assert_shape
assert means.shape == assert_shape
assert sigmas.shape == assert_shape
pi = (D, K)
mu = (D, K)
sigma = (D, K)
where:
K
- number of mixture componentsD
- dimensionality of the data
# convert to tensors
logit_weights = torch.Tensor(logit_weights)
means = torch.Tensor(means)
sigmas = torch.Tensor(sigmas)
import torch.functional as F
import torch.distributions as dist
from scipy import stats
from scipy.special import log_softmax, logsumexp
from einops import repeat
# create base dist
mixture_dist = dist.Normal(means, sigmas)
# Calculate the CDF
x_cdfs = mixture_dist.cdf(torch.Tensor(x_domain).unsqueeze(-1))
# calculate mixture cdf
z_cdfs = logit_weights * x_cdfs
# sum mixture distributions
z_cdf = z_cdfs.sum(axis=-1)
fig, ax = plt.subplots()
ax.plot(torch_2_numpy(x_domain), torch_2_numpy(z_cdf), label="CDF")
plt.legend()
plt.show()
# Calculate the CDF
x_cdfs = mixture_dist.cdf(torch.Tensor(X_1))
# calculat mixture cdf
z_cdfs = logit_weights * x_cdfs
# sum mixture distributions
z_cdf = z_cdfs.sum(axis=-1)
fig, ax = plt.subplots()
ax.hist(torch_2_numpy(z_cdf), density=True, bins=50, color="Red", label="Uniform Domain")
plt.legend()
plt.show()
PDF - From Scratch#
We are going to use the same function as before. The only difference is the we will use the PDF instead of the CDF of a Gaussian.
# Calculate the PDF
x_pdfs = mixture_dist.log_prob(torch.Tensor(x_domain[:, None]))
# log softmax of weights
# log_weights = log_softmax(logit_weights, axis=-1)
log_weights = torch.log(logit_weights)
# calculat mixture cdf
z_logpdfs = log_weights + x_pdfs
# sum mixture distributions
z_logpdf = torch.logsumexp(z_logpdfs, axis=-1)
fig, ax = plt.subplots()
ax.plot(x_domain, torch_2_numpy(z_logpdf.exp()), label="PDF (Ours)", linewidth=4)
ax.plot(x_domain, torch_2_numpy(np.exp(probs)), label="PDF (GMM)", linestyle="dotted", linewidth=4)
ax.hist(X_1.ravel(), density=True, bins=50, color="Red", label="Samples")
plt.legend()
plt.show()
Inverse CDF#
Don’t do this#
z_domain = torch.linspace(0.01, 0.99, 100)
# Calculate the CDF
x_icdfs = mixture_dist.icdf(z_domain.unsqueeze(-1))
# calculat mixture cdf
xs = logit_weights * x_icdfs
# sum mixture distributions
x_domain = xs.sum(axis=-1)
fig, ax = plt.subplots()
ax.plot(z_domain, x_domain, label="Inverse CDF")
plt.legend()
plt.show()
# Calculate the CDF
x_icdfs = mixture_dist.icdf(z_cdf.unsqueeze(-1))
# calculat mixture cdf
xs = logit_weights * x_icdfs
# sum mixture distributions
x_approx = xs.sum(axis=-1)
fig, ax = plt.subplots()
ax.hist(torch_2_numpy(X_1), density=True, bins=50, color="Blue", label="Original Data")
ax.hist(torch_2_numpy(x_approx), density=True, bins=50, color="Red", label="Inverse Transform")
plt.legend()
plt.show()
Bisection Search#
from survae.transforms.bijections.functional.iterative_inversion import bisection_inverse
def mix_cdf(x, logit_weights, means, sigmas):
# Calculate the CDF
x_cdfs = mixture_dist.cdf(x.unsqueeze(-1))
mix_dist = dist.Normal(means, sigmas)
# calculat mixture cdf
z_cdfs = logit_weights * x_cdfs
# sum mixture distributions
z_cdf = z_cdfs.sum(axis=-1)
return z_cdf
# initialize the parameters
max_scales = torch.sum(sigmas, dim=-1, keepdim=True)
init_lower, _ = (means - 20 * max_scales).min(dim=-1)
init_upper, _ = (means + 20 * max_scales).max(dim=-1)
x_approx = bisection_inverse(
fn=lambda x: mix_cdf(x, logit_weights, means, sigmas),
z=z_cdf,
init_x=torch.zeros_like(z_cdf),
init_lower=init_lower,
init_upper=init_upper,
eps=1e-10,
max_iters=100,
)
x_approx_ = gaussian_mixture_transform(z_cdf.unsqueeze(-1), logit_weights, means, sigmas.log(), inverse=True)
fig, ax = plt.subplots()
ax.hist(torch_2_numpy(X_1), density=True, bins=50, color="Blue", label="Original Data")
ax.hist(torch_2_numpy(x_approx), density=True, bins=50, color="Red", label="Inverse Transform")
plt.legend()
plt.show()
Bijections#
from survae.transforms.bijections.functional.mixtures.gaussian_mixture import gaussian_mixture_transform
fig, ax = plt.subplots()
ax.plot(z_domain, x_approx, label="Inverse CDF")
ax.plot(z_domain, x_approx_, label="Inverse CDF (Bisection)")
plt.legend()
plt.show()
fig, ax = plt.subplots()
ax.plot(x_domain, z_cdf, label="CDF")
plt.legend()
plt.show()
SurVAE Flows Function#
Now, we will translate the code into PyTorch. I will create a functional form so that we just need to call it within our Bijector class.
# from survae.transforms.bijections.functional.mixtures import gaussian_mixture_transform
from torch.distributions import Normal
def gaussian_mixture_transform(inputs, logit_weights, means, log_scales):
dist = Normal(means, log_scales.exp())
def mix_cdf(x):
return torch.sum(logit_weights * dist.cdf(x.unsqueeze(-1)), dim=-1)
def mix_log_pdf(x):
return torch.logsumexp(logit_weights.log() + dist.log_prob(x.unsqueeze(-1)), dim=-1)
z = mix_cdf(inputs)
ldj = mix_log_pdf(inputs)
return z, ldj
Initialization#
I will also create an initialization function which will allow us to initialization the parameters from the GMM model.
from survae.transforms.bijections.functional.mixtures.gaussian_mixture import init_marginal_mixture_weights
init_marginal_mixture_weights??
num_mixtures = 4
logit_weights, means, sigmas = init_marginal_mixture_weights(X, num_mixtures)
X_mu, ldj = gaussian_mixture_transform(
torch.Tensor(X),
torch.Tensor(logit_weights),
torch.Tensor(means),
torch.Tensor(np.log(sigmas))
)
fig = corner.corner(X_mu.numpy())
PyTorch Class#
from survae.transforms.bijections.elementwise_nonlinear import GaussianMixtureCDF
from survae.transforms.bijections.elementwise_nonlinear import InverseGaussCDF
GaussianMixtureCDF??
InverseGaussCDF??
Flow Model#
So to have a flow model, we need two components:
Base Distribution: \(p_Z = \mathbb{P}_Z\)
This will describe the distribution we want in the transform domain. In this case, we will choose the uniform distribution because we are trying to uniformize our data.
Bijections: \(f = f_L \circ f_{L-1} \circ \ldots \circ f_1\)
The list of bijections. These are our functions which we would like to compose together to get our dataset.
from survae.distributions import StandardUniform, StandardNormal
from survae.flows import Flow
# base distribution
base_dist = StandardNormal((2,))
# transforms
transforms = [
GaussianMixtureCDF((2,), None, 4),
InverseGaussCDF()
]
# flow model
model = Flow(
base_dist=base_dist,
transforms=transforms
)
Forward Transformation#
with torch.no_grad():
X_mu, ldj = model.forward_transform(torch.Tensor(X))
fig = corner.corner(torch_2_numpy(X_mu))
Inverse Transformation#
with torch.no_grad():
X_mu, ldj = model.forward_transform(torch.Tensor(X))
fig = corner.corner(torch_2_numpy(X_mu))
Training#
Dataset#
# # Data
X_train = get_toy_data(5_000, 123)
train_loader = DataLoader(torch.Tensor(X_train), batch_size=128, shuffle=True)
Loss#
def nll_loss(model, data):
return - model.log_prob(data).mean()
Pytorch-Lightning Trainer#
import pytorch_lightning as pl
class Learner2DPlane(pl.LightningModule):
def __init__(self, model:nn.Module):
super().__init__()
self.model = model
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
# loss function
# loss = -self.model.log_prob(batch).mean()
loss = nll_loss(self.model, batch)
self.log("train_loss", loss)
return {'loss': loss}
def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=1e-3)
def train_dataloader(self):
return train_loader
# initialize trainer
learn = Learner2DPlane(model)
n_epochs = 50
logger = TensorBoardLogger("tb_logs", name='mg_no_init')
# initialize trainer
trainer = pl.Trainer(min_epochs=1, max_epochs=n_epochs, gpus=1, enable_progress_bar=True, logger=logger)
Logging#
# %load_ext tensorboard
# %tensorboard --logdir tb_logs/
Training#
# train model
trainer.fit(learn, )
Results#
Latent Domain#
with torch.no_grad():
X_ = torch.Tensor(X)
# X_ = X_.to(device)
X_r, ldj = learn.model.forward_transform(X_)
fig = corner.corner(torch_2_numpy(X_r))
Inverse#
with torch.no_grad():
# X_ = X_.to(device)
X_approx = learn.model.inverse_transform(X_r)
fig = corner.corner(torch_2_numpy(X_approx))