Skip to content

Jax PlayGround

My starting notebook where I install all of the necessary libraries and load some easy 1D/2D Regression data to play around with.

#@title Install Packages
!pip install jax jaxlib
!pip install "git+https://github.com/google/objax.git"
!pip install "git+https://github.com/deepmind/chex.git"
!pip install "git+https://github.com/deepmind/dm-haiku"
!pip install "git+https://github.com/Information-Fusion-Lab-Umass/NuX"
!pip install "git+https://github.com/pyro-ppl/numpyro.git#egg=numpyro"
!pip uninstall tensorflow -y -q
!pip install -Uq tfp-nightly[jax] > /dev/null
Requirement already satisfied: jax in /usr/local/lib/python3.6/dist-packages (0.1.75)
Requirement already satisfied: jaxlib in /usr/local/lib/python3.6/dist-packages (0.1.52)
Requirement already satisfied: absl-py in /usr/local/lib/python3.6/dist-packages (from jax) (0.10.0)
Requirement already satisfied: opt-einsum in /usr/local/lib/python3.6/dist-packages (from jax) (3.3.0)
Requirement already satisfied: numpy>=1.12 in /usr/local/lib/python3.6/dist-packages (from jax) (1.18.5)
Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from jaxlib) (1.4.1)
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from absl-py->jax) (1.15.0)
Collecting git+https://github.com/google/objax.git
  Cloning https://github.com/google/objax.git to /tmp/pip-req-build-cmqfg5f3
  Running command git clone -q https://github.com/google/objax.git /tmp/pip-req-build-cmqfg5f3
Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from objax==1.0.2) (1.4.1)
Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.6/dist-packages (from objax==1.0.2) (1.18.5)
Requirement already satisfied: pillow in /usr/local/lib/python3.6/dist-packages (from objax==1.0.2) (7.0.0)
Requirement already satisfied: jaxlib in /usr/local/lib/python3.6/dist-packages (from objax==1.0.2) (0.1.52)
Requirement already satisfied: jax in /usr/local/lib/python3.6/dist-packages (from objax==1.0.2) (0.1.75)
Requirement already satisfied: tensorboard>=2.3.0 in /usr/local/lib/python3.6/dist-packages (from objax==1.0.2) (2.3.0)
Collecting parameterized
  Downloading https://files.pythonhosted.org/packages/ba/6b/73dfed0ab5299070cf98451af50130989901f50de41fe85d605437a0210f/parameterized-0.7.4-py2.py3-none-any.whl
Requirement already satisfied: absl-py in /usr/local/lib/python3.6/dist-packages (from jaxlib->objax==1.0.2) (0.10.0)
Requirement already satisfied: opt-einsum in /usr/local/lib/python3.6/dist-packages (from jax->objax==1.0.2) (3.3.0)
Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.3.0->objax==1.0.2) (50.3.0)
Requirement already satisfied: grpcio>=1.24.3 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.3.0->objax==1.0.2) (1.32.0)
Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.3.0->objax==1.0.2) (1.0.1)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.3.0->objax==1.0.2) (3.2.2)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.3.0->objax==1.0.2) (0.4.1)
Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.3.0->objax==1.0.2) (2.23.0)
Requirement already satisfied: wheel>=0.26; python_version >= "3" in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.3.0->objax==1.0.2) (0.35.1)
Requirement already satisfied: google-auth<2,>=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.3.0->objax==1.0.2) (1.17.2)
Requirement already satisfied: protobuf>=3.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.3.0->objax==1.0.2) (3.12.4)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.3.0->objax==1.0.2) (1.7.0)
Requirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from tensorboard>=2.3.0->objax==1.0.2) (1.15.0)
Requirement already satisfied: importlib-metadata; python_version < "3.8" in /usr/local/lib/python3.6/dist-packages (from markdown>=2.6.8->tensorboard>=2.3.0->objax==1.0.2) (2.0.0)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.3.0->objax==1.0.2) (1.3.0)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard>=2.3.0->objax==1.0.2) (2.10)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard>=2.3.0->objax==1.0.2) (3.0.4)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard>=2.3.0->objax==1.0.2) (2020.6.20)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3,>=2.21.0->tensorboard>=2.3.0->objax==1.0.2) (1.24.3)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard>=2.3.0->objax==1.0.2) (4.1.1)
Requirement already satisfied: rsa<5,>=3.1.4; python_version >= "3" in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard>=2.3.0->objax==1.0.2) (4.6)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tensorboard>=2.3.0->objax==1.0.2) (0.2.8)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < "3.8"->markdown>=2.6.8->tensorboard>=2.3.0->objax==1.0.2) (3.2.0)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard>=2.3.0->objax==1.0.2) (3.1.0)
Requirement already satisfied: pyasn1>=0.1.3 in /usr/local/lib/python3.6/dist-packages (from rsa<5,>=3.1.4; python_version >= "3"->google-auth<2,>=1.6.3->tensorboard>=2.3.0->objax==1.0.2) (0.4.8)
Building wheels for collected packages: objax
  Building wheel for objax (setup.py) ... done
  Created wheel for objax: filename=objax-1.0.2-cp36-none-any.whl size=64182 sha256=83c24197ad3d62b42c8897a1c8c3bcb4d3c9684c159bfb500f67b5aac82c0753
  Stored in directory: /tmp/pip-ephem-wheel-cache-wy9fypuj/wheels/ff/75/37/8672991ae92977b2180136f69557c03ced92f87c74eff31761
