Skip to content

Exact GP w. MCMC

In this example we show how to use NUTS to sample from the posterior over the hyperparameters of a gaussian process.

Source: Numpyro Example

import sys
from pyprojroot import here
sys.path.append(str(here()))

from dataclasses import dataclass
import time

import functools
from typing import Callable, Dict, Tuple
import argparse
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

sns.reset_defaults()
# sns.set_style('whitegrid')
# sns.set_context('talk')
sns.set_context(context="talk", font_scale=0.7)

import numpy as onp
import tqdm
from jax.experimental import optimizers

import jax
from jax import vmap
import jax.numpy as np
import jax.random as random

from src.models.jaxgp.data import get_data
from src.models.jaxgp.exact import predictive_mean, predictive_variance
from src.models.jaxgp.kernels import gram, rbf_kernel, ard_kernel
from src.models.jaxgp.loss import marginal_likelihood
from src.models.jaxgp.mean import zero_mean
from src.models.jaxgp.utils import cholesky_factorization, get_factorizations, saturate

%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload

Data

@dataclass
class args:
    num_train = 30
    num_test = 1_000
    smoke_test = False
    input_noise = 0.15
    output_noise = 0.15
    num_chains = 1
    num_warmup = 1_000
    num_samples = 1_000
    device = 'cpu'

numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)
# sigma_inputs = 0.15
input_cov = jnp.array([args.input_noise]).reshape(-1, 1)
X, y, Xtest, ytest = get_data(
    N=args.num_train,
    input_noise=args.input_noise,
    output_noise=args.output_noise,
    N_test=args.num_test,
)

GP Model

# squared exponential kernel with diagonal noise term
def kernel(X, Z, var, length, noise, jitter=1.0e-6, include_noise=True):
    deltaXsq = jnp.power((X[:, None] - Z) / length, 2.0)
    k = var * jnp.exp(-0.5 * deltaXsq)
    if include_noise:
        k += (noise + jitter) * jnp.eye(X.shape[0])
    return k
@jax.jit
def model(X, Y):
    # set uninformative log-normal priors on our three kernel hyperparameters
    var = numpyro.sample("kernel_var", dist.LogNormal(0.0, 10.0))
    noise = numpyro.sample("kernel_noise", dist.LogNormal(0.0, 10.0))
    length = numpyro.sample("kernel_length", dist.LogNormal(0.0, 10.0))

    # compute kernel
    k = kernel(X, X, var, length, noise)

    # sample Y according to the standard gaussian process formula
    numpyro.sample(
        "Y",
        dist.MultivariateNormal(
            loc=jnp.zeros(X.shape[0]), 
            covariance_matrix=k
        ),
        obs=Y,
    )

Inference

# helper function for doing hmc inference
def run_inference(model, args, rng_key, X, Y):
    start = time.time()
    kernel = NUTS(model)
    mcmc = MCMC(
        kernel,
        args.num_warmup,
        args.num_samples,
        num_chains=args.num_chains,
        progress_bar=True,
    )
    mcmc.run(rng_key, X, Y)
    mcmc.print_summary()
    print("\nMCMC elapsed time:", time.time() - start)
    return mcmc.get_samples()

Training

# do inference
rng_key, rng_key_predict = random.split(random.PRNGKey(0))


samples = run_inference(model, args, rng_key, X, y)
sample: 100%|██████████| 2000/2000 [42:33<00:00,  1.28s/it, 1023 steps of size 4.98e-06. acc. prob=0.95]
                     mean       std    median      5.0%     95.0%     n_eff     r_hat
  kernel_length 1960118272.00    256.13 1960118528.00 1960118528.00 1960118528.00      0.50      1.00
   kernel_noise      0.00      0.00      0.00      0.00      0.00       nan       nan
     kernel_var      2.63      0.00      2.63      2.63      2.63      6.13      1.00

Number of divergences: 0

MCMC elapsed time: 2557.5514323711395
/home/emmanuel/.conda/envs/egp/lib/python3.8/site-packages/numpyro/diagnostics.py:172: RuntimeWarning: invalid value encountered in true_divide
  rho_k = 1. - (var_within - gamma_k_c.mean(axis=0)) / var_estimator

Predictions

