Parameterized Rotations#

This is my notebook where I play around with all things normalizing flow with pyro. I use the following packages:

  • PyTorch

  • Pyro

  • PyTorch Lightning

  • Wandb

#@title Install Packages
# %%capture

!pip install --upgrade --quiet pyro-ppl tqdm wandb corner loguru pytorch-lightning lightning-bolts torchtyping einops plum-dispatch pyyaml==5.4.1 nflows
!pip install --upgrade --quiet scipy
!git clone https://github.com/jejjohnson/survae_flows_lib.git
!pip install survae_flows_lib/. --use-feature=in-tree-build
#@title Import Packages

# TYPE HINTS
from typing import Tuple, Optional, Dict, Callable, Union
from pprint import pprint

# PyTorch
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset

# Pyro Imports
import pyro.distributions as dist
import pyro.distributions.transforms as T

# PyTorch Lightning Imports
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pl_bolts.datamodules import SklearnDataModule

# wandb imports
import wandb
from tqdm.notebook import trange, tqdm
from pytorch_lightning.loggers import TensorBoardLogger


# Logging Settings
from loguru import logger
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
logger.info("Using device: {}".format(device))

# NUMPY SETTINGS
import numpy as np
np.set_printoptions(precision=3, suppress=True)

# MATPLOTLIB Settings
import corner
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# SEABORN SETTINGS
import seaborn as sns
sns.set_context(context='talk',font_scale=0.7)

# PANDAS SETTINGS
import pandas as pd
pd.set_option("display.max_rows", 120)
pd.set_option("display.max_columns", 120)

HelpFul Functions#

Generate 2D Grid#

def generate_2d_grid(data: np.ndarray, n_grid: int = 1_000, buffer: float = 0.01) -> np.ndarray:

    xline = np.linspace(data[:, 0].min() - buffer, data[:, 0].max() + buffer, n_grid)
    yline = np.linspace(data[:, 1].min() - buffer, data[:, 1].max() + buffer, n_grid)
    xgrid, ygrid = np.meshgrid(xline, yline)
    xyinput = np.concatenate([xgrid.reshape(-1, 1), ygrid.reshape(-1, 1)], axis=1)
    return xyinput

Plot 2D Grid#

from matplotlib import cm

def plot_2d_grid(X_plot, X_grid, X_log_prob):



    # Estimated Density
    cmap = cm.magma  # "Reds"
    probs = np.exp(X_log_prob)
    # probs = np.clip(probs, 0.0, 1.0)
    # probs = np.clip(probs, None, 0.0)


    cmap = cm.magma  # "Reds"
    # cmap = "Reds"

    fig, ax = plt.subplots(ncols=2, figsize=(12, 5))
    h = ax[0].hist2d(
        X_plot[:, 0], X_plot[:, 1], bins=512, cmap=cmap, density=True, vmin=0.0, vmax=1.0
    )
    ax[0].set_title("True Density")
    ax[0].set(
        xlim=[X_plot[:, 0].min(), X_plot[:, 0].max()],
        ylim=[X_plot[:, 1].min(), X_plot[:, 1].max()],
    )


    h1 = ax[1].scatter(
        X_grid[:, 0], X_grid[:, 1], s=1, c=probs, cmap=cmap, #vmin=0.0, vmax=1.0
    )
    ax[1].set(
        xlim=[X_grid[:, 0].min(), X_grid[:, 0].max()],
        ylim=[X_grid[:, 1].min(), X_grid[:, 1].max()],
    )
    # plt.colorbar(h1)
    ax[1].set_title("Estimated Density")


    plt.tight_layout()
    plt.show()
    return fig, ax

Torch 2 Numpy#

def torch_2_numpy(X):

    if not isinstance(X, np.ndarray):
        try:
            X = X.numpy()
        except RuntimeError:
            X = X.detach().numpy()
        except TypeError:
            X = X.detach().cpu().numpy()
    

    return X

2D Toy Data#

def get_toy_data(n_samples=1000, seed=123):
    rng = np.random.RandomState(seed=seed)

    x = np.abs(2 * rng.randn(n_samples, 1))
    y = np.sin(x) + 0.25 * rng.randn(n_samples, 1)
    data = np.hstack((x, y))

    return data