Successfully built objax
Installing collected packages: parameterized, objax
Successfully installed objax-1.0.2 parameterized-0.7.4
Collecting git+https://github.com/deepmind/chex.git
  Cloning https://github.com/deepmind/chex.git to /tmp/pip-req-build-oui3l10s
  Running command git clone -q https://github.com/deepmind/chex.git /tmp/pip-req-build-oui3l10s
Requirement already satisfied: absl-py>=0.9.0 in /usr/local/lib/python3.6/dist-packages (from chex==0.0.2) (0.10.0)
Requirement already satisfied: jax>=0.1.55 in /usr/local/lib/python3.6/dist-packages (from chex==0.0.2) (0.1.75)
Requirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.6/dist-packages (from chex==0.0.2) (0.1.52)
Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.6/dist-packages (from chex==0.0.2) (1.18.5)
Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.6/dist-packages (from chex==0.0.2) (0.11.1)
Requirement already satisfied: dataclasses>=0.7 in /usr/local/lib/python3.6/dist-packages (from chex==0.0.2) (0.7)
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from absl-py>=0.9.0->chex==0.0.2) (1.15.0)
Requirement already satisfied: opt-einsum in /usr/local/lib/python3.6/dist-packages (from jax>=0.1.55->chex==0.0.2) (3.3.0)
Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from jaxlib>=0.1.37->chex==0.0.2) (1.4.1)
Building wheels for collected packages: chex
  Building wheel for chex (setup.py) ... done
  Created wheel for chex: filename=chex-0.0.2-cp36-none-any.whl size=44002 sha256=30674f7c9349eb65b8a34cb8528347bafb3dd29c4fb3f70ae868aa71bd168926
  Stored in directory: /tmp/pip-ephem-wheel-cache-f79rv8h3/wheels/36/04/94/19c6b9a94d01685be55d12c1ada626f07828e0e4486dcb81ea
Successfully built chex
Installing collected packages: chex
Successfully installed chex-0.0.2
Collecting git+https://github.com/deepmind/dm-haiku
  Cloning https://github.com/deepmind/dm-haiku to /tmp/pip-req-build-h3p9xhwo
  Running command git clone -q https://github.com/deepmind/dm-haiku /tmp/pip-req-build-h3p9xhwo
Requirement already satisfied: absl-py>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from dm-haiku==0.0.2) (0.10.0)
Requirement already satisfied: numpy>=1.18.0 in /usr/local/lib/python3.6/dist-packages (from dm-haiku==0.0.2) (1.18.5)
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from absl-py>=0.7.1->dm-haiku==0.0.2) (1.15.0)
Building wheels for collected packages: dm-haiku
  Building wheel for dm-haiku (setup.py) ... done
  Created wheel for dm-haiku: filename=dm_haiku-0.0.2-cp36-none-any.whl size=293301 sha256=9576abccac877acd3875444d00f07df6a8d9f61067f382b32805324c2abb0c3c
  Stored in directory: /tmp/pip-ephem-wheel-cache-xnh_0n7r/wheels/97/0f/e9/17f34e377f8d4060fa88a7e82bee5d8afbf7972384768a5499
Successfully built dm-haiku
Installing collected packages: dm-haiku
Successfully installed dm-haiku-0.0.2
Collecting git+https://github.com/Information-Fusion-Lab-Umass/NuX
  Cloning https://github.com/Information-Fusion-Lab-Umass/NuX to /tmp/pip-req-build-9cddfy42
  Running command git clone -q https://github.com/Information-Fusion-Lab-Umass/NuX /tmp/pip-req-build-9cddfy42
