Parameterized Marginal Gaussianization#

This is my notebook where I play around with all things normalizing flow with pyro. I use the following packages:

  • PyTorch

  • Pyro

  • PyTorch Lightning

  • Wandb

#@title Install Packages
# %%capture

!pip install --upgrade --quiet pyro-ppl tqdm wandb corner loguru pytorch-lightning lightning-bolts torchtyping einops plum-dispatch pyyaml==5.4.1 nflows
!pip install --upgrade --quiet scipy
!git clone https://github.com/jejjohnson/survae_flows_lib.git
!pip install survae_flows_lib/. --use-feature=in-tree-build
#@title Import Packages

# TYPE HINTS
from typing import Tuple, Optional, Dict, Callable, Union
from pprint import pprint

# PyTorch
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset

# Pyro Imports
import pyro.distributions as dist
import pyro.distributions.transforms as T

# PyTorch Lightning Imports
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pl_bolts.datamodules import SklearnDataModule

# wandb imports
import wandb
from tqdm.notebook import trange, tqdm
from pytorch_lightning.loggers import TensorBoardLogger


# Logging Settings
from loguru import logger
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
logger.info("Using device: {}".format(device))

# NUMPY SETTINGS
import numpy as np
np.set_printoptions(precision=3, suppress=True)

# MATPLOTLIB Settings
import corner
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# SEABORN SETTINGS
import seaborn as sns
sns.set_context(context='talk',font_scale=0.7)

# PANDAS SETTINGS
import pandas as pd
pd.set_option("display.max_rows", 120)
pd.set_option("display.max_columns", 120)

HelpFul Functions#

Generate 2D Grid#

def generate_2d_grid(data: np.ndarray, n_grid: int = 1_000, buffer: float = 0.01) -> np.ndarray:

    xline = np.linspace(data[:, 0].min() - buffer, data[:, 0].max() + buffer, n_grid)
    yline = np.linspace(data[:, 1].min() - buffer, data[:, 1].max() + buffer, n_grid)
    xgrid, ygrid = np.meshgrid(xline, yline)
    xyinput = np.concatenate([xgrid.reshape(-1, 1), ygrid.reshape(-1, 1)], axis=1)
    return xyinput

Plot 2D Grid#

from matplotlib import cm

def plot_2d_grid(X_plot, X_grid, X_log_prob):



    # Estimated Density
    cmap = cm.magma  # "Reds"
    probs = np.exp(X_log_prob)
    # probs = np.clip(probs, 0.0, 1.0)
    # probs = np.clip(probs, None, 0.0)


    cmap = cm.magma  # "Reds"
    # cmap = "Reds"

    fig, ax = plt.subplots(ncols=2, figsize=(12, 5))
    h = ax[0].hist2d(
        X_plot[:, 0], X_plot[:, 1], bins=512, cmap=cmap, density=True, vmin=0.0, vmax=1.0
    )
    ax[0].set_title("True Density")
    ax[0].set(
        xlim=[X_plot[:, 0].min(), X_plot[:, 0].max()],
        ylim=[X_plot[:, 1].min(), X_plot[:, 1].max()],
    )


    h1 = ax[1].scatter(
        X_grid[:, 0], X_grid[:, 1], s=1, c=probs, cmap=cmap, #vmin=0.0, vmax=1.0
    )
    ax[1].set(
        xlim=[X_grid[:, 0].min(), X_grid[:, 0].max()],
        ylim=[X_grid[:, 1].min(), X_grid[:, 1].max()],
    )
    # plt.colorbar(h1)
    ax[1].set_title("Estimated Density")


    plt.tight_layout()
    plt.show()
    return fig, ax

Torch 2 Numpy#

def torch_2_numpy(X):

    if not isinstance(X, np.ndarray):
        try:
            X = X.numpy()
        except RuntimeError:
            X = X.detach().numpy()
        except TypeError:
            X = X.detach().cpu().numpy()
    

    return X

