Skip to content

MCMC eGP

TLDR

I did a quick experiment where I look at how we can impact the error bars when doing a fully Bayesian GP (i.e. GP with MCMC inference). I have 3 cases where I use no prior on the inputs, where I use a modest prior on the inputs, and one where I use the exact known prior on the inputs. The results are definitely different than what I'm used to because I actually trained the GP knowing the priors. The error bars were reduced which I guess makes sense.

TODO: Do the MCMC where we approximate the posterior when we trained the GP with uncertain inputs.

png

png

png


Experiment

Code Blocks
!pip install jax jaxlib numpyro
Requirement already satisfied: jax in /usr/local/lib/python3.6/dist-packages (0.1.62)
Requirement already satisfied: jaxlib in /usr/local/lib/python3.6/dist-packages (0.1.42)
Collecting numpyro
[?25l  Downloading https://files.pythonhosted.org/packages/b8/58/54e914bb6d8ee9196f8dbf28b81057fea81871fc171dbee03b790336d0c5/numpyro-0.2.4-py3-none-any.whl (159kB)
     |████████████████████████████████| 163kB 2.5MB/s 
[?25hRequirement already satisfied: absl-py in /usr/local/lib/python3.6/dist-packages (from jax) (0.9.0)
Requirement already satisfied: opt-einsum in /usr/local/lib/python3.6/dist-packages (from jax) (3.2.1)
Requirement already satisfied: numpy>=1.12 in /usr/local/lib/python3.6/dist-packages (from jax) (1.18.3)
Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from jaxlib) (1.4.1)
Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from numpyro) (4.38.0)
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from absl-py->jax) (1.12.0)
ERROR: numpyro 0.2.4 has requirement jax==0.1.57, but you'll have jax 0.1.62 which is incompatible.
ERROR: numpyro 0.2.4 has requirement jaxlib==0.1.37, but you'll have jaxlib 0.1.42 which is incompatible.
Installing collected packages: numpyro
Successfully installed numpyro-0.2.4
#@title packages

import time

import numpy as onp
from dataclasses import dataclass

import jax
from jax import vmap
import jax.numpy as np
import jax.random as random

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

sns.reset_defaults()
#sns.set_style('whitegrid')
#sns.set_context('talk')
sns.set_context(context='talk',font_scale=0.7)
%matplotlib inline
#@title Data

def get_data(N=30, sigma_inputs=0.15, sigma_obs=0.15, N_test=400):
    onp.random.seed(0)
    X = np.linspace(-10, 10, N)
    # Y = X + 0.2 * np.power(X, 3.0) + 0.5 * np.power(0.5 + X, 2.0) * np.sin(4.0 * X)
    Y = np.sin(1.0 * np.pi / 1.6 * np.cos(5 + .5 * X))
    Y += sigma_obs * onp.random.randn(N)
    X += sigma_inputs * onp.random.randn(N)
    Y -= np.mean(Y)
    Y /= np.std(Y)

    assert X.shape == (N,)
    assert Y.shape == (N,)

    X_test = np.linspace(-11, 11, N_test) 
    X_test += sigma_inputs * onp.random.randn(N_test)

    return X, Y, X_test

GP Model

#@title GP Model

# squared exponential kernel with diagonal noise term
def kernel(X, Z, var, length, noise, jitter=1.0e-6, include_noise=True):
    deltaXsq = np.power((X[:, None] - Z) / length, 2.0)
    k = var * np.exp(-0.5 * deltaXsq)
    if include_noise:
        k += (noise + jitter) * np.eye(X.shape[0])
    return k

def model(Xmu, Y):
    # set uninformative log-normal priors on our three kernel hyperparameters
    var = numpyro.sample("kernel_var", dist.LogNormal(0.0, 10.0))
    noise = numpyro.sample("kernel_noise", dist.LogNormal(0.0, 10.0))
    length = numpyro.sample("kernel_length", dist.LogNormal(0.0, 10.0))
    # X = numpyro.sample("X", dist.Normal(Xmu, 0.15 * np.ones((Xmu.shape[0],))))
    X = Xmu

    # compute kernel
    k = kernel(X, X, var, length, noise)

    # sample Y according to the standard gaussian process formula
    numpyro.sample("Y", dist.MultivariateNormal(loc=np.zeros(X.shape[0]), covariance_matrix=k),
                   obs=Y)