X = get_toy_data(5_000, 123)

init_X = torch.Tensor(X)
# # Data

# ds = CheckerboardDataset



# test = Dataset(get_toy_data(2_000, 100)[0])
# train_loader = DataLoader(train, batch_size=64, shuffle=False)
# test_loader = DataLoader(test, batch_size=256, shuffle=True)


fig = corner.corner(X, color="blue")
fig.suptitle("Sine Wave")
plt.show()

Parameterized Marginal Gaussianization#

from survae.transforms.bijections.elementwise_nonlinear import GaussianMixtureCDF
from survae.transforms.bijections.elementwise_nonlinear import InverseGaussCDF
# GaussianMixtureCDF??


# marginal gaussianization
mu_bijector = GaussianMixtureCDF((2,), None, 4)

# inverse gaussian cdf
icdf_bijector = InverseGaussCDF()


with torch.no_grad():

    X_mu, _ = mu_bijector.forward(init_X)
    X_mg, _ = icdf_bijector.forward(X_mu)


fig = corner.corner(torch_2_numpy(X_mu), color="blue")
fig.suptitle("Marginal Uniformization")
plt.show()

fig = corner.corner(torch_2_numpy(X_mg), color="blue")
fig.suptitle("Marginal Gaussianization")
plt.show()

Orthogonal Parameterization#

HouseHolder Reflections#

Algorithm

\(\mathbf{H}_k\) is reflection matrix and is defined by.

\[ \mathbf{H}_K = \mathbf{I} - 2 \frac{\mathbf{v}_k\mathbf{v}_k^\top}{||\mathbf{v}||_2^2} \]

where \(\mathbf{v}_k \in \mathbb{R}^{D}\).

num_reflections = 3
num_dimensions = 2

# create vectors, v
v_vectors = torch.ones(num_reflections, num_dimensions)

# calc denominator
squared_norms = torch.sum(v_vectors ** 2, dim=-1)

# initialize loop
Q = torch.eye(num_dimensions)

Multiply all matrices together.

\[ \mathbf{R} = \mathbf{H}_K \mathbf{H}_{K-1}\ldots \mathbf{H}_{1} \]

where \(\mathbf{R},\mathbf{H}_k \in \mathbb{R}^{D \times D}\) are orthogonal matrices.

# loop through all vectors
for v_vector, squared_norm in zip(v_vectors, squared_norms):

    # Inner product.
    temp = Q @ v_vector  

    # Outer product.
    temp = torch.ger(temp, (2.0 / squared_norm) * v_vector)  
    Q = Q - temp
# check dimensions
assert Q.shape == (num_dimensions, num_dimensions)

# check it's orthogonal
assert (Q @ Q.T).all() == torch.eye(num_dimensions).all()
X_r = init_X @ Q
fig = corner.corner(X_r.cpu().numpy(), color="blue")
fig.suptitle("Sine Wave")
plt.show()
def householder_product(vectors: torch.Tensor) -> torch.Tensor:
    """
    Args:
        vectors [K,D] - q vectors for the reflections
    
    Returns:
        R [D, D] - householder reflections
    """
    num_reflections, num_dimensions = vectors.shape

    squared_norms = torch.sum(vectors ** 2, dim=-1)

    # initialize reflection
    H = torch.eye(num_dimensions)

    for vector, squared_norm in zip(vectors, squared_norms):
        temp = H @ vector  # Inner product.
        temp = torch.ger(temp, (2.0 / squared_norm) * vector)  # Outer product.
        H = H - temp

    return H


# initialize vectors
v_vectors = torch.ones(num_reflections, num_dimensions)
v_vectors = torch.nn.init.orthogonal_(v_vectors)


# householder product
R = householder_product(v_vectors)

# inverse householder product
reverse_idx = torch.arange(num_reflections - 1, -1, -1)
R_inv = householder_product(v_vectors[reverse_idx])


# check the inverse
torch.testing.assert_allclose(R @ R.T, torch.eye(num_dimensions), rtol=1e-5, atol=1e-5, )

