MonteCarlo¶
Numpyro¶
Probabilistic programming with numpy by Jax for autograd and JIT compilation to GPU/TPU/CPU.
Predict¶
def predict(model, rng_key, samples, X):
model = handlers.substitute(handlers.seed(model, rng_key), samples)
model_trace = handlers.trace(model).get_trace(X=X, Y=None)
Y = model_trace["Y"]["value"]
return Y
Sample¶
def sample(
model,
n_samples: int,
n_warmup: int,
n_chains: int,
seed: int,
chain_method: str="parallel",
summary: bool=True,
**kwargs: Dict={},
):
# generate random key
rng_key = random.PRNGKey(seed)
# generate model from NUTS
kernel = NUTS(model)
# Note: sampling
mcmc = MCMC(kernel, n_warmup, n_samples, n_chains, chain_method=chain_method)
mcmc.run(rng_key, **kwargs)
if summary:
mcmc.print_summary()
return mcmc
MCX¶
A library to compile probabilitistc programs for performant Inference on CPU & GPU
from jax import numpy as np
import mcx
import mcx.distributions as dist
x_data = np.array([2.3, 8.2, 1.8])
y_data = np.array([1.7, 7., 3.1])
@mcx.model
def linear_regression(x, lmbda=1.):
scale @ dist.Exponential(lmbda)
coefs @ dist.Normal(np.zeros(np.shape(x)[-1]))
y = np.dot(x, coefs)
predictions @ dist.Normal(y, scale)
return predictions
rng_key = jax.random.PRNGKey(0)
# Sample the model forward, conditioning on the value of `x`
mcx.sample_forward(
rng_key,
linear_regression,
x=x_data,
num_samples=10_000
)
# Sample from the posterior distribution using HMC
kernel = mcx.HMC(
step_size=0.01,
num_integration_steps=100,
inverse_mass_matrix=np.array([1., 1.]),
)
observations = {'x': x_data, 'predictions': y_data, 'lmbda': 3.}
sampler = mcx.sample(
rng_key,
linear_regression,
kernel,
**observations
)
trace = sampler.run()