# helper function for doing hmc inference
def run_inference(model, args, rng_key, X, Y):
    start = time.time()
    kernel = NUTS(model)
    mcmc = MCMC(kernel, args.num_warmup, args.num_samples, num_chains=args.num_chains,
                progress_bar=True)
    mcmc.run(rng_key, X, Y)
    mcmc.print_summary()
    print('\nMCMC elapsed time:', time.time() - start)
    return mcmc.get_samples()


# do GP prediction for a given set of hyperparameters. this makes use of the well-known
# formula for gaussian process predictions
def predict(rng_key, X, Y, X_test, var, length, noise):
    # compute kernels between train and test data, etc.
    k_pp = kernel(X_test, X_test, var, length, noise, include_noise=True)
    k_pX = kernel(X_test, X, var, length, noise, include_noise=False)
    k_XX = kernel(X, X, var, length, noise, include_noise=True)
    K_xx_inv = np.linalg.inv(k_XX)
    K = k_pp - np.matmul(k_pX, np.matmul(K_xx_inv, np.transpose(k_pX)))
    sigma_noise = np.sqrt(np.clip(np.diag(K), a_min=0.)) * jax.random.normal(rng_key, X_test.shape[:1])
    mean = np.matmul(k_pX, np.matmul(K_xx_inv, Y))
    # we return both the mean function and a sample from the posterior predictive for the
    # given set of hyperparameters
    return mean, mean + sigma_noise

Experiment

@dataclass
class args:
    num_data = 60
    num_warmup = 100
    num_chains = 1
    num_samples = 1_000
    device = 'cpu'
    sigma_inputs = 0.3
    sigma_obs = 0.05

numpyro.set_platform(args.device)
X, Y, X_test = get_data(args.num_data, sigma_inputs=args.sigma_inputs, sigma_obs=args.sigma_obs)
rng_key, rng_key_predict = random.split(random.PRNGKey(0))

# Run inference scheme
samples = run_inference(model, args, rng_key, X, Y, )
sample: 100%|██████████| 1100/1100 [00:11<00:00, 96.81it/s, 7 steps of size 6.48e-01. acc. prob=0.94]



                     mean       std    median      5.0%     95.0%     n_eff     r_hat
  kernel_length      1.97      0.23      1.97      1.58      2.34    650.87      1.00
   kernel_noise      0.04      0.01      0.04      0.02      0.05    637.46      1.00
     kernel_var      1.15      0.65      0.98      0.34      1.97    563.69      1.00

Number of divergences: 0

MCMC elapsed time: 14.073462963104248

Predictions

# do prediction
vmap_args = (random.split(rng_key_predict, args.num_samples * args.num_chains), samples['kernel_var'],
                samples['kernel_length'], samples['kernel_noise'])
means, predictions = vmap(lambda rng_key, var, length, noise:
                            predict(rng_key, X, Y, X_test, var, length, noise))(*vmap_args)

mean_prediction = onp.mean(means, axis=0)
percentiles = onp.percentile(predictions, [5.0, 95.0], axis=0)
# make plots
fig, ax = plt.subplots(1, 1)

# plot training data
ax.plot(X, Y, 'kx')
# plot 90% confidence level of predictions
ax.fill_between(X_test, percentiles[0, :], percentiles[1, :], color='lightblue')
# plot mean prediction
ax.plot(X_test, mean_prediction, 'blue', ls='solid', lw=2.0)
ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI")
[Text(0, 0.5, 'Y'),
 Text(0.5, 0, 'X'),
 Text(0.5, 1.0, 'Mean predictions with 90% CI')]

png

GP Model - Uncertain Inputs