def predict(rng_key, X, Y, X_test, var, length, noise):
    # compute kernels between train and test data, etc.
    k_pp = kernel(X_test, X_test, var, length, noise, include_noise=True)
    k_pX = kernel(X_test, X, var, length, noise, include_noise=False)
    k_XX = kernel(X, X, var, length, noise, include_noise=True)
    K_xx_inv = jnp.linalg.inv(k_XX)
    K = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
    sigma_noise = jnp.sqrt(jnp.clip(jnp.diag(K), a_min=0.0)) * jax.random.normal(
        rng_key, X_test.shape[:1]
    )
    mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, Y))
    # we return both the mean function and a sample from the posterior predictive for the
    # given set of hyperparameters
    return mean, mean + sigma_noise
# do prediction
vmap_args = (
    random.split(rng_key_predict, args.num_samples * args.num_chains),
    samples["kernel_var"],
    samples["kernel_length"],
    samples["kernel_noise"],
)
means, predictions = vmap(
    lambda rng_key, var, length, noise: predict(
        rng_key, X, y, X_test, var, length, noise
    )
)(*vmap_args)
mean_prediction = onp.mean(means, axis=0)
percentiles = onp.percentile(predictions, [5.0, 95.0], axis=0)

Results

# make plots
fig, ax = plt.subplots(1, 1)

# plot training data
ax.plot(X, y, "kx")
# plot 90% confidence level of predictions
ax.fill_between(X_test, percentiles[0, :], percentiles[1, :], color="lightblue")
# plot mean prediction
ax.plot(X_test, mean_prediction, "blue", ls="solid", lw=2.0)
ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI")

# plt.savefig("numpyro_gp_plot.png")
plt.tight_layout()

Experiment

GP Model - Uncertain Inputs

def emodel(Xmu, Y):
    # set uninformative log-normal priors on our three kernel hyperparameters
    var = numpyro.sample("kernel_var", dist.LogNormal(0.0, 10.0))
    noise = numpyro.sample("kernel_noise", dist.LogNormal(0.0, 10.0))
    length = numpyro.sample("kernel_length", dist.LogNormal(0.0, 10.0))
    X = numpyro.sample("X", dist.Normal(Xmu, 0.3), )
    # X = Xmu + Xstd
    # X = numpyro.sample("X", dist.Normal(Xmu, 0.3 * np.ones(Xmu.shape[-1])), )


    # compute kernel
    k = kernel(X, X, var, length, noise)

    # sample Y according to the standard gaussian process formula
    numpyro.sample(
        "Y", dist.MultivariateNormal(
        loc=np.zeros(X.shape[0]), 
        covariance_matrix=k
    ),
        obs=Y)
rng_key, rng_key_predict = random.split(random.PRNGKey(0))