2D Toy Data#

def get_toy_data(n_samples=1000, seed=123):
    rng = np.random.RandomState(seed=seed)

    x = np.abs(2 * rng.randn(n_samples, 1))
    y = np.sin(x) + 0.25 * rng.randn(n_samples, 1)
    data = np.hstack((x, y))

    return data

X = get_toy_data(5_000, 123)


# get marginal data
X_1 = X[:, 0][:, None]
X_2 = X[:, 1][:, None]
# # Data

# ds = CheckerboardDataset



# test = Dataset(get_toy_data(2_000, 100)[0])
# train_loader = DataLoader(train, batch_size=64, shuffle=False)
# test_loader = DataLoader(test, batch_size=256, shuffle=True)


fig = corner.corner(X, color="blue")
fig.suptitle("Sine Wave")
plt.show()

Model#

Parameterized Marginal Gaussianization#

Univariate GMM#

\[ p(\mathbf{x};\boldsymbol{\theta}) = \sum_k^K \pi_k \mathcal{N}(x| \mu_k, \sigma_k) \]

where:

  • \(K\) - # of components

  • \(\pi_k\) - the probit weighting term for the \(k\)-th component

  • \(\boldsymbol{\theta} = \{ \boldsymbol{\mu}_K, \boldsymbol{\sigma}_K, \boldsymbol{\pi}_K \}\)

We will use the scikit-learn implementation to learn the parameters for the Gaussian mixture model. They use the Expectation-Maximization (EM) scheme to solve for the parameters.

from sklearn.mixture import GaussianMixture, BayesianGaussianMixture
n_components = 4
random_state = 123
covariance_type = "diag"

# init gmm model
mg_bijection = GaussianMixture(
    n_components=n_components, 
    random_state=random_state, 
    covariance_type=covariance_type
)
mg_bijection.fit(X_1)

Viz - PDF#

x_domain = np.linspace(X_1.min(), X_1.max(), 100)
probs = mg_bijection.score_samples(x_domain[:, None])
fig, ax = plt.subplots()
ax.plot(x_domain, np.exp(probs), label="Estimated Density")
ax.hist(X_1.ravel(), density=True, bins=50, color="Red", label="Samples")
plt.legend()
plt.show()

CDF - From Scratch#

\[ F(x) = \sum_k^K \pi_k \phi(x| \mu_k, \sigma_k) \]

where \(\phi\) is the CDF of the Gaussian distribution.

Trick: - Stabilization#

We often like to put everything in terms of logs. This reparameterizes the functions so that the gradient updates of the parameters will be small irregardless of the function.

\[ \log U(x) = \sum_k^K \log \pi_k + \log \sum_k^K \mathcal{N}_{CDF}(x|\mu_k, \sigma_k) \]
# extract parameters
logit_weights = mg_bijection.weights_[None, :]
means = mg_bijection.means_.T
sigmas = np.sqrt(mg_bijection.covariances_).T

# convert to tensors

# assert shapes
assert_shape = (1, n_components)

assert logit_weights.shape == assert_shape
assert means.shape == assert_shape
assert sigmas.shape == assert_shape
  • pi = (D, K)

  • mu = (D, K)

  • sigma = (D, K)

where:

  • K - number of mixture components

  • D - dimensionality of the data

# convert to tensors
logit_weights = torch.Tensor(logit_weights)
means = torch.Tensor(means)
sigmas = torch.Tensor(sigmas)


import torch.functional as F
import torch.distributions as dist
from scipy import stats
from scipy.special import log_softmax, logsumexp
from einops import repeat
# create base dist
mixture_dist = dist.Normal(means, sigmas)

# Calculate the CDF
x_cdfs = mixture_dist.cdf(torch.Tensor(x_domain).unsqueeze(-1))

# calculate mixture cdf
z_cdfs = logit_weights * x_cdfs