def emodel(Xmu, Y):
    # set uninformative log-normal priors on our three kernel hyperparameters
    var = numpyro.sample("kernel_var", dist.LogNormal(0.0, 10.0))
    noise = numpyro.sample("kernel_noise", dist.LogNormal(0.0, 10.0))
    length = numpyro.sample("kernel_length", dist.LogNormal(0.0, 10.0))
    X = numpyro.sample("X", dist.Normal(Xmu, 0.3), )
    # X = Xmu + Xstd
    # X = numpyro.sample("X", dist.Normal(Xmu, 0.3 * np.ones(Xmu.shape[-1])), )


    # compute kernel
    k = kernel(X, X, var, length, noise)

    # sample Y according to the standard gaussian process formula
    numpyro.sample("Y", dist.MultivariateNormal(loc=np.zeros(X.shape[0]), covariance_matrix=k),
                   obs=Y)
rng_key, rng_key_predict = random.split(random.PRNGKey(0))

# Run inference scheme
samples = run_inference(emodel, args, rng_key, X, Y, )
sample: 100%|██████████| 1100/1100 [00:19<00:00, 56.73it/s, 15 steps of size 2.14e-01. acc. prob=0.93]



                     mean       std    median      5.0%     95.0%     n_eff     r_hat
           X[0]    -10.01      0.25     -9.98    -10.40     -9.59    597.56      1.00
           X[1]     -9.91      0.27     -9.92    -10.30     -9.45    887.43      1.00
           X[2]     -9.65      0.27     -9.67    -10.03     -9.16    475.77      1.00
           X[3]     -9.40      0.25     -9.41     -9.85     -9.04    937.07      1.00
           X[4]     -8.83      0.23     -8.83     -9.27     -8.51    759.05      1.00
           X[5]     -8.33      0.19     -8.31     -8.61     -8.01    463.77      1.00
           X[6]     -8.31      0.19     -8.30     -8.60     -8.01    556.40      1.00
           X[7]     -7.77      0.13     -7.77     -7.98     -7.57    554.29      1.00
           X[8]     -7.42      0.12     -7.42     -7.62     -7.23    426.67      1.00
           X[9]     -7.03      0.11     -7.03     -7.21     -6.84    363.45      1.00
          X[10]     -6.60      0.12     -6.61     -6.79     -6.41    370.03      1.00
          X[11]     -6.29      0.13     -6.29     -6.51     -6.10    423.12      1.00
          X[12]     -5.80      0.16     -5.80     -6.05     -5.54    461.56      1.00
          X[13]     -5.45      0.18     -5.46     -5.74     -5.15    665.71      1.00
          X[14]     -5.00      0.25     -5.02     -5.41     -4.61   1038.74      1.00
          X[15]     -4.97      0.23     -4.98     -5.32     -4.59    782.89      1.00
          X[16]     -4.92      0.32     -4.93     -5.42     -4.40   1370.20      1.00
          X[17]     -4.41      0.29     -4.41     -4.87     -3.91   1293.31      1.00
          X[18]     -4.00      0.31     -3.99     -4.49     -3.50   1020.15      1.00
          X[19]     -3.57      0.28     -3.55     -4.06     -3.14    686.66      1.00
          X[20]     -3.42      0.36     -3.40     -3.96     -2.79    798.33      1.00
          X[21]     -2.57      0.36     -2.57     -3.12     -1.96    760.66      1.00
          X[22]     -2.34      0.28     -2.32     -2.82     -1.92    800.18      1.00
          X[23]     -2.68      0.27     -2.66     -3.07     -2.18    779.41      1.00
          X[24]     -1.61      0.16     -1.61     -1.89     -1.38    410.05      1.01
          X[25]     -1.65      0.16     -1.65     -1.91     -1.38    460.06      1.00
          X[26]     -1.16      0.13     -1.16     -1.36     -0.96    352.57      1.01
          X[27]     -0.81      0.12     -0.80     -0.98     -0.59    363.98      1.01
          X[28]     -0.32      0.12     -0.33     -0.51     -0.13    388.07      1.01
          X[29]      0.11      0.12      0.10     -0.10      0.29    393.42      1.02
          X[30]      0.41      0.13      0.40      0.18      0.62    534.79      1.01
          X[31]      0.93      0.19      0.92      0.63      1.24    897.67      1.00
          X[32]      1.08      0.21      1.07      0.75      1.42    398.94      1.00
          X[33]      1.39      0.33      1.36      0.87      1.90    443.17      1.00
          X[34]      1.67      0.28      1.67      1.25      2.13    822.54      1.00
          X[35]      2.07      0.30      2.07      1.60      2.55    735.84      1.00
          X[36]      2.18      0.30      2.18      1.69      2.65    546.85      1.00
          X[37]      3.10      0.31      3.10      2.63      3.64   1153.64      1.00
          X[38]      2.92      0.28      2.93      2.46      3.35   1498.27      1.00
          X[39]      3.33      0.30      3.34      2.82      3.79   1393.60      1.00
          X[40]      4.02      0.26      4.03      3.64      4.48    743.00      1.00
          X[41]      3.50      0.31      3.50      3.01      3.99    895.25      1.00
          X[42]      3.93      0.33      3.96      3.41      4.47    667.17      1.00
          X[43]      4.27      0.19      4.27      3.96      4.60    626.95      1.00
          X[44]      4.89      0.16      4.88      4.62      5.14    340.23      1.00
          X[45]      5.36      0.13      5.36      5.17      5.58    373.83      1.00
          X[46]      5.78      0.12      5.78      5.58      5.96    307.39      1.00
          X[47]      6.05      0.12      6.06      5.86      6.24    329.60      1.00
          X[48]      6.65      0.13      6.64      6.45      6.86    343.38      1.00
          X[49]      6.94      0.15      6.94      6.68      7.18    385.53      1.00
          X[50]      7.56      0.26      7.53      7.18      8.04    551.86      1.00
          X[51]      7.61      0.26      7.59      7.17      8.03    524.45      1.00
          X[52]      7.63      0.21      7.63      7.25      7.95   1001.54      1.00
          X[53]      8.40      0.30      8.38      7.91      8.89    576.53      1.00
          X[54]      8.28      0.33      8.32      7.69      8.78   1210.06      1.00
          X[55]      8.95      0.26      8.95      8.50      9.34   1332.43      1.00
          X[56]      9.25      0.27      9.25      8.77      9.65   1931.84      1.00
          X[57]      9.27      0.27      9.28      8.82      9.69   1255.06      1.00
          X[58]      9.89      0.28      9.91      9.46     10.36    613.94      1.00
          X[59]     10.22      0.28     10.20      9.78     10.66    925.47      1.00
  kernel_length      1.95      0.19      1.96      1.63      2.23    440.62      1.01
   kernel_noise      0.00      0.00      0.00      0.00      0.01    222.75      1.00
     kernel_var      1.18      0.64      1.02      0.39      2.01    423.49      1.01