# Run inference scheme
samples = run_inference(emodel, args, rng_key, X, Y, )
sample: 100%|██████████| 1100/1100 [00:19<00:00, 56.73it/s, 15 steps of size 2.14e-01. acc. prob=0.93] 
                     mean       std    median      5.0%     95.0%     n_eff     r_hat
           X[0]    -10.01      0.25     -9.98    -10.40     -9.59    597.56      1.00
           X[1]     -9.91      0.27     -9.92    -10.30     -9.45    887.43      1.00
           X[2]     -9.65      0.27     -9.67    -10.03     -9.16    475.77      1.00
           X[3]     -9.40      0.25     -9.41     -9.85     -9.04    937.07      1.00
           X[4]     -8.83      0.23     -8.83     -9.27     -8.51    759.05      1.00
           X[5]     -8.33      0.19     -8.31     -8.61     -8.01    463.77      1.00
           X[6]     -8.31      0.19     -8.30     -8.60     -8.01    556.40      1.00
           X[7]     -7.77      0.13     -7.77     -7.98     -7.57    554.29      1.00
           X[8]     -7.42      0.12     -7.42     -7.62     -7.23    426.67      1.00
           X[9]     -7.03      0.11     -7.03     -7.21     -6.84    363.45      1.00
          X[10]     -6.60      0.12     -6.61     -6.79     -6.41    370.03      1.00
          X[11]     -6.29      0.13     -6.29     -6.51     -6.10    423.12      1.00
          X[12]     -5.80      0.16     -5.80     -6.05     -5.54    461.56      1.00
          X[13]     -5.45      0.18     -5.46     -5.74     -5.15    665.71      1.00
          X[14]     -5.00      0.25     -5.02     -5.41     -4.61   1038.74      1.00
          X[15]     -4.97      0.23     -4.98     -5.32     -4.59    782.89      1.00
          X[16]     -4.92      0.32     -4.93     -5.42     -4.40   1370.20      1.00
          X[17]     -4.41      0.29     -4.41     -4.87     -3.91   1293.31      1.00
          X[18]     -4.00      0.31     -3.99     -4.49     -3.50   1020.15      1.00
          X[19]     -3.57      0.28     -3.55     -4.06     -3.14    686.66      1.00
          X[20]     -3.42      0.36     -3.40     -3.96     -2.79    798.33      1.00
          X[21]     -2.57      0.36     -2.57     -3.12     -1.96    760.66      1.00
          X[22]     -2.34      0.28     -2.32     -2.82     -1.92    800.18      1.00
          X[23]     -2.68      0.27     -2.66     -3.07     -2.18    779.41      1.00
          X[24]     -1.61      0.16     -1.61     -1.89     -1.38    410.05      1.01
          X[25]     -1.65      0.16     -1.65     -1.91     -1.38    460.06      1.00
          X[26]     -1.16      0.13     -1.16     -1.36     -0.96    352.57      1.01
          X[27]     -0.81      0.12     -0.80     -0.98     -0.59    363.98      1.01
          X[28]     -0.32      0.12     -0.33     -0.51     -0.13    388.07      1.01
          X[29]      0.11      0.12      0.10     -0.10      0.29    393.42      1.02
          X[30]      0.41      0.13      0.40      0.18      0.62    534.79      1.01
          X[31]      0.93      0.19      0.92      0.63      1.24    897.67      1.00
          X[32]      1.08      0.21      1.07      0.75      1.42    398.94      1.00
          X[33]      1.39      0.33      1.36      0.87      1.90    443.17      1.00
          X[34]      1.67      0.28      1.67      1.25      2.13    822.54      1.00
          X[35]      2.07      0.30      2.07      1.60      2.55    735.84      1.00
          X[36]      2.18      0.30      2.18      1.69      2.65    546.85      1.00
          X[37]      3.10      0.31      3.10      2.63      3.64   1153.64      1.00
          X[38]      2.92      0.28      2.93      2.46      3.35   1498.27      1.00
          X[39]      3.33      0.30      3.34      2.82      3.79   1393.60      1.00
          X[40]      4.02      0.26      4.03      3.64      4.48    743.00      1.00
          X[41]      3.50      0.31      3.50      3.01      3.99    895.25      1.00
          X[42]      3.93      0.33      3.96      3.41      4.47    667.17      1.00
          X[43]      4.27      0.19      4.27      3.96      4.60    626.95      1.00
          X[44]      4.89      0.16      4.88      4.62      5.14    340.23      1.00
          X[45]      5.36      0.13      5.36      5.17      5.58    373.83      1.00
          X[46]      5.78      0.12      5.78      5.58      5.96    307.39      1.00
          X[47]      6.05      0.12      6.06      5.86      6.24    329.60      1.00
          X[48]      6.65      0.13      6.64      6.45      6.86    343.38      1.00
          X[49]      6.94      0.15      6.94      6.68      7.18    385.53      1.00
          X[50]      7.56      0.26      7.53      7.18      8.04    551.86      1.00
          X[51]      7.61      0.26      7.59      7.17      8.03    524.45      1.00
          X[52]      7.63      0.21      7.63      7.25      7.95   1001.54      1.00
          X[53]      8.40      0.30      8.38      7.91      8.89    576.53      1.00
          X[54]      8.28      0.33      8.32      7.69      8.78   1210.06      1.00
          X[55]      8.95      0.26      8.95      8.50      9.34   1332.43      1.00
          X[56]      9.25      0.27      9.25      8.77      9.65   1931.84      1.00
          X[57]      9.27      0.27      9.28      8.82      9.69   1255.06      1.00
          X[58]      9.89      0.28      9.91      9.46     10.36    613.94      1.00
          X[59]     10.22      0.28     10.20      9.78     10.66    925.47      1.00
  kernel_length      1.95      0.19      1.96      1.63      2.23    440.62      1.01
   kernel_noise      0.00      0.00      0.00      0.00      0.01    222.75      1.00
     kernel_var      1.18      0.64      1.02      0.39      2.01    423.49      1.01

Number of divergences: 0

MCMC elapsed time: 23.19466996192932
# do prediction
vmap_args = (random.split(rng_key_predict, args.num_samples * args.num_chains), samples['kernel_var'],
                samples['kernel_length'], samples['kernel_noise'])
means, predictions = vmap(lambda rng_key, var, length, noise:
                            predict(rng_key, X, Y, X_test, var, length, noise))(*vmap_args)

mean_prediction = onp.mean(means, axis=0)
percentiles = onp.percentile(predictions, [5.0, 95.0], axis=0)
# make plots
fig, ax = plt.subplots(1, 1)