Requirement already satisfied: numpy>=1.12 in /usr/local/lib/python3.6/dist-packages (from nux==1.0.2) (1.18.5)
Requirement already satisfied: jax in /usr/local/lib/python3.6/dist-packages (from nux==1.0.2) (0.1.75)
Requirement already satisfied: jaxlib in /usr/local/lib/python3.6/dist-packages (from nux==1.0.2) (0.1.52)
Requirement already satisfied: opt-einsum in /usr/local/lib/python3.6/dist-packages (from jax->nux==1.0.2) (3.3.0)
Requirement already satisfied: absl-py in /usr/local/lib/python3.6/dist-packages (from jax->nux==1.0.2) (0.10.0)
Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from jaxlib->nux==1.0.2) (1.4.1)
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from absl-py->jax->nux==1.0.2) (1.15.0)
Building wheels for collected packages: nux
  Building wheel for nux (setup.py) ... done
  Created wheel for nux: filename=nux-1.0.2-cp36-none-any.whl size=48811 sha256=aa131a1f61296fb04af02a29c9fa79a7a1e007b1165d9a7184eb7c48b0188a82
  Stored in directory: /tmp/pip-ephem-wheel-cache-pd8_7ba1/wheels/66/78/5c/a78e73a116ea5b926e1a923369e0de3ebcb26f731b37321a23
Successfully built nux
Installing collected packages: nux
Successfully installed nux-1.0.2
Collecting numpyro
  Cloning https://github.com/pyro-ppl/numpyro.git to /tmp/pip-install-sw3d2b2p/numpyro
  Running command git clone -q https://github.com/pyro-ppl/numpyro.git /tmp/pip-install-sw3d2b2p/numpyro
Collecting jax>=0.2
  Downloading https://files.pythonhosted.org/packages/e1/d9/9bd335976d3b61f705c2e9c35da2c6e030f9cd9ffd3e111feb99d8d169a7/jax-0.2.0.tar.gz (454kB)
     |████████████████████████████████| 460kB 2.8MB/s 
Collecting jaxlib>=0.1.55
  Downloading https://files.pythonhosted.org/packages/a0/e2/7e2c7e5b2b2b06c0868f8408f5ed016f8ee83540381cfe43d96bf1e8463b/jaxlib-0.1.55-cp36-none-manylinux2010_x86_64.whl (31.9MB)
     |████████████████████████████████| 31.9MB 142kB/s 
Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from numpyro) (4.41.1)
Requirement already satisfied: numpy>=1.12 in /usr/local/lib/python3.6/dist-packages (from jax>=0.2->numpyro) (1.18.5)
Requirement already satisfied: absl-py in /usr/local/lib/python3.6/dist-packages (from jax>=0.2->numpyro) (0.10.0)
Requirement already satisfied: opt_einsum in /usr/local/lib/python3.6/dist-packages (from jax>=0.2->numpyro) (3.3.0)
Requirement already satisfied: scipy in /usr/local/lib/python3.6/dist-packages (from jaxlib>=0.1.55->numpyro) (1.4.1)
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from absl-py->jax>=0.2->numpyro) (1.15.0)
Building wheels for collected packages: numpyro, jax
  Building wheel for numpyro (setup.py) ... done
  Created wheel for numpyro: filename=numpyro-0.3.0-cp36-none-any.whl size=175165 sha256=d124ed6a83a31136cd17459e6bfa6fec7d50bcbc3a6f2fc5815ca76821f50e5f
  Stored in directory: /tmp/pip-ephem-wheel-cache-0gst43qz/wheels/d3/66/a6/667201f3fa85a83c93cc19efb1c1c5869a7b53f450a4031b52
  Building wheel for jax (setup.py) ... done
  Created wheel for jax: filename=jax-0.2.0-cp36-none-any.whl size=522279 sha256=5381086a9e5a03dc3d075de4a3fae5acf2aa73d2238be2d02116a2df59202135
  Stored in directory: /root/.cache/pip/wheels/99/f1/91/e9c21aca3142a6d2e5e760162fd65a1430438b7630a0b75591