Number of divergences: 0

MCMC elapsed time: 23.19466996192932
# do prediction
vmap_args = (random.split(rng_key_predict, args.num_samples * args.num_chains), samples['kernel_var'],
                samples['kernel_length'], samples['kernel_noise'])
means, predictions = vmap(lambda rng_key, var, length, noise:
                            predict(rng_key, X, Y, X_test, var, length, noise))(*vmap_args)

mean_prediction = onp.mean(means, axis=0)
percentiles = onp.percentile(predictions, [5.0, 95.0], axis=0)
# make plots
fig, ax = plt.subplots(1, 1)

# plot training data
ax.plot(X, Y, 'kx')
# plot 90% confidence level of predictions
ax.fill_between(X_test, percentiles[0, :], percentiles[1, :], color='lightblue')
# plot mean prediction
ax.plot(X_test, mean_prediction, 'blue', ls='solid', lw=2.0)
ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI")
[Text(0, 0.5, 'Y'),
 Text(0.5, 0, 'X'),
 Text(0.5, 1.0, 'Mean predictions with 90% CI')]

png

def emodel(Xmu, Y):
    # set uninformative log-normal priors on our three kernel hyperparameters
    var = numpyro.sample("kernel_var", dist.LogNormal(0.0, 10.0))
    noise = numpyro.sample("kernel_noise", dist.LogNormal(0.0, 10.0))
    length = numpyro.sample("kernel_length", dist.LogNormal(0.0, 10.0))
    # X = numpyro.sample("X", dist.Normal(Xmu, 0.15), )
    Xstd = numpyro.sample("Xstd", dist.Normal(0.0, 0.3), sample_shape=(Xmu.shape[0],))
    X = Xmu + Xstd
    # X = numpyro.sample("X", dist.Normal(Xmu, 0.3 * np.ones(Xmu.shape[-1])), )


    # compute kernel
    k = kernel(X, X, var, length, noise)

    # sample Y according to the standard gaussian process formula
    numpyro.sample("Y", dist.MultivariateNormal(loc=np.zeros(X.shape[0]), covariance_matrix=k),
                   obs=Y)
