Here are my favourite JAX packages.
Tutorials¶
Special DataStructures¶
- coordinax
- jaxdf
- tree-math
- xarray_jax
- quax
- jaxtyping
Symbolic Math¶
- sympy2jax
Numerical Methods¶
Linear Algebra¶
- Cola, Lab
- autoray
- einx
- einops
- matfree - Matrix-Free linear algebra in JAX
- opt-einsum - optimized einsum (numpy, JAX, TF, PyTorch, Dask, CuPy, Sparse)
Convolutions¶
Integration¶
Interpolation¶
Optimization¶
- Optimistix
- LineaX
- Optax
- JaxOp
- varz - Simple, multi-backend constrained (L-BFGS) and unconstrained optimization (Adam).
- ott
Kernels¶
- mlkernels - Kernel Matrices (JAX, TF, PyTorch, Julia).
- KernelBiome
Differentiation¶
- FiniteDiffX, FiniteVolX, SpectralDiffX
- jax-fem
- Probfindiff
- LapJax
- RBF-FDax
ODESolvers¶
- Diffrax,
- probdiffeq - probabilistic solvers for differential equations
Ordinary Differential Equations¶
Partial Differential Equations¶
Basis Functions¶
PCA/SVD/POD/EOF¶
Fourier¶
Orthogonal¶
Wavelet¶
Spherical Harmonics¶
- SphericalHarmonics - spherical harmonics (numpy, JAX, PyTorch, TF)
- Jax Implementation
Neural Networks¶
Probabilistic¶
Probabilistic Programming Languages¶
- blackjack,
- numpyro
- numpyro-ext
- tfp.substrate.jax
- bayeux