Successfully built numpyro jax
Installing collected packages: jax, jaxlib, numpyro
  Found existing installation: jax 0.1.75
    Uninstalling jax-0.1.75:
      Successfully uninstalled jax-0.1.75
  Found existing installation: jaxlib 0.1.52
    Uninstalling jaxlib-0.1.52:
      Successfully uninstalled jaxlib-0.1.52
Successfully installed jax-0.2.0 jaxlib-0.1.55 numpyro-0.3.0
#@title Load Packages
# TYPE HINTS
from typing import Tuple, Optional, Dict, Callable, Union

# JAX SETTINGS
import jax
import jax.numpy as np
import jax.random as random
import objax

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
tfpk = tfp.math.psd_kernels

# NUMPY SETTINGS
import numpy as onp
onp.set_printoptions(precision=3, suppress=True)

# MATPLOTLIB Settings
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

# SEABORN SETTINGS
import seaborn as sns
sns.set_context(context='talk',font_scale=0.7)

# PANDAS SETTINGS
import pandas as pd
pd.set_option("display.max_rows", 120)
pd.set_option("display.max_columns", 120)

# LOGGING SETTINGS
import sys
import logging
logging.basicConfig(
    level=logging.INFO, 
    stream=sys.stdout,
    format='%(asctime)s:%(levelname)s:%(message)s'
)
logger = logging.getLogger()
#logger.setLevel(logging.INFO)

%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
#@title Data
def get_data(
    N: int = 30,
    input_noise: float = 0.15,
    output_noise: float = 0.15,
    N_test: int = 400,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, None]:
    onp.random.seed(0)
    X = np.linspace(-1, 1, N)
    Y = X + 0.2 * np.power(X, 3.0) + 0.5 * np.power(0.5 + X, 2.0) * np.sin(4.0 * X)
    Y += output_noise * onp.random.randn(N)
    Y -= np.mean(Y)
    Y /= np.std(Y)

    X += input_noise * onp.random.randn(N)

    assert X.shape == (N,)
    assert Y.shape == (N,)

    X_test = np.linspace(-1.2, 1.2, N_test)

    return X[:, None], Y[:, None], X_test[:, None]
X, y, Xtest = get_data(100, 0.0, 0.05, 100)
/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
  warnings.warn('No GPU/TPU found, falling back to CPU.')

Kernel Functions

from functools import partial

def covariance_matrix(
    func: Callable,
    x: np.ndarray,
    y: np.ndarray,
) -> np.ndarray:
    """Computes the covariance matrix.
    Given a function `Callable` and some `params`, we can
    use the `jax.vmap` function to calculate the gram matrix
    as the function applied to each of the points.
    Parameters
    ----------
    kernel_func : Callable
        a callable function (kernel or distance)
    params : Dict
        the parameters needed for the kernel
    x : jax.numpy.ndarray
        input dataset (n_samples, n_features)
    y : jax.numpy.ndarray
        other input dataset (n_samples, n_features)
    Returns
    -------
    mat : jax.ndarray
        the gram matrix.
    Notes
    -----
        There is little difference between this function
        and `gram`
    See Also
    --------
    jax.kernels.gram
    Examples
    --------
    >>> covariance_matrix(kernel_rbf, {"gamma": 1.0}, X, Y)
    """
    mapx1 = jax.vmap(lambda x, y: func(x=x, y=y), in_axes=(0, None), out_axes=0)
    mapx2 = jax.vmap(lambda x, y: mapx1(x, y), in_axes=(None, 0), out_axes=1)
    return mapx2(x, y)


def rbf_kernel(gamma: float, x: np.ndarray, y: np.ndarray) -> np.ndarray:
    """Radial Basis Function (RBF) Kernel.
    The most popular kernel in all of kernel methods.
    .. math::
        k(\mathbf{x,y}) = \\
           \\exp \left( - \\gamma\\
           ||\\mathbf{x} - \\mathbf{y}||^2_2\\
            \\right) 
    Parameters
    ----------
    params : Dict
        the parameters needed for the kernel
    x : jax.numpy.ndarray
        input dataset (n_samples, n_features)
    y : jax.numpy.ndarray
        other input dataset (n_samples, n_features)
    Returns
    -------
    kernel_mat : jax.numpy.ndarray
        the kernel matrix (n_samples, n_samples)
    References
    ----------
    .. [1] David Duvenaud, *Kernel Cookbook*
    """
    return np.exp(- gamma * sqeuclidean_distance(x, y))


