Jax PlayGround¶
My starting notebook where I install all of the necessary libraries and load some easy 1D/2D Regression data to play around with.
#@title Install Packages
!pip install jax jaxlib --force
!pip install "git+https://github.com/pyro-ppl/numpyro.git#egg=numpyro"
#@title Load Packages
# TYPE HINTS
from typing import Tuple, Optional, Dict, Callable, Union
# JAX SETTINGS
import jax
import jax.numpy as np
from jax import random, lax
import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI
from numpyro.infer.autoguide import AutoDiagonalNormal
from numpyro.infer import SVI, ELBO
numpyro.set_platform("0")
# NUMPY SETTINGS
import numpy as onp
onp.set_printoptions(precision=3, suppress=True)
# MATPLOTLIB Settings
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
# SEABORN SETTINGS
import seaborn as sns
sns.set_context(context='talk',font_scale=0.7)
# PANDAS SETTINGS
import pandas as pd
pd.set_option("display.max_rows", 120)
pd.set_option("display.max_columns", 120)
# LOGGING SETTINGS
import sys
import logging
logging.basicConfig(
level=logging.INFO,
stream=sys.stdout,
format='%(asctime)s:%(levelname)s:%(message)s'
)
logger = logging.getLogger()
#logger.setLevel(logging.INFO)
%load_ext autoreload
%autoreload 2
#@title Data
def get_data(
N: int = 30,
input_noise: float = 0.15,
output_noise: float = 0.15,
N_test: int = 400,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, None]:
onp.random.seed(0)
X = np.linspace(-1, 1, N)
Y = X + 0.2 * np.power(X, 3.0) + 0.5 * np.power(0.5 + X, 2.0) * np.sin(4.0 * X)
Y += output_noise * onp.random.randn(N)
Y -= np.mean(Y)
Y /= np.std(Y)
X += input_noise * onp.random.randn(N)
assert X.shape == (N,)
assert Y.shape == (N,)
X_test = np.linspace(-1.2, 1.2, N_test)
return X[:, None], Y[:, None], X_test[:, None]
seed = 123
rng = onp.random.RandomState(seed)
n_samples = 50
noise = 0.05
X = np.linspace(-2*np.pi, 2*np.pi, n_samples)
y = np.sinc(X) + 0.01 * rng.randn(n_samples)
# Store data as torch.tensors.
# Plot data and true function.
plt.scatter(X, y, label='data')
plt.xlabel('x')
plt.ylabel('y = f(x)')
plt.legend();
Numpyro Model¶
# One-dimensional squared exponential kernel with diagonal noise term.
def squared_exp_cov_1D(X, Y, variance, lengthscale):
deltaXsq = np.power((X[:, None] - Y) / lengthscale, 2.0)
K = variance * np.exp(-0.5 * deltaXsq)
return K
# GP model.
def GP(X, y):
# Set informative log-normal priors on kernel hyperparameters.
variance = numpyro.sample("kernel_var", dist.LogNormal(0.0, 0.1))
lengthscale = numpyro.sample("kernel_length", dist.LogNormal(0.0, 1.0))
sigma = numpyro.sample("sigma", dist.LogNormal(0.0, 1.0))
# Compute kernel
K = squared_exp_cov_1D(X, X, variance, lengthscale)
K += np.eye(X.shape[0]) * np.power(sigma, 2)
# Sample y according to the standard gaussian process formula
numpyro.sample("y", dist.MultivariateNormal(loc=np.zeros(X.shape[0]),
covariance_matrix=K), obs=y)
Inference - SVI¶
from numpyro import handlers
with handlers.seed(rng_seed=0):
i = GP(X, y)
print(i)
%%time
from numpyro.infer.autoguide import AutoLaplaceApproximation, AutoMultivariateNormal, AutoDiagonalNormal
from numpyro.infer import TraceMeanField_ELBO
# Compile
guide = AutoDiagonalNormal(GP)
optimizer = numpyro.optim.Adam(step_size=0.01)
svi = SVI(GP, guide, optimizer, loss=TraceMeanField_ELBO())
init_state = svi.init(random.PRNGKey(1), X, y)
%%time
# Run optimizer for 1000 iteratons.
state, losses = lax.scan(lambda state, i:
svi.update(state, X, y),
init_state, np.arange(5000))
# Extract surrogate posterior.
params = svi.get_params(state)
plt.plot(losses);
plt.title("Negative ELBO (Loss)");
Sample Posterior Distribution¶
n_test_samples = 100
advi_samples = guide.get_posterior(params).sample(random.PRNGKey(seed), (n_test_samples, ))
Parameters¶
fig, ax = plt.subplots(ncols=3, nrows=1, figsize=(15, 3))
labels = ["Variance", "Length Scale", "Noise"]
for i, ilabel in enumerate(labels):
ax[i].hist(advi_samples[:, i], density=True, bins=20, label=ilabel)
ax[i].legend()
plt.show()
Predictive Mean & Variance¶
def predict(kernel, X, Y, X_test, var, length, noise):
# compute kernels between train and test data, etc.
k_pp = kernel(X_test, X_test, var, length)
k_pX = kernel(X_test, X, var, length)
k_XX = kernel(X, X, var, length)
k_XX += np.eye(X.shape[0]) * np.power(noise, 2)
K_xx_inv = np.linalg.inv(k_XX)
K = k_pp - np.matmul(k_pX, np.matmul(K_xx_inv, np.transpose(k_pX)))
mean = np.matmul(k_pX, np.matmul(K_xx_inv, Y))
# we return both the mean function and a sample from the posterior predictive for the
# given set of hyperparameters
return mean
n_test_samples = 10
advi_samples = guide.get_posterior(params).sample(random.PRNGKey(seed), (n_test_samples, ))
X_test = np.linspace(-2.2 * np.pi, 2.2 * np.pi, 1_000)
preds = [predict(squared_exp_cov_1D, X, y, X_test, iparam[0], iparam[1], iparam[2]) for iparam in advi_samples]
predictions = np.vstack(preds)
predictions.shape
# Summarize function posterior.
ci = 95
ci_lower = (100 - ci) / 2
ci_upper = (100 + ci) / 2
preds_mean = predictions.mean(0)
preds_lower = np.percentile(predictions, ci_lower, axis=0)
preds_upper = np.percentile(predictions, ci_upper, axis=0)
plt.plot(X_test, preds_mean)
plt.plot(X_test, preds_lower)
plt.plot(X_test, preds_upper)
fig, ax = plt.subplots()
ax.plot(X_test, preds_mean)
ax.fill_between(
X_test.squeeze(),
preds_lower.squeeze(),
preds_upper.squeeze(),
color='darkorange', alpha=0.2,
label='95% Confidence'
)
# ax.set_ylim([-3, 3])
# ax.set_xlim([-10.2, 10.2])
ax.legend()
plt.tight_layout()
plt.show()