Skip to content

Jax PlayGround

My starting notebook where I install all of the necessary libraries and load some easy 1D/2D Regression data to play around with.

#@title Install Packages
!pip install jax jaxlib --force
!pip install "git+https://github.com/pyro-ppl/numpyro.git#egg=numpyro"
Collecting jaxlib
  Using cached https://files.pythonhosted.org/packages/a0/e2/7e2c7e5b2b2b06c0868f8408f5ed016f8ee83540381cfe43d96bf1e8463b/jaxlib-0.1.55-cp36-none-manylinux2010_x86_64.whl
Collecting numpy>=1.12
  Using cached https://files.pythonhosted.org/packages/63/97/af8a92864a04bfa48f1b5c9b1f8bf2ccb2847f24530026f26dd223de4ca0/numpy-1.19.2-cp36-cp36m-manylinux2010_x86_64.whl
Collecting scipy
  Using cached https://files.pythonhosted.org/packages/2b/a8/f4c66eb529bb252d50e83dbf2909c6502e2f857550f22571ed8556f62d95/scipy-1.5.2-cp36-cp36m-manylinux1_x86_64.whl
Collecting absl-py
  Using cached https://files.pythonhosted.org/packages/b9/07/f69dd3367368ad69f174bfe426a973651412ec11d48ec05c000f19fe0561/absl_py-0.10.0-py3-none-any.whl
Collecting six
  Using cached https://files.pythonhosted.org/packages/ee/ff/48bde5c0f013094d729fe4b0316ba2a24774b3ff1c52d924a8a4cb04078a/six-1.15.0-py2.py3-none-any.whl
ERROR: fancyimpute 0.4.3 requires tensorflow, which is not installed.
ERROR: nbclient 0.5.0 has requirement jupyter-client>=6.1.5, but you'll have jupyter-client 5.3.5 which is incompatible.
ERROR: datascience 0.10.6 has requirement folium==0.2.1, but you'll have folium 0.8.3 which is incompatible.
ERROR: albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.9 which is incompatible.
Installing collected packages: numpy, scipy, six, absl-py, jaxlib
  Found existing installation: scipy 1.5.2
    Uninstalling scipy-1.5.2:
      Successfully uninstalled scipy-1.5.2
  Found existing installation: six 1.15.0
    Uninstalling six-1.15.0:
      Successfully uninstalled six-1.15.0
  Found existing installation: absl-py 0.10.0
    Uninstalling absl-py-0.10.0:
      Successfully uninstalled absl-py-0.10.0
  Found existing installation: jaxlib 0.1.55
    Uninstalling jaxlib-0.1.55:
      Successfully uninstalled jaxlib-0.1.55
Successfully installed absl-py-0.10.0 jaxlib-0.1.55 numpy-1.19.2 scipy-1.5.2 six-1.15.0
#@title Load Packages
# TYPE HINTS
from typing import Tuple, Optional, Dict, Callable, Union

# JAX SETTINGS
import jax
import jax.numpy as np
from jax import random, lax
import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI
from numpyro.infer.autoguide import AutoDiagonalNormal
from numpyro.infer import SVI, ELBO
numpyro.set_platform("0")

# NUMPY SETTINGS
import numpy as onp
onp.set_printoptions(precision=3, suppress=True)

# MATPLOTLIB Settings
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# SEABORN SETTINGS
import seaborn as sns
sns.set_context(context='talk',font_scale=0.7)

# PANDAS SETTINGS
import pandas as pd
pd.set_option("display.max_rows", 120)
pd.set_option("display.max_columns", 120)

# LOGGING SETTINGS
import sys
import logging
logging.basicConfig(
    level=logging.INFO, 
    stream=sys.stdout,
    format='%(asctime)s:%(levelname)s:%(message)s'
)
logger = logging.getLogger()
#logger.setLevel(logging.INFO)

%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
#@title Data
def get_data(
    N: int = 30,
    input_noise: float = 0.15,
    output_noise: float = 0.15,
    N_test: int = 400,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, None]:
    onp.random.seed(0)
    X = np.linspace(-1, 1, N)
    Y = X + 0.2 * np.power(X, 3.0) + 0.5 * np.power(0.5 + X, 2.0) * np.sin(4.0 * X)
    Y += output_noise * onp.random.randn(N)
    Y -= np.mean(Y)
    Y /= np.std(Y)

    X += input_noise * onp.random.randn(N)

    assert X.shape == (N,)
    assert Y.shape == (N,)

    X_test = np.linspace(-1.2, 1.2, N_test)

    return X[:, None], Y[:, None], X_test[:, None]
seed = 123
rng = onp.random.RandomState(seed)

n_samples = 50
noise = 0.05

X = np.linspace(-2*np.pi, 2*np.pi, n_samples)
y = np.sinc(X) + 0.01 * rng.randn(n_samples)


# Store data as torch.tensors.


# Plot data and true function.
plt.scatter(X, y, label='data')
plt.xlabel('x')
plt.ylabel('y = f(x)')
plt.legend();

Numpyro Model