def ard_kernel(x: np.ndarray, y: np.ndarray, length_scale, amplitude) -> np.ndarray:
    """Radial Basis Function (RBF) Kernel.
    The most popular kernel in all of kernel methods.
    .. math::
        k(\mathbf{x,y}) = \\
           \\exp \left( - \\gamma\\
           ||\\mathbf{x} - \\mathbf{y}||^2_2\\
            \\right) 
    Parameters
    ----------
    params : Dict
        the parameters needed for the kernel
    x : jax.numpy.ndarray
        input dataset (n_samples, n_features)
    y : jax.numpy.ndarray
        other input dataset (n_samples, n_features)
    Returns
    -------
    kernel_mat : jax.numpy.ndarray
        the kernel matrix (n_samples, n_samples)
    References
    ----------
    .. [1] David Duvenaud, *Kernel Cookbook*
    """
    x = x / length_scale
    y = y / length_scale

    # return the ard kernel
    return amplitude * np.exp(-sqeuclidean_distance(x, y))

def sqeuclidean_distance(x: np.array, y: np.array) -> float:
    return np.sum((x - y) ** 2)


class RBFKernel(objax.Module):

    def __init__(self):
        self.gamma = objax.TrainVar(np.array([0.1]))

    def __call__(self, X: np.ndarray, Y: np.ndarray)-> np.ndarray:
        kernel_func = partial(rbf_kernel, gamma=self.gamma.value)
        return covariance_matrix(kernel_func, X, Y).squeeze()

class ARDKernel(objax.Module):

    def __init__(self):
        self.length_scale = objax.TrainVar(np.array([0.1]))
        self.amplitude = objax.TrainVar(np.array([1.]))

    def __call__(self, X: np.ndarray, Y: np.ndarray)-> np.ndarray:
        kernel_func = partial(
            ard_kernel, 
            length_scale=jax.nn.softplus(self.length_scale.value), 
            amplitude=jax.nn.softplus(self.amplitude.value)
        )
        return covariance_matrix(kernel_func, X, Y).squeeze()

class ZeroMean(objax.Module):
    def __init__(self):
        pass

    def __call__(self, X: np.ndarray) -> np.ndarray:
        return np.zeros(X.shape[-1], dtype=X.dtype)

class LinearMean(objax.Module):
    def __init__(self, input_dim, output_dim):
        self.w = objax.TrainVar(objax.random.normal((input_dim, output_dim)))
        self.b = objax.TrainVar(np.zeros(output_dim))

    def __call__(self, X: np.ndarray) -> np.ndarray:
        return np.dot(X.T, self.w.value) + self.b.value

class GaussianLikelihood(objax.Module):
    def __init__(self):
        self.noise = objax.TrainVar(np.array([0.1]))

    def __call__(self, X: np.ndarray) -> np.ndarray:
        return np.zeros(X.shape[-1], dtype=X.dtype)

class ExactGP(objax.Module):
    def __init__(self, input_dim, output_dim, jitter):

        # MEAN FUNCTION
        self.mean = ZeroMean()

        # KERNEL Function
        self.kernel = ARDKernel()

        # noise level
        self.noise = objax.TrainVar(np.array([0.1]))

        # jitter (make it correctly conditioned)
        self.jitter = jitter

    def forward(self, X: np.ndarray) -> np.ndarray:

        # mean function
        mu = self.mean(X)

        # kernel function
        cov = self.kernel(X, X)

        # noise model
        cov += jax.nn.softplus(self.noise.value) * np.eye(X.shape[0])

        # jitter
        cov += self.jitter * np.eye(X.shape[0])

        # calculate cholesky
        cov_chol = np.linalg.cholesky(cov)

        # gaussian process likelihood
        return tfd.MultivariateNormalTriL(loc=mu, scale_tril=cov_chol)

    def predict(self, X: np.ndarray) -> np.ndarray:
        pass

    def sample(self, n_samples: int, key: None) -> np.ndarray:
        pass

gp_model = ExactGP(X.shape[0], 1, 1e-5)

dist = gp_model.forward(X)
gp_model.vars()
{'(ExactGP).kernel(ARDKernel).amplitude': <objax.variable.TrainVar at 0x7f9f94237f28>,
 '(ExactGP).kernel(ARDKernel).length_scale': <objax.variable.TrainVar at 0x7f9f94237e10>,
 '(ExactGP).noise': <objax.variable.TrainVar at 0x7f9f94237f98>}
plt.imshow(dist.covariance())
<matplotlib.image.AxesImage at 0x7f9f8ff3a4a8>
key = random.PRNGKey(0)
samples = dist.sample(10, key)