rng_key, rng_key_predict = random.split(random.PRNGKey(0))

# Run inference scheme
samples = run_inference(emodel, args, rng_key, X, Y, )
sample: 100%|██████████| 1100/1100 [00:17<00:00, 62.81it/s, 15 steps of size 2.65e-01. acc. prob=0.89]



                     mean       std    median      5.0%     95.0%     n_eff     r_hat
        Xstd[0]      0.20      0.26      0.22     -0.23      0.62    792.72      1.00
        Xstd[1]     -0.15      0.26     -0.15     -0.63      0.22    951.20      1.00
        Xstd[2]     -0.09      0.26     -0.11     -0.50      0.37    832.81      1.00
        Xstd[3]      0.10      0.24      0.10     -0.28      0.49    982.06      1.00
        Xstd[4]     -0.25      0.22     -0.26     -0.63      0.09    934.13      1.00
        Xstd[5]      0.09      0.19      0.11     -0.19      0.42    531.47      1.00
        Xstd[6]      0.13      0.18      0.15     -0.18      0.43    452.94      1.00
        Xstd[7]     -0.28      0.13     -0.28     -0.49     -0.06    495.33      1.00
        Xstd[8]      0.14      0.13      0.15     -0.05      0.35    365.29      1.00
        Xstd[9]     -0.10      0.12     -0.10     -0.30      0.09    306.55      1.00
       Xstd[10]     -0.21      0.12     -0.20     -0.42     -0.02    304.69      1.00
       Xstd[11]     -0.05      0.13     -0.05     -0.26      0.18    284.53      1.00
       Xstd[12]     -0.21      0.17     -0.21     -0.49      0.05    415.29      1.00
       Xstd[13]      0.52      0.19      0.50      0.20      0.81    540.49      1.00
       Xstd[14]      0.12      0.24      0.11     -0.29      0.47    898.49      1.01
       Xstd[15]      0.15      0.23      0.14     -0.23      0.52   1165.26      1.00
       Xstd[16]     -0.08      0.33     -0.10     -0.60      0.44    904.75      1.00
       Xstd[17]     -0.00      0.29     -0.00     -0.53      0.43   1652.78      1.00
       Xstd[18]     -0.01      0.32      0.00     -0.49      0.54   1462.54      1.00
       Xstd[19]     -0.02      0.28     -0.01     -0.47      0.47    903.68      1.00
       Xstd[20]      0.14      0.36      0.17     -0.47      0.66    905.66      1.00
       Xstd[21]      0.05      0.36      0.06     -0.54      0.61    648.84      1.00
       Xstd[22]      0.04      0.30      0.07     -0.44      0.51   1011.25      1.00
       Xstd[23]     -0.01      0.27      0.00     -0.43      0.45   1237.85      1.00
       Xstd[24]     -0.20      0.16     -0.20     -0.45      0.06    419.10      1.00
       Xstd[25]     -0.69      0.16     -0.68     -0.98     -0.46    379.36      1.00
       Xstd[26]     -0.33      0.13     -0.33     -0.54     -0.12    320.10      1.00
       Xstd[27]      0.09      0.13      0.09     -0.10      0.30    245.93      1.01
       Xstd[28]      0.50      0.13      0.50      0.30      0.71    253.37      1.01
       Xstd[29]     -0.04      0.13     -0.04     -0.24      0.18    259.11      1.01
       Xstd[30]      0.36      0.14      0.36      0.13      0.57    296.81      1.01
       Xstd[31]      0.07      0.19      0.06     -0.23      0.37    539.63      1.00
       Xstd[32]      0.18      0.21      0.17     -0.18      0.50    868.61      1.00
       Xstd[33]     -0.09      0.31     -0.14     -0.54      0.45    551.52      1.00
       Xstd[34]      0.04      0.27      0.04     -0.35      0.53   1343.46      1.00
       Xstd[35]     -0.01      0.29     -0.01     -0.48      0.51   1573.42      1.00
       Xstd[36]     -0.04      0.29     -0.04     -0.51      0.44   1578.87      1.00
       Xstd[37]      0.02      0.31      0.03     -0.48      0.53   2398.18      1.00
       Xstd[38]     -0.00      0.29     -0.00     -0.45      0.47   1411.13      1.00
       Xstd[39]     -0.01      0.30     -0.01     -0.49      0.48   2119.89      1.00
       Xstd[40]     -0.11      0.25     -0.10     -0.49      0.33    537.76      1.00
       Xstd[41]      0.00      0.30      0.01     -0.48      0.51    934.64      1.00
       Xstd[42]      0.09      0.32      0.12     -0.49      0.55   1000.19      1.00
       Xstd[43]     -0.61      0.20     -0.60     -0.92     -0.28    716.21      1.00
       Xstd[44]      0.31      0.15      0.32      0.08      0.57    487.44      1.00
       Xstd[45]     -0.49      0.12     -0.48     -0.68     -0.28    426.20      1.00
       Xstd[46]      0.30      0.12      0.30      0.11      0.49    383.15      1.00
       Xstd[47]      0.34      0.12      0.34      0.15      0.53    329.32      1.00
       Xstd[48]     -0.21      0.13     -0.22     -0.41     -0.01    383.74      1.00
       Xstd[49]     -0.12      0.15     -0.13     -0.37      0.10    392.93      1.00
       Xstd[50]      0.03      0.25      0.01     -0.37      0.40    668.25      1.00
       Xstd[51]      0.05      0.25      0.03     -0.37      0.46    928.56      1.00
       Xstd[52]      0.26      0.22      0.24     -0.12      0.60    776.53      1.00
       Xstd[53]     -0.14      0.33     -0.17     -0.62      0.48    672.39      1.00
       Xstd[54]      0.06      0.32      0.08     -0.49      0.54   1436.10      1.00
       Xstd[55]      0.06      0.26      0.05     -0.35      0.49   1649.02      1.00
       Xstd[56]     -0.02      0.29     -0.02     -0.47      0.44   1683.10      1.00
       Xstd[57]     -0.01      0.27     -0.01     -0.42      0.48   1384.29      1.00
       Xstd[58]      0.05      0.28      0.06     -0.37      0.54   1061.70      1.00
       Xstd[59]     -0.06      0.26     -0.08     -0.46      0.39   1705.82      1.00
  kernel_length      1.93      0.21      1.94      1.57      2.26    224.70      1.01
   kernel_noise      0.00      0.00      0.00      0.00      0.01    262.87      1.00
     kernel_var      1.15      0.62      1.00      0.41      1.95    344.59      1.00

Number of divergences: 0

MCMC elapsed time: 19.7586088180542
# do prediction
vmap_args = (random.split(rng_key_predict, args.num_samples * args.num_chains), samples['kernel_var'],
                samples['kernel_length'], samples['kernel_noise'])
means, predictions = vmap(lambda rng_key, var, length, noise:
                            predict(rng_key, X, Y, X_test, var, length, noise))(*vmap_args)

mean_prediction = onp.mean(means, axis=0)
percentiles = onp.percentile(predictions, [5.0, 95.0], axis=0)
# make plots
fig, ax = plt.subplots(1, 1)

# plot training data
ax.plot(X, Y, 'kx')
# plot 90% confidence level of predictions
ax.fill_between(X_test, percentiles[0, :], percentiles[1, :], color='lightblue')
# plot mean prediction
ax.plot(X_test, mean_prediction, 'blue', ls='solid', lw=2.0)
ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI")
[Text(0, 0.5, 'Y'),
 Text(0.5, 0, 'X'),
 Text(0.5, 1.0, 'Mean predictions with 90% CI')]

png

Results

png

png

png