# One-dimensional squared exponential kernel with diagonal noise term.
def squared_exp_cov_1D(X, Y, variance, lengthscale):
    deltaXsq = np.power((X[:, None] - Y) / lengthscale, 2.0)
    K = variance * np.exp(-0.5 * deltaXsq)
    return K

# GP model.
def GP(X, y):
    # Set informative log-normal priors on kernel hyperparameters.
    variance = numpyro.sample("kernel_var", dist.LogNormal(0.0, 0.1))
    lengthscale = numpyro.sample("kernel_length", dist.LogNormal(0.0, 1.0))
    sigma = numpyro.sample("sigma", dist.LogNormal(0.0, 1.0))

    # Compute kernel
    K = squared_exp_cov_1D(X, X, variance, lengthscale)
    K += np.eye(X.shape[0]) * np.power(sigma, 2)

    # Sample y according to the standard gaussian process formula
    numpyro.sample("y", dist.MultivariateNormal(loc=np.zeros(X.shape[0]),
                                                covariance_matrix=K), obs=y)

Inference - SVI

from numpyro import handlers

with handlers.seed(rng_seed=0):
    i = GP(X, y)
    print(i)
None
%%time
from numpyro.infer.autoguide import AutoLaplaceApproximation, AutoMultivariateNormal, AutoDiagonalNormal
from numpyro.infer import TraceMeanField_ELBO

# Compile
guide = AutoDiagonalNormal(GP)
optimizer = numpyro.optim.Adam(step_size=0.01)
svi = SVI(GP, guide, optimizer, loss=TraceMeanField_ELBO())
init_state = svi.init(random.PRNGKey(1), X, y)
CPU times: user 130 ms, sys: 98.2 ms, total: 228 ms
Wall time: 122 ms
%%time

# Run optimizer for 1000 iteratons.
state, losses = lax.scan(lambda state, i: 
                         svi.update(state, X, y),
                         init_state, np.arange(5000))

# Extract surrogate posterior.
params = svi.get_params(state)
plt.plot(losses);
plt.title("Negative ELBO (Loss)");
/usr/local/lib/python3.6/dist-packages/numpyro/infer/elbo.py:105: UserWarning: Failed to verify mean field restriction on the guide. To eliminate this warning, ensure model and guide sites occur in the same order.
Model sites:
  kernel_var
  kernel_length
  sigmaGuide sites:
  kernel_length
  kernel_var
  sigma
  "Guide sites:\n  " + "\n  ".join(guide_sites))
CPU times: user 2.47 s, sys: 407 ms, total: 2.88 s
Wall time: 2.24 s

Sample Posterior Distribution

n_test_samples = 100
advi_samples = guide.get_posterior(params).sample(random.PRNGKey(seed), (n_test_samples, ))

Parameters

fig, ax = plt.subplots(ncols=3, nrows=1, figsize=(15, 3))

labels = ["Variance", "Length Scale", "Noise"]

for i, ilabel in enumerate(labels):

    ax[i].hist(advi_samples[:, i], density=True, bins=20, label=ilabel)
    ax[i].legend()
plt.show()

Predictive Mean & Variance

def predict(kernel, X, Y, X_test, var, length, noise):
    # compute kernels between train and test data, etc.
    k_pp = kernel(X_test, X_test, var, length)
    k_pX = kernel(X_test, X, var, length)
    k_XX = kernel(X, X, var, length)
    k_XX += np.eye(X.shape[0]) * np.power(noise, 2)


    K_xx_inv = np.linalg.inv(k_XX)
    K = k_pp - np.matmul(k_pX, np.matmul(K_xx_inv, np.transpose(k_pX)))
    mean = np.matmul(k_pX, np.matmul(K_xx_inv, Y))
    # we return both the mean function and a sample from the posterior predictive for the
    # given set of hyperparameters
    return mean
n_test_samples = 10
advi_samples = guide.get_posterior(params).sample(random.PRNGKey(seed), (n_test_samples, ))

X_test = np.linspace(-2.2 * np.pi, 2.2 * np.pi, 1_000)

preds = [predict(squared_exp_cov_1D, X, y, X_test, iparam[0], iparam[1], iparam[2]) for iparam in advi_samples]
predictions = np.vstack(preds)
predictions.shape
(10, 1000)
# Summarize function posterior.
ci = 95
ci_lower = (100 - ci) / 2
ci_upper = (100 + ci) / 2
preds_mean = predictions.mean(0)
preds_lower = np.percentile(predictions, ci_lower, axis=0)
preds_upper = np.percentile(predictions, ci_upper, axis=0)
plt.plot(X_test, preds_mean)
plt.plot(X_test, preds_lower)
plt.plot(X_test, preds_upper)
[<matplotlib.lines.Line2D at 0x7fa2508dd550>]
fig, ax = plt.subplots()
ax.plot(X_test, preds_mean)
ax.fill_between(
    X_test.squeeze(), 
    preds_lower.squeeze(), 
    preds_upper.squeeze(), 
    color='darkorange', alpha=0.2,
    label='95% Confidence'
)
# ax.set_ylim([-3, 3])
# ax.set_xlim([-10.2, 10.2])
ax.legend()
plt.tight_layout()
plt.show()