# sum mixture distributions
z_cdf = z_cdfs.sum(axis=-1)
fig, ax = plt.subplots()
ax.plot(torch_2_numpy(x_domain), torch_2_numpy(z_cdf), label="CDF")
plt.legend()
plt.show()
# Calculate the CDF
x_cdfs = mixture_dist.cdf(torch.Tensor(X_1))

# calculat mixture cdf
z_cdfs = logit_weights * x_cdfs

# sum mixture distributions
z_cdf = z_cdfs.sum(axis=-1)
fig, ax = plt.subplots()
ax.hist(torch_2_numpy(z_cdf), density=True, bins=50, color="Red", label="Uniform Domain")
plt.legend()
plt.show()

PDF - From Scratch#

We are going to use the same function as before. The only difference is the we will use the PDF instead of the CDF of a Gaussian.

\[ \log \nabla U(x) = \sum_k^K \log \pi_k + \log \sum_k^K \mathcal{N}_{PDF}(x|\mu_k, \sigma_k) \]
# Calculate the PDF
x_pdfs = mixture_dist.log_prob(torch.Tensor(x_domain[:, None]))

# log softmax of weights
# log_weights = log_softmax(logit_weights, axis=-1)
log_weights = torch.log(logit_weights)

# calculat mixture cdf
z_logpdfs = log_weights + x_pdfs

# sum mixture distributions
z_logpdf = torch.logsumexp(z_logpdfs, axis=-1)
fig, ax = plt.subplots()
ax.plot(x_domain, torch_2_numpy(z_logpdf.exp()), label="PDF (Ours)", linewidth=4)
ax.plot(x_domain, torch_2_numpy(np.exp(probs)), label="PDF (GMM)", linestyle="dotted", linewidth=4)
ax.hist(X_1.ravel(), density=True, bins=50, color="Red", label="Samples")
plt.legend()
plt.show()

Inverse CDF#

Don’t do this#
z_domain = torch.linspace(0.01, 0.99, 100)
# Calculate the CDF
x_icdfs = mixture_dist.icdf(z_domain.unsqueeze(-1))

# calculat mixture cdf
xs = logit_weights * x_icdfs

# sum mixture distributions
x_domain = xs.sum(axis=-1)
fig, ax = plt.subplots()

ax.plot(z_domain, x_domain, label="Inverse CDF")
plt.legend()
plt.show()
# Calculate the CDF
x_icdfs = mixture_dist.icdf(z_cdf.unsqueeze(-1))

# calculat mixture cdf
xs = logit_weights * x_icdfs

# sum mixture distributions
x_approx = xs.sum(axis=-1)
fig, ax = plt.subplots()
ax.hist(torch_2_numpy(X_1), density=True, bins=50, color="Blue", label="Original Data")
ax.hist(torch_2_numpy(x_approx), density=True, bins=50, color="Red", label="Inverse Transform")
plt.legend()
plt.show()

Bijections#

from survae.transforms.bijections.functional.mixtures.gaussian_mixture import gaussian_mixture_transform

fig, ax = plt.subplots()

ax.plot(z_domain, x_approx, label="Inverse CDF")
ax.plot(z_domain, x_approx_, label="Inverse CDF (Bisection)")
plt.legend()
plt.show()

fig, ax = plt.subplots()
ax.plot(x_domain, z_cdf, label="CDF")
plt.legend()
plt.show()

SurVAE Flows Function#

Now, we will translate the code into PyTorch. I will create a functional form so that we just need to call it within our Bijector class.

# from survae.transforms.bijections.functional.mixtures import gaussian_mixture_transform
from torch.distributions import Normal

