Exact GP w. MCMC¶
In this example we show how to use NUTS to sample from the posterior over the hyperparameters of a gaussian process.
Source: Numpyro Example
import sys
from pyprojroot import here
sys.path.append(str(here()))
from dataclasses import dataclass
import time
import functools
from typing import Callable, Dict, Tuple
import argparse
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
sns.reset_defaults()
# sns.set_style('whitegrid')
# sns.set_context('talk')
sns.set_context(context="talk", font_scale=0.7)
import numpy as onp
import tqdm
from jax.experimental import optimizers
import jax
from jax import vmap
import jax.numpy as np
import jax.random as random
from src.models.jaxgp.data import get_data
from src.models.jaxgp.exact import predictive_mean, predictive_variance
from src.models.jaxgp.kernels import gram, rbf_kernel, ard_kernel
from src.models.jaxgp.loss import marginal_likelihood
from src.models.jaxgp.mean import zero_mean
from src.models.jaxgp.utils import cholesky_factorization, get_factorizations, saturate
%load_ext autoreload
%autoreload 2
Data¶
@dataclass
class args:
num_train = 30
num_test = 1_000
smoke_test = False
input_noise = 0.15
output_noise = 0.15
num_chains = 1
num_warmup = 1_000
num_samples = 1_000
device = 'cpu'
numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)
# sigma_inputs = 0.15
input_cov = jnp.array([args.input_noise]).reshape(-1, 1)
X, y, Xtest, ytest = get_data(
N=args.num_train,
input_noise=args.input_noise,
output_noise=args.output_noise,
N_test=args.num_test,
)
GP Model¶
# squared exponential kernel with diagonal noise term
def kernel(X, Z, var, length, noise, jitter=1.0e-6, include_noise=True):
deltaXsq = jnp.power((X[:, None] - Z) / length, 2.0)
k = var * jnp.exp(-0.5 * deltaXsq)
if include_noise:
k += (noise + jitter) * jnp.eye(X.shape[0])
return k
@jax.jit
def model(X, Y):
# set uninformative log-normal priors on our three kernel hyperparameters
var = numpyro.sample("kernel_var", dist.LogNormal(0.0, 10.0))
noise = numpyro.sample("kernel_noise", dist.LogNormal(0.0, 10.0))
length = numpyro.sample("kernel_length", dist.LogNormal(0.0, 10.0))
# compute kernel
k = kernel(X, X, var, length, noise)
# sample Y according to the standard gaussian process formula
numpyro.sample(
"Y",
dist.MultivariateNormal(
loc=jnp.zeros(X.shape[0]),
covariance_matrix=k
),
obs=Y,
)
Inference¶
# helper function for doing hmc inference
def run_inference(model, args, rng_key, X, Y):
start = time.time()
kernel = NUTS(model)
mcmc = MCMC(
kernel,
args.num_warmup,
args.num_samples,
num_chains=args.num_chains,
progress_bar=True,
)
mcmc.run(rng_key, X, Y)
mcmc.print_summary()
print("\nMCMC elapsed time:", time.time() - start)
return mcmc.get_samples()
Training¶
# do inference
rng_key, rng_key_predict = random.split(random.PRNGKey(0))
samples = run_inference(model, args, rng_key, X, y)
Predictions¶
def predict(rng_key, X, Y, X_test, var, length, noise):
# compute kernels between train and test data, etc.
k_pp = kernel(X_test, X_test, var, length, noise, include_noise=True)
k_pX = kernel(X_test, X, var, length, noise, include_noise=False)
k_XX = kernel(X, X, var, length, noise, include_noise=True)
K_xx_inv = jnp.linalg.inv(k_XX)
K = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
sigma_noise = jnp.sqrt(jnp.clip(jnp.diag(K), a_min=0.0)) * jax.random.normal(
rng_key, X_test.shape[:1]
)
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, Y))
# we return both the mean function and a sample from the posterior predictive for the
# given set of hyperparameters
return mean, mean + sigma_noise
# do prediction
vmap_args = (
random.split(rng_key_predict, args.num_samples * args.num_chains),
samples["kernel_var"],
samples["kernel_length"],
samples["kernel_noise"],
)
means, predictions = vmap(
lambda rng_key, var, length, noise: predict(
rng_key, X, y, X_test, var, length, noise
)
)(*vmap_args)
mean_prediction = onp.mean(means, axis=0)
percentiles = onp.percentile(predictions, [5.0, 95.0], axis=0)
Results¶
# make plots
fig, ax = plt.subplots(1, 1)
# plot training data
ax.plot(X, y, "kx")
# plot 90% confidence level of predictions
ax.fill_between(X_test, percentiles[0, :], percentiles[1, :], color="lightblue")
# plot mean prediction
ax.plot(X_test, mean_prediction, "blue", ls="solid", lw=2.0)
ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI")
# plt.savefig("numpyro_gp_plot.png")
plt.tight_layout()
Experiment¶
GP Model - Uncertain Inputs¶
def emodel(Xmu, Y):
# set uninformative log-normal priors on our three kernel hyperparameters
var = numpyro.sample("kernel_var", dist.LogNormal(0.0, 10.0))
noise = numpyro.sample("kernel_noise", dist.LogNormal(0.0, 10.0))
length = numpyro.sample("kernel_length", dist.LogNormal(0.0, 10.0))
X = numpyro.sample("X", dist.Normal(Xmu, 0.3), )
# X = Xmu + Xstd
# X = numpyro.sample("X", dist.Normal(Xmu, 0.3 * np.ones(Xmu.shape[-1])), )
# compute kernel
k = kernel(X, X, var, length, noise)
# sample Y according to the standard gaussian process formula
numpyro.sample(
"Y", dist.MultivariateNormal(
loc=np.zeros(X.shape[0]),
covariance_matrix=k
),
obs=Y)
rng_key, rng_key_predict = random.split(random.PRNGKey(0))
# Run inference scheme
samples = run_inference(emodel, args, rng_key, X, Y, )
# do prediction
vmap_args = (random.split(rng_key_predict, args.num_samples * args.num_chains), samples['kernel_var'],
samples['kernel_length'], samples['kernel_noise'])
means, predictions = vmap(lambda rng_key, var, length, noise:
predict(rng_key, X, Y, X_test, var, length, noise))(*vmap_args)
mean_prediction = onp.mean(means, axis=0)
percentiles = onp.percentile(predictions, [5.0, 95.0], axis=0)
# make plots
fig, ax = plt.subplots(1, 1)
# plot training data
ax.plot(X, Y, 'kx')
# plot 90% confidence level of predictions
ax.fill_between(X_test, percentiles[0, :], percentiles[1, :], color='lightblue')
# plot mean prediction
ax.plot(X_test, mean_prediction, 'blue', ls='solid', lw=2.0)
ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI")
def emodel(Xmu, Y):
# set uninformative log-normal priors on our three kernel hyperparameters
var = numpyro.sample("kernel_var", dist.LogNormal(0.0, 10.0))
noise = numpyro.sample("kernel_noise", dist.LogNormal(0.0, 10.0))
length = numpyro.sample("kernel_length", dist.LogNormal(0.0, 10.0))
# X = numpyro.sample("X", dist.Normal(Xmu, 0.15), )
Xstd = numpyro.sample("Xstd", dist.Normal(0.0, 0.3), sample_shape=(Xmu.shape[0],))
X = Xmu + Xstd
# X = numpyro.sample("X", dist.Normal(Xmu, 0.3 * np.ones(Xmu.shape[-1])), )
# compute kernel
k = kernel(X, X, var, length, noise)
# sample Y according to the standard gaussian process formula
numpyro.sample("Y", dist.MultivariateNormal(loc=np.zeros(X.shape[0]), covariance_matrix=k),
obs=Y)
rng_key, rng_key_predict = random.split(random.PRNGKey(0))
# Run inference scheme
samples = run_inference(emodel, args, rng_key, X, Y, )
# do prediction
vmap_args = (random.split(rng_key_predict, args.num_samples * args.num_chains), samples['kernel_var'],
samples['kernel_length'], samples['kernel_noise'])
means, predictions = vmap(lambda rng_key, var, length, noise:
predict(rng_key, X, Y, X_test, var, length, noise))(*vmap_args)
mean_prediction = onp.mean(means, axis=0)
percentiles = onp.percentile(predictions, [5.0, 95.0], axis=0)
# make plots
fig, ax = plt.subplots(1, 1)
# plot training data
ax.plot(X, Y, 'kx')
# plot 90% confidence level of predictions
ax.fill_between(X_test, percentiles[0, :], percentiles[1, :], color='lightblue')
# plot mean prediction
ax.plot(X_test, mean_prediction, 'blue', ls='solid', lw=2.0)
ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI")