Jax PlayGround¶
My starting notebook where I install all of the necessary libraries and load some easy 1D/2D Regression data to play around with.
#@title Install Packages
!pip install jax jaxlib
!pip install "git+https://github.com/google/objax.git"
!pip install "git+https://github.com/deepmind/chex.git"
!pip install "git+https://github.com/deepmind/dm-haiku"
!pip install "git+https://github.com/Information-Fusion-Lab-Umass/NuX"
!pip install "git+https://github.com/pyro-ppl/numpyro.git#egg=numpyro"
!pip uninstall tensorflow -y -q
!pip install -Uq tfp-nightly[jax] > /dev/null
#@title Load Packages
# TYPE HINTS
from typing import Tuple, Optional, Dict, Callable, Union
# JAX SETTINGS
import jax
import jax.numpy as np
import jax.random as random
import objax
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
tfpk = tfp.math.psd_kernels
# NUMPY SETTINGS
import numpy as onp
onp.set_printoptions(precision=3, suppress=True)
# MATPLOTLIB Settings
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
# SEABORN SETTINGS
import seaborn as sns
sns.set_context(context='talk',font_scale=0.7)
# PANDAS SETTINGS
import pandas as pd
pd.set_option("display.max_rows", 120)
pd.set_option("display.max_columns", 120)
# LOGGING SETTINGS
import sys
import logging
logging.basicConfig(
level=logging.INFO,
stream=sys.stdout,
format='%(asctime)s:%(levelname)s:%(message)s'
)
logger = logging.getLogger()
#logger.setLevel(logging.INFO)
%load_ext autoreload
%autoreload 2
#@title Data
def get_data(
N: int = 30,
input_noise: float = 0.15,
output_noise: float = 0.15,
N_test: int = 400,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, None]:
onp.random.seed(0)
X = np.linspace(-1, 1, N)
Y = X + 0.2 * np.power(X, 3.0) + 0.5 * np.power(0.5 + X, 2.0) * np.sin(4.0 * X)
Y += output_noise * onp.random.randn(N)
Y -= np.mean(Y)
Y /= np.std(Y)
X += input_noise * onp.random.randn(N)
assert X.shape == (N,)
assert Y.shape == (N,)
X_test = np.linspace(-1.2, 1.2, N_test)
return X[:, None], Y[:, None], X_test[:, None]
from functools import partial
def covariance_matrix(
func: Callable,
x: np.ndarray,
y: np.ndarray,
) -> np.ndarray:
"""Computes the covariance matrix.
Given a function `Callable` and some `params`, we can
use the `jax.vmap` function to calculate the gram matrix
as the function applied to each of the points.
Parameters
----------
kernel_func : Callable
a callable function (kernel or distance)
params : Dict
the parameters needed for the kernel
x : jax.numpy.ndarray
input dataset (n_samples, n_features)
y : jax.numpy.ndarray
other input dataset (n_samples, n_features)
Returns
-------
mat : jax.ndarray
the gram matrix.
Notes
-----
There is little difference between this function
and `gram`
See Also
--------
jax.kernels.gram
Examples
--------
>>> covariance_matrix(kernel_rbf, {"gamma": 1.0}, X, Y)
"""
mapx1 = jax.vmap(lambda x, y: func(x=x, y=y), in_axes=(0, None), out_axes=0)
mapx2 = jax.vmap(lambda x, y: mapx1(x, y), in_axes=(None, 0), out_axes=1)
return mapx2(x, y)
def rbf_kernel(gamma: float, x: np.ndarray, y: np.ndarray) -> np.ndarray:
"""Radial Basis Function (RBF) Kernel.
The most popular kernel in all of kernel methods.
.. math::
k(\mathbf{x,y}) = \\
\\exp \left( - \\gamma\\
||\\mathbf{x} - \\mathbf{y}||^2_2\\
\\right)
Parameters
----------
params : Dict
the parameters needed for the kernel
x : jax.numpy.ndarray
input dataset (n_samples, n_features)
y : jax.numpy.ndarray
other input dataset (n_samples, n_features)
Returns
-------
kernel_mat : jax.numpy.ndarray
the kernel matrix (n_samples, n_samples)
References
----------
.. [1] David Duvenaud, *Kernel Cookbook*
"""
return np.exp(- gamma * sqeuclidean_distance(x, y))
def ard_kernel(x: np.ndarray, y: np.ndarray, length_scale, amplitude) -> np.ndarray:
"""Radial Basis Function (RBF) Kernel.
The most popular kernel in all of kernel methods.
.. math::
k(\mathbf{x,y}) = \\
\\exp \left( - \\gamma\\
||\\mathbf{x} - \\mathbf{y}||^2_2\\
\\right)
Parameters
----------
params : Dict
the parameters needed for the kernel
x : jax.numpy.ndarray
input dataset (n_samples, n_features)
y : jax.numpy.ndarray
other input dataset (n_samples, n_features)
Returns
-------
kernel_mat : jax.numpy.ndarray
the kernel matrix (n_samples, n_samples)
References
----------
.. [1] David Duvenaud, *Kernel Cookbook*
"""
x = x / length_scale
y = y / length_scale
# return the ard kernel
return amplitude * np.exp(-sqeuclidean_distance(x, y))
def sqeuclidean_distance(x: np.array, y: np.array) -> float:
return np.sum((x - y) ** 2)
class RBFKernel(objax.Module):
def __init__(self):
self.gamma = objax.TrainVar(np.array([0.1]))
def __call__(self, X: np.ndarray, Y: np.ndarray)-> np.ndarray:
kernel_func = partial(rbf_kernel, gamma=self.gamma.value)
return covariance_matrix(kernel_func, X, Y).squeeze()
class ARDKernel(objax.Module):
def __init__(self):
self.length_scale = objax.TrainVar(np.array([0.1]))
self.amplitude = objax.TrainVar(np.array([1.]))
def __call__(self, X: np.ndarray, Y: np.ndarray)-> np.ndarray:
kernel_func = partial(
ard_kernel,
length_scale=jax.nn.softplus(self.length_scale.value),
amplitude=jax.nn.softplus(self.amplitude.value)
)
return covariance_matrix(kernel_func, X, Y).squeeze()
class ZeroMean(objax.Module):
def __init__(self):
pass
def __call__(self, X: np.ndarray) -> np.ndarray:
return np.zeros(X.shape[-1], dtype=X.dtype)
class LinearMean(objax.Module):
def __init__(self, input_dim, output_dim):
self.w = objax.TrainVar(objax.random.normal((input_dim, output_dim)))
self.b = objax.TrainVar(np.zeros(output_dim))
def __call__(self, X: np.ndarray) -> np.ndarray:
return np.dot(X.T, self.w.value) + self.b.value
class GaussianLikelihood(objax.Module):
def __init__(self):
self.noise = objax.TrainVar(np.array([0.1]))
def __call__(self, X: np.ndarray) -> np.ndarray:
return np.zeros(X.shape[-1], dtype=X.dtype)
class ExactGP(objax.Module):
def __init__(self, input_dim, output_dim, jitter):
# MEAN FUNCTION
self.mean = ZeroMean()
# KERNEL Function
self.kernel = ARDKernel()
# noise level
self.noise = objax.TrainVar(np.array([0.1]))
# jitter (make it correctly conditioned)
self.jitter = jitter
def forward(self, X: np.ndarray) -> np.ndarray:
# mean function
mu = self.mean(X)
# kernel function
cov = self.kernel(X, X)
# noise model
cov += jax.nn.softplus(self.noise.value) * np.eye(X.shape[0])
# jitter
cov += self.jitter * np.eye(X.shape[0])
# calculate cholesky
cov_chol = np.linalg.cholesky(cov)
# gaussian process likelihood
return tfd.MultivariateNormalTriL(loc=mu, scale_tril=cov_chol)
def predict(self, X: np.ndarray) -> np.ndarray:
pass
def sample(self, n_samples: int, key: None) -> np.ndarray:
pass
gp_model = ExactGP(X.shape[0], 1, 1e-5)
dist = gp_model.forward(X)
gp_model.vars()
plt.imshow(dist.covariance())
key = random.PRNGKey(0)
samples = dist.sample(10, key)
plt.plot(samples.T)
# Settings
lr = 0.01 # learning rate
batch = 256
epochs = 50
gp_model = ExactGP(X.shape[0], 1, 1e-5)
def loss(X, label):
dist = gp_model.forward(X)
return - dist.log_prob(label).mean()
opt = objax.optimizer.SGD(gp_model.vars())
gv = objax.GradValues(loss, gp_model.vars())
def train_op(x, label):
g, v = gv(x, label) # returns gradients, loss
opt(lr, g)
return v
# This line is optional: it is compiling the code to make it faster.
train_op = objax.Jit(train_op, gv.vars() + opt.vars())
losses = []
for epoch in range(epochs):
# Train
loss = train_op(X, y.squeeze())
losses.append(loss)
gp_model.noise.value, jax.nn.softplus(gp_model.noise.value)
plt.plot(losses)
Posterior¶
from typing import Tuple, Optional, Callable
def cholesky_factorization(K: np.ndarray, Y: np.ndarray) -> Tuple[np.ndarray, bool]:
"""Cholesky Factorization"""
# cho factor the cholesky
L = jax.scipy.linalg.cho_factor(K, lower=True)
# weights
weights = jax.scipy.linalg.cho_solve(L, Y)
return L, weights
def get_factorizations(
X: np.ndarray,
Y: np.ndarray,
likelihood_noise: float,
mean_f: Callable,
kernel: Callable,
) -> Tuple[Tuple[np.ndarray, bool], np.ndarray]:
"""Cholesky Factorization"""
# ==========================
# 1. GP PRIOR
# ==========================
mu_x = mean_f(X)
Kxx = kernel(X, X)
# ===========================
# 2. CHOLESKY FACTORIZATION
# ===========================
print(mu_x)
print(Y.reshape(-1, 1).shape, mu_x.reshape(-1, 1).shape)
L, alpha = cholesky_factorization(
Kxx + likelihood_noise * np.eye(Kxx.shape[0]),
Y.reshape(-1, 1) - mu_x.reshape(-1, 1),
)
# ================================
# 4. PREDICTIVE MEAN DISTRIBUTION
# ================================
return L, alpha
def posterior(
Xnew, X, y,
likelihood_noise,
mean_f,
kernel
):
#
L, alpha = get_factorizations(
X, y,
likelihood_noise,
mean_f,
kernel
)
K_Xx = gp_model.kernel(Xnew, X)
# Calculate the Mean
mu_y = np.dot(K_Xx, alpha)
# =====================================
# 5. PREDICTIVE COVARIANCE DISTRIBUTION
# =====================================
v = jax.scipy.linalg.cho_solve(L, K_Xx.T)
# Calculate kernel matrix for inputs
K_xx = gp_model.kernel(Xnew, Xnew)
cov_y = K_xx - np.dot(K_Xx, v)
return mu_y, cov_y
mu, cov = posterior(
X, X, y.squeeze(),
jax.nn.softplus(gp_model.noise.value),
gp_model.mean,
gp_model.kernel
)
(1.96 * np.sqrt(np.diag(cov))).shape, mu.shape
plt.plot(X, mu)
plt.plot(X, mu.squeeze() + 1.96 * np.sqrt(np.diag(cov) + jax.nn.softplus(gp_model.noise.value)))
plt.plot(X, mu.squeeze() - 1.96 * np.sqrt(np.diag(cov) + jax.nn.softplus(gp_model.noise.value)))
plt.show()
dist
loss(dist, y.squeeze())
#@title Distribution Data
from scipy.stats import beta
a, b = 3.0, 10.0
data_dist = beta(a, b)
x_samples = data_dist.rvs(1_000, 123)
# x_samples = data_dist.rvs(1_000, 123)
plt.hist(x_samples, bins=100);