plt.plot(samples.T)
# Settings
lr = 0.01  # learning rate
batch = 256
epochs = 50
gp_model = ExactGP(X.shape[0], 1, 1e-5)

def loss(X, label):
    dist = gp_model.forward(X)
    return - dist.log_prob(label).mean()
opt = objax.optimizer.SGD(gp_model.vars())

gv = objax.GradValues(loss, gp_model.vars())

def train_op(x, label):

    g, v = gv(x, label)  # returns gradients, loss
    opt(lr, g)
    return v

# This line is optional: it is compiling the code to make it faster.
train_op = objax.Jit(train_op, gv.vars() + opt.vars())
losses = []
for epoch in range(epochs):

    # Train
    loss = train_op(X, y.squeeze())

    losses.append(loss)
gp_model.noise.value, jax.nn.softplus(gp_model.noise.value)
(DeviceArray([-5.328], dtype=float32), DeviceArray([0.005], dtype=float32))
plt.plot(losses)
[<matplotlib.lines.Line2D at 0x7f9fb6288048>]

Posterior

from typing import Tuple, Optional, Callable

def cholesky_factorization(K: np.ndarray, Y: np.ndarray) -> Tuple[np.ndarray, bool]:
    """Cholesky Factorization"""
    # cho factor the cholesky
    L = jax.scipy.linalg.cho_factor(K, lower=True)

    # weights
    weights = jax.scipy.linalg.cho_solve(L, Y)

    return L, weights

def get_factorizations(
    X: np.ndarray,
    Y: np.ndarray,
    likelihood_noise: float,
    mean_f: Callable,
    kernel: Callable,
) -> Tuple[Tuple[np.ndarray, bool], np.ndarray]:
    """Cholesky Factorization"""

    # ==========================
    # 1. GP PRIOR
    # ==========================
    mu_x = mean_f(X)
    Kxx = kernel(X, X)

    # ===========================
    # 2. CHOLESKY FACTORIZATION
    # ===========================
    print(mu_x)
    print(Y.reshape(-1, 1).shape, mu_x.reshape(-1, 1).shape)
    L, alpha = cholesky_factorization(
        Kxx + likelihood_noise * np.eye(Kxx.shape[0]),
        Y.reshape(-1, 1) - mu_x.reshape(-1, 1),
    )

    # ================================
    # 4. PREDICTIVE MEAN DISTRIBUTION
    # ================================

    return L, alpha

def posterior(
    Xnew, X, y,
    likelihood_noise,
    mean_f,
    kernel
    ):
    #
    L, alpha = get_factorizations(
        X, y,
        likelihood_noise,
        mean_f,
        kernel
    )

    K_Xx = gp_model.kernel(Xnew, X)

    # Calculate the Mean
    mu_y = np.dot(K_Xx, alpha)

    # =====================================
    # 5. PREDICTIVE COVARIANCE DISTRIBUTION
    # =====================================
    v = jax.scipy.linalg.cho_solve(L, K_Xx.T)

    # Calculate kernel matrix for inputs
    K_xx = gp_model.kernel(Xnew, Xnew)

    cov_y = K_xx - np.dot(K_Xx, v)
    return mu_y, cov_y
mu, cov = posterior(
    X, X, y.squeeze(),
    jax.nn.softplus(gp_model.noise.value),
    gp_model.mean,
    gp_model.kernel
)
[0.]
(100, 1) (1, 1)
DeviceArray([[-1.585],
             [-1.575],
             [-1.562],
             [-1.546],
             [-1.527]], dtype=float32)
(1.96 * np.sqrt(np.diag(cov))).shape, mu.shape
((100,), (100, 1))
plt.plot(X, mu)
plt.plot(X, mu.squeeze() + 1.96 * np.sqrt(np.diag(cov)  + jax.nn.softplus(gp_model.noise.value)))
plt.plot(X, mu.squeeze() - 1.96 * np.sqrt(np.diag(cov)  + jax.nn.softplus(gp_model.noise.value)))
plt.show()
dist
loss(dist, y.squeeze())
DeviceArray(-0.267, dtype=float32)
#@title Distribution Data
from scipy.stats import beta

a, b = 3.0, 10.0
data_dist = beta(a, b)


x_samples = data_dist.rvs(1_000, 123)

# x_samples = data_dist.rvs(1_000, 123)

plt.hist(x_samples, bins=100);