Example - 2D Plane#
This is my notebook where I play around with all things normalizing flow with pyro. I use the following packages:
PyTorch
Pyro
PyTorch Lightning
Wandb
#@title Install Packages
# %%capture
!pip install --upgrade --quiet pyro-ppl tqdm wandb corner loguru pytorch-lightning lightning-bolts torchtyping einops plum-dispatch pyyaml==5.4.1 nflows
!pip install --upgrade --quiet scipy
!git clone https://github.com/jejjohnson/survae_flows_lib.git
!pip install survae_flows_lib/. --use-feature=in-tree-build
#@title Import Packages
# TYPE HINTS
from typing import Tuple, Optional, Dict, Callable, Union
from pprint import pprint
# PyTorch
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from einops import rearrange
# Pyro Imports
import pyro.distributions as dist
import pyro.distributions.transforms as T
# PyTorch Lightning Imports
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pl_bolts.datamodules import SklearnDataModule
# wandb imports
import wandb
from tqdm.notebook import trange, tqdm
from pytorch_lightning.loggers import TensorBoardLogger
# Logging Settings
from loguru import logger
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
logger.info("Using device: {}".format(device))
# NUMPY SETTINGS
import numpy as np
np.set_printoptions(precision=3, suppress=True)
# MATPLOTLIB Settings
import corner
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
# SEABORN SETTINGS
import seaborn as sns
sns.set_context(context='talk',font_scale=0.7)
# PANDAS SETTINGS
import pandas as pd
pd.set_option("display.max_rows", 120)
pd.set_option("display.max_columns", 120)
HelpFul Functions#
Torch 2 Numpy#
def torch_2_numpy(X):
if not isinstance(X, np.ndarray):
try:
X = X.numpy()
except RuntimeError:
X = X.detach().numpy()
except TypeError:
X = X.detach().cpu().numpy()
return X
2D Toy Data#
def get_toy_data(n_samples=1000, seed=123):
rng = np.random.RandomState(seed=seed)
x = np.abs(2 * rng.randn(n_samples, 1))
y = np.sin(x) + 0.25 * rng.randn(n_samples, 1)
data = np.hstack((x, y))
return data
X = get_toy_data(5_000, 123)
fig = corner.corner(X, color="blue")
fig.suptitle("Sine Wave")
plt.show()
Model#
from survae.transforms.bijections.elementwise_nonlinear import GaussianMixtureCDF, InverseGaussCDF
from survae.transforms.bijections.linear_orthogonal import LinearHouseholder
from survae.distributions import StandardUniform, StandardNormal
from survae.flows import Flow
import pytorch_lightning as pl
Naive GF Initialization#
def init_gf_layers(num_mixtures: int, num_layers: int=5, num_reflections: int=2, **kwargs):
transforms = []
with trange(num_layers) as pbar:
for ilayer in pbar:
# MARGINAL UNIFORMIZATION
ilayer = GaussianMixtureCDF(shape, num_mixtures=num_mixtures)
# save layer
transforms.append(ilayer)
# ELEMENT-WISE INVERSE GAUSSIAN CDF
ilayer = InverseGaussCDF()
# save layer
transforms.append(ilayer)
# HOUSEHOLDER TRANSFORM
ilayer = LinearHouseholder(shape[0], num_householder=num_reflections)
# save layer
transforms.append(ilayer)
return transforms
shape = (2,)
# base distribution
base_dist = StandardNormal(shape)
# init GF
transforms = init_gf_layers(shape=shape, num_mixtures=6, num_layers=12, num_householder=2)
# flow model
model = Flow(
base_dist=base_dist,
transforms=transforms
)
Training#
Dataset#
# # Data
X_train = get_toy_data(10_000, 123)
train_loader = DataLoader(torch.Tensor(X_train), batch_size=128, shuffle=True)
Loss#
def nll_loss(model, data):
return - model.log_prob(data).mean()
Pytorch-Lightning Trainer#
import pytorch_lightning as pl
class Learner2DPlane(pl.LightningModule):
def __init__(self, model:nn.Module):
super().__init__()
self.model = model
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
# loss function
# loss = -self.model.log_prob(batch).mean()
loss = nll_loss(self.model, batch)
self.log("train_loss", loss)
return {'loss': loss}
def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=1e-2)
def train_dataloader(self):
return train_loader
# initialize trainer
learn = Learner2DPlane(model)
n_epochs = 20
logger = TensorBoardLogger("tb_logs", name='mg_no_init')
# initialize trainer
trainer = pl.Trainer(min_epochs=1, max_epochs=n_epochs, gpus=1, enable_progress_bar=True, logger=logger)
Logging#
# %load_ext tensorboard
# %tensorboard --logdir tb_logs/
Training#
# train model
trainer.fit(learn, )
Results#
Latent Domain#
with torch.no_grad():
X_ = torch.Tensor(X)
X_ = X_.to(learn.device)
X_r, ldj = learn.model.forward_transform(X_)
fig = corner.corner(torch_2_numpy(X_r))
Inverse#
with torch.no_grad():
# X_ = X_.to(device)
X_approx = learn.model.inverse_transform(X_r)
fig = corner.corner(torch_2_numpy(X_approx))
Samples#
X_samples = learn.model.sample(5_000)
fig = corner.corner(torch_2_numpy(X_samples))
Log Probability#
Better Initialization#
Notice how we did not actually initialize the layers with the best parameters using the data.
def init_gf_layers_rbig(shape: Tuple[int], num_mixtures: int, num_reflections: int=2, num_layers: int=5, X=None, **kwargs):
transforms = []
X = torch.Tensor(X)
with trange(num_layers) as pbar:
for ilayer in pbar:
# MARGINAL UNIFORMIZATION
ilayer = GaussianMixtureCDF(shape, X=torch_2_numpy(X), num_mixtures=num_mixtures)
# forward transform
X, _ = ilayer.forward(X)
# save layer
transforms.append(ilayer)
# ELEMENT-WISE INVERSE GAUSSIAN CDF
ilayer = InverseGaussCDF()
# forward transform
X, _ = ilayer.forward(X)
# save layer
transforms.append(ilayer)
# ELEMENT-WISE INVERSE GAUSSIAN CDF
ilayer = LinearHouseholder(shape[0], num_householder=num_reflections)
# forward transform
X, _ = ilayer.forward(X)
# save layer
transforms.append(ilayer)
return transforms
shape = (2,)
# base distribution
base_dist = StandardNormal(shape)
# init GF
transforms = init_gf_layers_rbig(shape=shape, X=X, num_mixtures=6, num_layers=12, num_householder=2)
# flow model
model = Flow(
base_dist=base_dist,
transforms=transforms
)
# initialize trainer
learn = Learner2DPlane(model)
n_epochs = 20
logger = TensorBoardLogger("tb_logs", name='mg_rbig_init')
# initialize trainer
trainer = pl.Trainer(min_epochs=1, max_epochs=n_epochs, gpus=1, enable_progress_bar=True, logger=logger)
# train model
trainer.fit(learn, )
Results#
Latent Domain#
with torch.no_grad():
X_ = torch.Tensor(X)
X_ = X_.to(learn.device)
X_r, ldj = learn.model.forward_transform(X_)
fig = corner.corner(torch_2_numpy(X_r))
Inverse#
with torch.no_grad():
# X_ = X_.to(device)
X_approx = learn.model.inverse_transform(X_r)
fig = corner.corner(torch_2_numpy(X_approx))
Samples#
X_samples = learn.model.sample(5_000)
fig = corner.corner(torch_2_numpy(X_samples))
Log Probability#
def generate_2d_grid(data: np.ndarray, n_grid: int = 1_000, buffer: float = 0.01) -> np.ndarray:
xline = np.linspace(data[:, 0].min() - buffer, data[:, 0].max() + buffer, n_grid)
yline = np.linspace(data[:, 1].min() - buffer, data[:, 1].max() + buffer, n_grid)
xgrid, ygrid = np.meshgrid(xline, yline)
xyinput = np.concatenate([xgrid.reshape(-1, 1), ygrid.reshape(-1, 1)], axis=1)
return xyinput
# sampled data
xyinput = generate_2d_grid(X, 500, buffer=0.1)
from matplotlib import cm
def plot_2d_grid(X_plot, X_grid, X_log_prob):
# Estimated Density
cmap = cm.magma # "Reds"
probs = np.exp(X_log_prob)
# probs = np.clip(probs, 0.0, 1.0)
# probs = np.clip(probs, None, 0.0)
cmap = cm.magma # "Reds"
# cmap = "Reds"
fig, ax = plt.subplots(ncols=2, figsize=(12, 5))
h = ax[0].hist2d(
X_plot[:, 0], X_plot[:, 1], bins=512, cmap=cmap, density=True, vmin=0.0, vmax=1.0
)
ax[0].set_title("True Density")
ax[0].set(
xlim=[X_plot[:, 0].min(), X_plot[:, 0].max()],
ylim=[X_plot[:, 1].min(), X_plot[:, 1].max()],
)
h1 = ax[1].scatter(
X_grid[:, 0], X_grid[:, 1], s=1, c=probs, cmap=cmap, #vmin=0.0, vmax=1.0
)
ax[1].set(
xlim=[X_grid[:, 0].min(), X_grid[:, 0].max()],
ylim=[X_grid[:, 1].min(), X_grid[:, 1].max()],
)
# plt.colorbar(h1)
ax[1].set_title("Estimated Density")
plt.tight_layout()
plt.show()
return fig, ax
with torch.no_grad():
X_ = torch.Tensor(xyinput)
# X_ = X_.to(device)
X_log_prob = learn.model.log_prob(X_)
X_log_prob = torch_2_numpy(X_log_prob)
plot_2d_grid(torch_2_numpy(X), torch_2_numpy(xyinput), torch_2_numpy(X_log_prob))