def gaussian_mixture_transform(inputs, logit_weights, means, log_scales):

    dist = Normal(means, log_scales.exp())

    def mix_cdf(x):
        return torch.sum(logit_weights * dist.cdf(x.unsqueeze(-1)), dim=-1)

    def mix_log_pdf(x):
        return torch.logsumexp(logit_weights.log() + dist.log_prob(x.unsqueeze(-1)), dim=-1)
    z = mix_cdf(inputs)
    ldj = mix_log_pdf(inputs)

    return z, ldj

Initialization#

I will also create an initialization function which will allow us to initialization the parameters from the GMM model.

from survae.transforms.bijections.functional.mixtures.gaussian_mixture import init_marginal_mixture_weights
init_marginal_mixture_weights??
num_mixtures = 4

logit_weights, means, sigmas = init_marginal_mixture_weights(X, num_mixtures)
X_mu, ldj = gaussian_mixture_transform(
    torch.Tensor(X), 
    torch.Tensor(logit_weights), 
    torch.Tensor(means), 
    torch.Tensor(np.log(sigmas))
)
fig = corner.corner(X_mu.numpy())

PyTorch Class#

from survae.transforms.bijections.elementwise_nonlinear import GaussianMixtureCDF
from survae.transforms.bijections.elementwise_nonlinear import InverseGaussCDF

GaussianMixtureCDF??
InverseGaussCDF??

Flow Model#

So to have a flow model, we need two components:


Base Distribution: \(p_Z = \mathbb{P}_Z\)

This will describe the distribution we want in the transform domain. In this case, we will choose the uniform distribution because we are trying to uniformize our data.


Bijections: \(f = f_L \circ f_{L-1} \circ \ldots \circ f_1\)

The list of bijections. These are our functions which we would like to compose together to get our dataset.

from survae.distributions import StandardUniform, StandardNormal
from survae.flows import Flow


# base distribution
base_dist = StandardNormal((2,))

# transforms
transforms = [
              GaussianMixtureCDF((2,), None, 4),
              InverseGaussCDF()
]

# flow model
model = Flow(
    base_dist=base_dist,
    transforms=transforms
)

Forward Transformation#

with torch.no_grad():

    X_mu, ldj = model.forward_transform(torch.Tensor(X))

fig = corner.corner(torch_2_numpy(X_mu))

Inverse Transformation#

with torch.no_grad():

    X_mu, ldj = model.forward_transform(torch.Tensor(X))

fig = corner.corner(torch_2_numpy(X_mu))

Training#

Dataset#

# # Data
X_train = get_toy_data(5_000, 123)

train_loader = DataLoader(torch.Tensor(X_train), batch_size=128, shuffle=True)

Loss#

def nll_loss(model, data):
    return - model.log_prob(data).mean()

Pytorch-Lightning Trainer#

import pytorch_lightning as pl

class Learner2DPlane(pl.LightningModule):
    def __init__(self, model:nn.Module):
        super().__init__()
        self.model = model

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        
        # loss function
        # loss = -self.model.log_prob(batch).mean()
        loss = nll_loss(self.model, batch)
        
        self.log("train_loss", loss)
        
        return {'loss': loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=1e-3)

    def train_dataloader(self):
        return train_loader
# initialize trainer
learn = Learner2DPlane(model)
n_epochs = 50
logger = TensorBoardLogger("tb_logs", name='mg_no_init')

# initialize trainer
trainer = pl.Trainer(min_epochs=1, max_epochs=n_epochs, gpus=1, enable_progress_bar=True, logger=logger)

Logging#

# %load_ext tensorboard
# %tensorboard --logdir tb_logs/

Training#

# train model
trainer.fit(learn, )

Results#

Latent Domain#

with torch.no_grad():
    X_ = torch.Tensor(X)
    # X_ = X_.to(device)
    X_r, ldj = learn.model.forward_transform(X_)


fig = corner.corner(torch_2_numpy(X_r))

Inverse#

with torch.no_grad():
    # X_ = X_.to(device)
    X_approx = learn.model.inverse_transform(X_r)

fig = corner.corner(torch_2_numpy(X_approx))