GP from Scratch¶
This post will go through how we can build a GP regression model from scratch. I will be going over the formulation as well as how we can code this up from scratch. I did this before a long time ago but I've learned a lot about GPs since then. So I'm putting all of my knowledge together so that I can get a good implementation that goes in parallel with the theory. I am also interested in furthering my research on uncertain GPs where I go over how we can look at input error in GPs.
Materials
The full code can be found in the colab notebook. Later I will refactor everything into a script so I can use it in the future.
Good News
It took me approximately 12 hours in total to code this up from scratch. That's significantly better than last time as that time easily took me a week and some change. And I still had problems with the code afterwards. That's progress, no?
Resources
I saw quite a few tutorials that inspired me to do this tutorial.
- Blog Post -
Excellent blog post that goes over GPs with step-by-step. Necessary equations only.
- Blog Post Series - Peter Roelants
Good blog post series that go through more finer details of GPs using TensorFlow.
Definition¶
The first thing to understand about GPs is that we are actively placing a distribution \mathcal{P}(f) on functions f where these functions can be infinitely long function values f=[f_1, f_2, \ldots]. A GP generalizes the multivariate Gaussian distribution to infinitely many variables.
A GP is a collection of random variables f_1, f_2, \ldots, any finite number of which is Gaussian distributed.
A GP defines a distribution over functions p(f) which can be used for Bayesian regression. (Zhoubin)
Another nice definition is:
Gaussian Process: Any set of function variables \{f_n \}^{N}_{n=1} has a joint Gaussian distribution with mean function m. (Deisenroth)
The nice thing is that this is provided by a mean function \mu and covariance matrix \mathbf{K}
Bayesian Inference Problem¶
Objective¶
Let's have some data set, \mathcal{D}= \left\{ (x_i, y_i)^N_{i=1} \right\}=(X,y)
Model¶
The prior on f is a GP distribution, the likelihood is Gaussian, therefore the posterior on f is also a GP,
So we can make predictions:
We can also do model comparison by way of the marginal likelihood (evidence) so that we can compare and tune the covariance functions
Bayesian Treatment¶
So now how does this look in terms of the Bayes theorem in words:
And mathematically:
where:
- Prior: p(f|X, \theta)=\mathcal{GP}(m_\theta, \mathbf{K}_\theta)
- Likelihood (noise model): p(y|f,X)=\mathcal{N}(y|f(x), \sigma_n^2\mathbf{I})
- Marginal Likelihood (Evidence): p(y|X)=\int_f p(y|f,X)p(f|X)df
- Posterior: p(f|X,y) = \mathcal{GP}(\mu_*, \mathbf{K}_*)
Gaussian Process Regression¶
We only need a few elements to define a Gaussian process in itself. Just a mean function \mu, a covariance matrix \mathbf{K}_\theta and some data, \mathcal{D}.
Code
class GPR:
def __init__(self, mu, kernel, X, y, noise_variance=1e-6):
self.mu = mu
self.kernel = kernel
self.x_train = x_train
self.y_train = y_train
self.noise_variance = noise_variance
Gaussian Process Prior¶
This is the basis of the GP method. Under the assumption that we mentioned above:
where m_\theta is a mean function and \mathbf{K} is a covariance function
We kind of treat these functions as a vector of function values up to infinity in theory f=[f_1, f_2, \ldots]. But in particular we look at the distribution over the function values, for example f_i=f(x_i). So let's look at the joint distribution between N function values f_N and all other function values f_\infty. This is 'normally distributed' so we can write the joint distribution roughly as:
where \Sigma_{NN}\in \mathbb{R}^{N\times N} and \Sigma_{\infty\infty} \in \mathbb{R}^{\infty \times \infty} (or m\rightarrow \infty) to be more precise.
So again, any marginal distribution of a joint Gaussian distribution is still a Gaussian distribution. So if we integrate over all of the functions from the infinite portion, we get:
We can even get more specific and split the f_N into training f_{\text{train}} and testing f_{\text{test}}. It's simply a matter of manipulating joint Gaussian distributions. So again, calculating the marginals:
and we arrive at a joint Gaussian distribution of the training and testing which is still normally distributed due to the marginalization.
Code¶
Code
Honestly, I never work with mean functions. I always assume a zero-mean function and that's it. I don't really know anyone who works with mean functions either. I've seen it used in deep Gaussian processes but I have no expertise in which mean functions to use. So, we'll follow the community standard for now: zero mean function.
def zero_mean(x):
return jnp.zeros(x.shape[0])
The output of the mean function is size \mathbb{R}^{N}.
The most common kernel function you will see in the literature is the Radial Basis Function (RBF). It's a universal approximator and it performs fairly well on most datasets. If your dataset becomes non-linear, then it may start to fail as it is a really smooth function. The kernel function is defined as:
# Squared Euclidean Distance Formula
@jax.jit
def sqeuclidean_distance(x, y):
return jnp.sum((x-y)**2)
# RBF Kernel
@jax.jit
def rbf_kernel(params, x, y):
return jnp.exp( - params['gamma'] * sqeuclidean_distance(x, y))
We also have the more robust version of the RBF with a separate length scale per dimension called the Automatic Relavance Determination (ARD) kernel.
# ARD Kernel
@jax.jit
def ard_kernel(params, x, y):
# divide by the length scale
x = x / params['length_scale']
y = y / params['length_scale']
# return the ard kernel
return params['var_f'] * jnp.exp( - sqeuclidean_distance(x, y) )
The kernel function in the tab over shows how we can calculate the kernel for an input vector. But we need every single combination
# Gram Matrix
def gram(func, params, x, y):
return jax.vmap(lambda x1: jax.vmap(lambda y1: func(params, x1, y1))(y))(x)
Sampling from Prior¶
Now, something a bit more practical, generally speaking when we program the sampling portion of the prior, we need data. The kernel function is as is and has already been defined with its appropriate parameters. Furthermore, we already have defined the mean function \mu when we initialized the mean function above. So we just need to pass the function through the multivariate normal function along with the number of samples we would like to draw from the prior.
Code
# initialize parameters
params = {
'gamma': 10.,
'length_scale': 1e-3,
}
n_samples = 10 # condition on 10 samples
test_X = X[:n_samples, :].copy() # random samples from data distribution
# GP Prior functions (mu, sigma)
mu_f = zero_mean
cov_f = functools.partial(gram, rbf_kernel)
mu_x, cov_x = gp_prior(params, mu_f=mu_f, cov_f=cov_f , x=test_X)
# make it semi-positive definite with jitter
jitter = 1e-6
cov_x_ = cov_x + jitter * jnp.eye(cov_x.shape[0])
n_functions = 10 # number of random functions to draw
key = jax.random.PRNGKey(0) # Jax random numbers boilerplate code
y_samples = jax.random.multivariate_normal(key, mu_x, cov_x_, shape=(n_functions,))
Likelihood (noise model)¶
This comes from our assumption as stated above from y=f(x)+\epsilon.
Alternative Notation: * y\sim \mathcal{N}(f, \sigma_n^2) * \mathcal{N}(f, \sigma_n^2) = \prod_{i=1}^N\mathcal{P}(y_i, f_i)
Posterior¶
Alternative Notation:
- \mathcal{P}(f|y)\propto \mathcal{N}(y|f, \sigma_n^2\mathbf{I})\cdot \mathcal{N}(f|\mu, \mathbf{K}_{ff})
Code
This will easily be the longest function that we need for the GP. In my version, it's not necessary for training the GP. But it is necessary for testing.
def posterior(params, prior_params, X, Y, X_new, likelihood_noise=False, return_cov=False):
(mu_func, cov_func) = prior_params
# ==========================
# 1. GP PRIOR
# ==========================
mu_x, Kxx = gp_prior(params, mu_f=mu_func, cov_f=cov_func, x=X)
# ===========================
# 2. CHOLESKY FACTORIZATION
# ===========================
(L, lower), alpha = cholesky_factorization(
Kxx + (params["likelihood_noise"] + 1e-7) * jnp.eye(Kxx.shape[0]),
Y-mu_func(X).reshape(-1,1)
)
# ================================
# 4. PREDICTIVE MEAN DISTRIBUTION
# ================================
# calculate transform kernel
KxX = cov_func(params, X_new, X)
# Calculate the Mean
mu_y = jnp.dot(KxX, alpha)
# =====================================
# 5. PREDICTIVE COVARIANCE DISTRIBUTION
# =====================================
v = jax.scipy.linalg.cho_solve((L, lower), KxX.T)
# Calculate kernel matrix for inputs
Kxx = cov_func(params, X_new, X_new)
cov_y = Kxx - jnp.dot(KxX, v)
# Likelihood Noise
if likelihood_noise is True:
cov_y += params['likelihood_noise']
# return variance (diagonals of covariance)
if return_cov is not True:
cov_y = jnp.diag(cov_y)
return mu_y, cov_y
Cholesky
A lot of times just straight solving the K^{-1}y=\alpha will give you problems. Many times you'll get an error about the matrix being ill-conditioned and non positive semi-definite. So we have to rectify that with the Cholesky decomposition. K should be a positive semi-definite matrix so, there are more stable ways to solve this. We can use the cholesky decomposition which decomposes K into a product of two lower triangular matrices:
We do this because:
- it's less expensive to calculate the inverse of a triangular matrix
- it's easier to solve systems of equations Ax=b.
There are two convenience terms that allow you to calculate the cholesky decomposition:
cho_factor
- calculates the decomposition K \rightarrow Lcho_solve
- solves the system of equations problem LL^\top \alpha=y
def cholesky_factorization(K, Y):
# cho factor the cholesky, K = LL^T
L = jax.scipy.linalg.cho_factor(K, lower=True)
# alpha, LL^T alpha=y
alpha = jax.scipy.linalg.cho_solve(L, Y)
return L, alpha
Note: If you want to get the cholesky matrix by itself and operator on it without the cho_factor
function, then you should call the cholesky
function directly. The cho_factor
puts random (inexpensive) values in the part of the triangle that's not necessary. Whereas the cholesky
adds zeros there instead.
The variance term also makes use of the K^{-1}. So naturally, we can use the already factored cholesky decompsition to calculate the term.
v = jax.scipy.linalg.cho_solve((L, lower), KxX.T)
var = np.dot(KxX, v)
Joint Probability Distribution¶
To make GPs useful, we want to actually make predictions. This stems from the using the joint distribution of the training data and test data with the formula shown above used to condition on multivariate Gaussians. In terms of the GP function space, we have
Then solving for the marginals, we can come up with the predictive test points.
where:
- \mu*=K_* (K + \sigma^2 I)^{-1}y=K_* \alpha
- \nu^2_*= K_{**} - K_*(K + \sigma^2I)^{-1}K_*^{\top}
Marginal Log-Likelihood¶
The prior m(x), K have hyper-parameters \theta. So learning a \mathcal{GP} implies inferring hyper-parameters from the model.
However, we are not interested in f directly. We can marginalize it out via the integral equation. The marginal of a Gaussian is Gaussian.
Note: Typically we use the \log likelihood instead of a pure likelihood. This is purely for computational purposes. The \log function is monotonic so it doesn't alter the location of the extreme points of the function. Furthermore we typically minimize the -\log instead of the maximum \log for purely practical reasons.
One way to train these functions is to use Maximum A Posterior (MAP) of the hyper-parameters
Marginal Likelihood (Evidence)¶
where: * p(y|f,X)=\mathcal{N}(y|f, \sigma_n^2\mathbf{I}) * p(f|X, \theta)=\mathcal{N}(f|m_\theta, K_\theta)
Note that all we're doing is simply describing each of these elements specifically because all of these quantities are Gaussian distributed.
So the product of two Gaussians is simply a Gaussian. That along with the notion that the integral of all the functions is a normal distribution with mean \mu and covariance K.
Proof
Using the Gaussian identities:
So we can use the same reasoning to combine the prior and the likelihood to get the posterior
Source:
- Alternative Derivation for Log Likelihood - blog
Marginal Log-Likelihood¶
TODO
Proof of Marginal Log-Likelihood
Now we need a cost function that will allow us to get the best hyperparameters that fit our data.
Inverting N\times N matrices is the worse part about GPs in general. There are many techniques to be able to handle them, but for basics, it can become a problem. Furthermore, inverting this Kernel matrix tends to have problems being positive semi-definite. One way we can make this more efficient is to do the cholesky decomposition and then solve our problem that way.
Cholesky Components¶
Let \mathbf{L}=\text{cholesky}(\mathbf{K}+\sigma_n^2\mathbf{I}). We can write the log likelihood in terms of the cholesky decomposition.
This gives us a computational complexity of \mathcal{O}(N + N^2 + N^3)=\mathcal{O}(N^3)
Code
I will demonstrate two ways to do this:
- We will use the equations above
- We will refactor this and use the built-in function
def nll_scratch(gp_priors, params, X, Y) -> float:
(mu_func, cov_func) = gp_priors
# ==========================
# 1. GP PRIOR
# ==========================
mu_x, Kxx = gp_prior(params, mu_f=mu_func, cov_f=cov_func , x=X)
# ===========================
# 2. CHOLESKY FACTORIZATION
# ===========================
(L, lower), alpha = cholesky_factorization(
Kxx + ( params['likelihood_noise'] + 1e-5 ) * jnp.eye(Kxx.shape[0]), Y
)
# ===========================
# 3. Marginal Log-Likelihood
# ===========================
log_likelihood = -0.5 * jnp.einsum("ik,ik->k", Y, alpha) # same as dot(Y.T, alpha)
log_likelihood -= jnp.sum(jnp.log(jnp.diag(L)))
log_likelihood -= ( Kxx.shape[0] / 2 ) * jnp.log(2 * jnp.pi)
return - log_likelihood.sum()
def marginal_likelihood(prior_params, params, Xtrain, Ytrain):
# unpack params
(mu_func, cov_func) = prior_params
# ==========================
# 1. GP Prior, mu(), cov(,)
# ==========================
mu_x = mu_f(Ytrain)
Kxx = cov_f(params, Xtrain, Xtrain)
# ===========================
# 2. GP Likelihood
# ===========================
K_gp = Kxx + ( params['likelihood_noise'] + 1e-6 ) * jnp.eye(Kxx.shape[0])
# ===========================
# 3. Marginal Log-Likelihood
# ===========================
# get log probability
log_prob = jax.scipy.stats.multivariate_normal.logpdf(x=Ytrain.T, mean=mu_x, cov=K_gp)
# sum dimensions and return neg mll
return -log_prob.sum()
source - Dai, GPSS 2018
Training¶
Code
We often have problems when it comes to using optimizers. A lot of times they just don't seem to want to converge and the gradients seem to not change no matter what happens. One trick we can do is to make the optimizer solve a transformed version of the parameters. And then we can take a softmax so that they converge properly.
Jax has a built-in function so we'll just use that.
def saturate(params):
return {ikey:jax.nn.softplus(ivalue) for (ikey, ivalue) in params.items()}
logger.setLevel(logging.INFO)
X, y, Xtest, ytest = get_data(50)
# PRIOR FUNCTIONS (mean, covariance)
mu_f = zero_mean
cov_f = functools.partial(gram, rbf_kernel)
gp_priors = (mu_f, cov_f)
# Kernel, Likelihood parameters
params = {
'gamma': 2.0,
# 'length_scale': 1.0,
# 'var_f': 1.0,
'likelihood_noise': 1.,
}
# saturate parameters with likelihoods
params = saturate(params)
# LOSS FUNCTION
mll_loss = jax.jit(functools.partial(marginal_likelihood, gp_priors))
# GRADIENT LOSS FUNCTION
dloss = jax.jit(jax.grad(mll_loss))
# STEP FUNCTION
@jax.jit
def step(params, X, y, opt_state):
# calculate loss
loss = mll_loss(params, X, y)
# calculate gradient of loss
grads = dloss(params, X, y)
# update optimizer state
opt_state = opt_update(0, grads, opt_state)
# update params
params = get_params(opt_state)
return params, opt_state, loss
# initialize optimizer
opt_init, opt_update, get_params = optimizers.rmsprop(step_size=1e-2)
# initialize parameters
opt_state = opt_init(params)
# get initial parameters
params = get_params(opt_state)
# TRAINING PARARMETERS
n_epochs = 500
learning_rate = 0.1
losses = list()
postfix = {}
import tqdm
with tqdm.trange(n_epochs) as bar:
for i in bar:
# 1 step - optimize function
params, opt_state, value = step(params, X, y, opt_state)
# update params
postfix = {}
for ikey in params.keys():
postfix[ikey] = f"{jax.nn.softplus(params[ikey]):.2f}"
# save loss values
losses.append(value.mean())
# update progress bar
postfix["Loss"] = f"{onp.array(losses[-1]):.2f}"
bar.set_postfix(postfix)
# saturate params
params = saturate(params)
Resources¶
- Surrogates: GP Modeling, Design, and Optimization for the Applied Sciences - Gramacy - Online Book