Cost - O(KDN) O(KD^2)

Pytorch Class#

from survae.transforms.bijections import Bijection

class LinearHouseholder(Bijection):
    """
    """

    def __init__(self, num_features: int, num_reflections: int = 2):
        super(LinearHouseholder, self).__init__()
        self.num_features = num_features
        self.num_reflections = num_reflections
        

        # initialize vectors param
        vectors = torch.randn(num_reflections, num_features)
        self.vectors = nn.Parameter(vectors)

        # initialize parameter to be orthogonal
        nn.init.orthogonal_(self.vectors)

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        # get rotation matrix
        R = householder_product(self.vectors)

        # Z = x @ R
        z = torch.mm(x, R)

        # ldj -> identity
        batch_size = x.shape[0]
        ldj = x.new_zeros(batch_size)

        return z, ldj

    def inverse(self, z):
        # get rotation matrix (in reverse)
        reverse_idx = torch.arange(self.num_reflections - 1, -1, -1)
        R = householder_product(self.vectors[reverse_idx])
        
        x = torch.mm(z, R)

        return x

with torch.no_grad():
    hh_bijector = LinearHouseholder(2, 2)
    X_r, ldj = hh_bijector.forward(init_X)
    X_approx = hh_bijector.inverse(X_r)

# check the inverse
torch.testing.assert_allclose(X_approx, init_X, rtol=1e-5, atol=1e-5, )
from survae.transforms.bijections.linear_orthogonal import FastHouseholder, LinearHouseholder

Flow Model#

So to have a flow model, we need two components:


Base Distribution: \(p_Z = \mathbb{P}_Z\)

This will describe the distribution we want in the transform domain. In this case, we will choose the uniform distribution because we are trying to uniformize our data.


Bijections: \(f = f_L \circ f_{L-1} \circ \ldots \circ f_1\)

The list of bijections. These are our functions which we would like to compose together to get our dataset.

from survae.distributions import StandardUniform, StandardNormal
from survae.flows import Flow

# parameters
features_shape = (2,)
num_mixtures = 4
num_reflections = 2

# base distribution
base_dist = StandardNormal(features_shape)

# transforms
transforms = [
              GaussianMixtureCDF(features_shape, None, num_mixtures),
              InverseGaussCDF(),
              LinearHouseholder(features_shape[0], num_reflections)

]

# flow model
model = Flow(
    base_dist=base_dist,
    transforms=transforms
)

Training#

Dataset#

# # Data
X_train = get_toy_data(5_000, 123)

train_loader = DataLoader(torch.Tensor(X_train), batch_size=128, shuffle=True)

Loss#

def nll_loss(model, data):
    return - model.log_prob(data).mean()

Pytorch-Lightning Trainer#

import pytorch_lightning as pl

class Learner2DPlane(pl.LightningModule):
    def __init__(self, model:nn.Module):
        super().__init__()
        self.model = model

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        
        # loss function
        # loss = -self.model.log_prob(batch).mean()
        loss = nll_loss(self.model, batch)
        
        self.log("train_loss", loss)
        
        return {'loss': loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=1e-3)

    def train_dataloader(self):
        return train_loader
# initialize trainer
learn = Learner2DPlane(model)
n_epochs = 50
# logger = TensorBoardLogger("tb_logs", name='mg_no_init')

# initialize trainer
trainer = pl.Trainer(min_epochs=1, max_epochs=n_epochs, gpus=1, enable_progress_bar=True, logger=logger)

Logging#

# %load_ext tensorboard
# %tensorboard --logdir tb_logs/

Training#

# train model
trainer.fit(learn, )

Results#

Latent Domain#

with torch.no_grad():
    X_ = torch.Tensor(X)
    # X_ = X_.to(device)
    X_r, ldj = learn.model.forward_transform(X_)


fig = corner.corner(torch_2_numpy(X_r))

Inverse#

with torch.no_grad():
    # X_ = X_.to(device)
    X_approx = learn.model.inverse_transform(X_r)

fig = corner.corner(torch_2_numpy(X_approx))