# plot training data
ax.plot(X, Y, 'kx')
# plot 90% confidence level of predictions
ax.fill_between(X_test, percentiles[0, :], percentiles[1, :], color='lightblue')
# plot mean prediction
ax.plot(X_test, mean_prediction, 'blue', ls='solid', lw=2.0)
ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI")
[Text(0, 0.5, 'Y'),
 Text(0.5, 0, 'X'),
 Text(0.5, 1.0, 'Mean predictions with 90% CI')]
def emodel(Xmu, Y):
    # set uninformative log-normal priors on our three kernel hyperparameters
    var = numpyro.sample("kernel_var", dist.LogNormal(0.0, 10.0))
    noise = numpyro.sample("kernel_noise", dist.LogNormal(0.0, 10.0))
    length = numpyro.sample("kernel_length", dist.LogNormal(0.0, 10.0))
    # X = numpyro.sample("X", dist.Normal(Xmu, 0.15), )
    Xstd = numpyro.sample("Xstd", dist.Normal(0.0, 0.3), sample_shape=(Xmu.shape[0],))
    X = Xmu + Xstd
    # X = numpyro.sample("X", dist.Normal(Xmu, 0.3 * np.ones(Xmu.shape[-1])), )


    # compute kernel
    k = kernel(X, X, var, length, noise)

    # sample Y according to the standard gaussian process formula
    numpyro.sample("Y", dist.MultivariateNormal(loc=np.zeros(X.shape[0]), covariance_matrix=k),
                   obs=Y)
rng_key, rng_key_predict = random.split(random.PRNGKey(0))

