Here are my favourite JAX packages.
Tutorials¶
Advanced Scientific Machine Learning. This is a set of tutorials which are Application-Based. I think there is value in learning JAX from an application-based stand point. I personally like the scientific machine learning perspective. Some highlights include tutorials on functional programming, differentiable programming, and optimization.
Bayesian Modelling and Probabilistic Programming with Numpyro, and Deep Generative Surrogates for Epidemiology. This is easily the best course to learn about probabilistic programming through the lens of JAX and numpyro. I find it very useful when thinking probabilistically. Some highlights include the Bayesian Workflow, Gaussian Processes, and even ODEs!.
Jax 101. These set of tutorials are very exhaustive in terms of the features of JAX. They are very neural network focused but they cover every aspect of why someone would want to use JAX. Some highlights include jax essentials like JIT compilation, automatic vectorization, automatic differentiation. However, there are some more interesting tutorials like PyTrees, Stateful Computations and even parallel computing.
Autodidax: JAX from Scratch. This is more for hardcore devs who are very interested in understanding the underlying aspects of JAX. It really takes you step-by-step into some of the inner workings in an interesting way. I think it’s worth it to just take a look for awareness. However, unless you plan to develop your own packages, it may not be necessary.
Special DataStructures¶
Symbolic Math¶
sympy2jax
Numerical Methods¶
Linear Algebra¶
Cola, Lab
matfree - Matrix-Free linear algebra in JAX
opt-einsum - optimized einsum (numpy, JAX, TF, PyTorch, Dask, CuPy, Sparse)
Convolutions¶
Integration¶
Interpolation¶
Optimization¶
Optimistix
LineaX
Optax
JaxOp
varz - Simple, multi-backend constrained (L-BFGS) and unconstrained optimization (Adam).
Kernels¶
mlkernels - Kernel Matrices (JAX, TF, PyTorch, Julia).
Differentiation¶
FiniteDiffX, FiniteVolX, SpectralDiffX
ODESolvers¶
Diffrax,
probdiffeq - probabilistic solvers for differential equations
Ordinary Differential Equations¶
Partial Differential Equations¶
Basis Functions¶
PCA/SVD/POD/EOF¶
Fourier¶
Orthogonal¶
Wavelet¶
Spherical Harmonics¶
SphericalHarmonics - spherical harmonics (numpy, JAX, PyTorch, TF)
Neural Networks¶
Probabilistic¶
Probabilistic Programming Languages¶
blackjack,
numpyro
tfp.substrate.jax