# Run inference scheme
samples = run_inference(emodel, args, rng_key, X, Y, )
sample: 100%|██████████| 1100/1100 [00:17<00:00, 62.81it/s, 15 steps of size 2.65e-01. acc. prob=0.89] 
                     mean       std    median      5.0%     95.0%     n_eff     r_hat
        Xstd[0]      0.20      0.26      0.22     -0.23      0.62    792.72      1.00
        Xstd[1]     -0.15      0.26     -0.15     -0.63      0.22    951.20      1.00
        Xstd[2]     -0.09      0.26     -0.11     -0.50      0.37    832.81      1.00
        Xstd[3]      0.10      0.24      0.10     -0.28      0.49    982.06      1.00
        Xstd[4]     -0.25      0.22     -0.26     -0.63      0.09    934.13      1.00
        Xstd[5]      0.09      0.19      0.11     -0.19      0.42    531.47      1.00
        Xstd[6]      0.13      0.18      0.15     -0.18      0.43    452.94      1.00
        Xstd[7]     -0.28      0.13     -0.28     -0.49     -0.06    495.33      1.00
        Xstd[8]      0.14      0.13      0.15     -0.05      0.35    365.29      1.00
        Xstd[9]     -0.10      0.12     -0.10     -0.30      0.09    306.55      1.00
       Xstd[10]     -0.21      0.12     -0.20     -0.42     -0.02    304.69      1.00
       Xstd[11]     -0.05      0.13     -0.05     -0.26      0.18    284.53      1.00
       Xstd[12]     -0.21      0.17     -0.21     -0.49      0.05    415.29      1.00
       Xstd[13]      0.52      0.19      0.50      0.20      0.81    540.49      1.00
       Xstd[14]      0.12      0.24      0.11     -0.29      0.47    898.49      1.01
       Xstd[15]      0.15      0.23      0.14     -0.23      0.52   1165.26      1.00
       Xstd[16]     -0.08      0.33     -0.10     -0.60      0.44    904.75      1.00
       Xstd[17]     -0.00      0.29     -0.00     -0.53      0.43   1652.78      1.00
       Xstd[18]     -0.01      0.32      0.00     -0.49      0.54   1462.54      1.00
       Xstd[19]     -0.02      0.28     -0.01     -0.47      0.47    903.68      1.00
       Xstd[20]      0.14      0.36      0.17     -0.47      0.66    905.66      1.00
       Xstd[21]      0.05      0.36      0.06     -0.54      0.61    648.84      1.00
       Xstd[22]      0.04      0.30      0.07     -0.44      0.51   1011.25      1.00
       Xstd[23]     -0.01      0.27      0.00     -0.43      0.45   1237.85      1.00
       Xstd[24]     -0.20      0.16     -0.20     -0.45      0.06    419.10      1.00
       Xstd[25]     -0.69      0.16     -0.68     -0.98     -0.46    379.36      1.00
       Xstd[26]     -0.33      0.13     -0.33     -0.54     -0.12    320.10      1.00
       Xstd[27]      0.09      0.13      0.09     -0.10      0.30    245.93      1.01
       Xstd[28]      0.50      0.13      0.50      0.30      0.71    253.37      1.01
       Xstd[29]     -0.04      0.13     -0.04     -0.24      0.18    259.11      1.01
       Xstd[30]      0.36      0.14      0.36      0.13      0.57    296.81      1.01
       Xstd[31]      0.07      0.19      0.06     -0.23      0.37    539.63      1.00
       Xstd[32]      0.18      0.21      0.17     -0.18      0.50    868.61      1.00
       Xstd[33]     -0.09      0.31     -0.14     -0.54      0.45    551.52      1.00
       Xstd[34]      0.04      0.27      0.04     -0.35      0.53   1343.46      1.00
       Xstd[35]     -0.01      0.29     -0.01     -0.48      0.51   1573.42      1.00
       Xstd[36]     -0.04      0.29     -0.04     -0.51      0.44   1578.87      1.00
       Xstd[37]      0.02      0.31      0.03     -0.48      0.53   2398.18      1.00
       Xstd[38]     -0.00      0.29     -0.00     -0.45      0.47   1411.13      1.00
       Xstd[39]     -0.01      0.30     -0.01     -0.49      0.48   2119.89      1.00
       Xstd[40]     -0.11      0.25     -0.10     -0.49      0.33    537.76      1.00
       Xstd[41]      0.00      0.30      0.01     -0.48      0.51    934.64      1.00
       Xstd[42]      0.09      0.32      0.12     -0.49      0.55   1000.19      1.00
       Xstd[43]     -0.61      0.20     -0.60     -0.92     -0.28    716.21      1.00
       Xstd[44]      0.31      0.15      0.32      0.08      0.57    487.44      1.00
       Xstd[45]     -0.49      0.12     -0.48     -0.68     -0.28    426.20      1.00
       Xstd[46]      0.30      0.12      0.30      0.11      0.49    383.15      1.00
       Xstd[47]      0.34      0.12      0.34      0.15      0.53    329.32      1.00
       Xstd[48]     -0.21      0.13     -0.22     -0.41     -0.01    383.74      1.00
       Xstd[49]     -0.12      0.15     -0.13     -0.37      0.10    392.93      1.00
       Xstd[50]      0.03      0.25      0.01     -0.37      0.40    668.25      1.00
       Xstd[51]      0.05      0.25      0.03     -0.37      0.46    928.56      1.00
       Xstd[52]      0.26      0.22      0.24     -0.12      0.60    776.53      1.00
       Xstd[53]     -0.14      0.33     -0.17     -0.62      0.48    672.39      1.00
       Xstd[54]      0.06      0.32      0.08     -0.49      0.54   1436.10      1.00
       Xstd[55]      0.06      0.26      0.05     -0.35      0.49   1649.02      1.00
       Xstd[56]     -0.02      0.29     -0.02     -0.47      0.44   1683.10      1.00
       Xstd[57]     -0.01      0.27     -0.01     -0.42      0.48   1384.29      1.00
       Xstd[58]      0.05      0.28      0.06     -0.37      0.54   1061.70      1.00
       Xstd[59]     -0.06      0.26     -0.08     -0.46      0.39   1705.82      1.00
  kernel_length      1.93      0.21      1.94      1.57      2.26    224.70      1.01
   kernel_noise      0.00      0.00      0.00      0.00      0.01    262.87      1.00
     kernel_var      1.15      0.62      1.00      0.41      1.95    344.59      1.00

Number of divergences: 0

MCMC elapsed time: 19.7586088180542
# do prediction
vmap_args = (random.split(rng_key_predict, args.num_samples * args.num_chains), samples['kernel_var'],
                samples['kernel_length'], samples['kernel_noise'])
means, predictions = vmap(lambda rng_key, var, length, noise:
                            predict(rng_key, X, Y, X_test, var, length, noise))(*vmap_args)

mean_prediction = onp.mean(means, axis=0)
percentiles = onp.percentile(predictions, [5.0, 95.0], axis=0)
# make plots
fig, ax = plt.subplots(1, 1)

# plot training data
ax.plot(X, Y, 'kx')
# plot 90% confidence level of predictions
ax.fill_between(X_test, percentiles[0, :], percentiles[1, :], color='lightblue')
# plot mean prediction
ax.plot(X_test, mean_prediction, 'blue', ls='solid', lw=2.0)
ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI")
[Text(0, 0.5, 'Y'),
 Text(0.5, 0, 'X'),
 Text(0.5, 1.0, 'Mean predictions with 90% CI')]