Curated list of awesome lists
Awesome JAX 
JAX brings automatic differentiation and the XLA compiler together through a NumPy-like API for high performance machine learning research on accelerators like GPUs and TPUs.
This is a curated list of awesome JAX libraries, projects, and other resources. Contributions are welcome!
Contents
Libraries
- Neural Network Libraries
-
Flax - Centered on flexibility and clarity.
-
Haiku - Focused on simplicity, created by the authors of Sonnet at DeepMind.
-
Objax - Has an object oriented design similar to PyTorch.
-
Elegy - A High Level API for Deep Learning in JAX. Supports Flax, Haiku, and Optax.
-
Trax - "Batteries included" deep learning library focused on providing solutions for common workloads.
-
Jraph - Lightweight graph neural network library.
-
Neural Tangents - High-level API for specifying neural networks of both finite and infinite width.
-
HuggingFace - Ecosystem of pretrained Transformers for a wide range of natural language tasks (Flax).
-
Equinox - Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX.
-
NumPyro - Probabilistic programming based on the Pyro library.
-
Chex - Utilities to write and test reliable JAX code.
-
Optax - Gradient processing and optimization library.
-
RLax - Library for implementing reinforcement learning agents.
-
JAX, M.D. - Accelerated, differential molecular dynamics.
-
Coax - Turn RL papers into code, the easy way.
-
Distrax - Reimplementation of TensorFlow Probability, containing probability distributions and bijectors.
-
cvxpylayers - Construct differentiable convex optimization layers.
-
TensorLy - Tensor learning made simple.
-
NetKet - Machine Learning toolbox for Quantum Physics.
New Libraries
This section contains libraries that are well-made and useful, but have not necessarily been battle-tested by a large userbase yet.
- Neural Network Libraries
-
FedJAX - Federated learning in JAX, built on Optax and Haiku.
-
Equivariant MLP - Construct equivariant neural network layers.
-
jax-resnet - Implementations and checkpoints for ResNet variants in Flax.
-
Parallax - Immutable Torch Modules for JAX.
-
jax-unirep - Library implementing the UniRep model for protein machine learning applications.
-
jax-flows - Normalizing flows in JAX.
-
sklearn-jax-kernels -
scikit-learn
kernel matrices using JAX.
-
jax-cosmo - Differentiable cosmology library.
-
efax - Exponential Families in JAX.
-
mpi4jax - Combine MPI operations with your Jax code on CPUs and GPUs.
-
imax - Image augmentations and transformations.
-
FlaxVision - Flax version of TorchVision.
-
Oryx - Probabilistic programming language based on program transformations.
-
Optimal Transport Tools - Toolbox that bundles utilities to solve optimal transport problems.
-
delta PV - A photovoltaic simulator with automatic differentation.
-
jaxlie - Lie theory library for rigid body transformations and optimization.
-
BRAX - Differentiable physics engine to simulate environments along with learning algorithms to train agents for these environments.
-
flaxmodels - Pretrained models for Jax/Flax.
-
CR.Sparse - XLA accelerated algorithms for sparse representations and compressive sensing.
-
exojax - Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible to JAX.
-
JAXopt - Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.
-
PIX - PIX is an image processing library in JAX, for JAX.
-
bayex - Bayesian Optimization powered by JAX.
-
JaxDF - Framework for differentiable simulators with arbitrary discretizations.
-
tree-math - Convert functions that operate on arrays into functions that operate on PyTrees.
-
jax-models - Implementations of research papers originally without code or code written with frameworks other than JAX.
-
PGMax - A framework for building discrete Probabilistic Graphical Models (PGM's) and running inference inference on them via JAX.
-
EvoJAX - Hardware-Accelerated Neuroevolution
-
evosax - JAX-Based Evolution Strategies
-
SymJAX - Symbolic CPU/GPU/TPU programming.
-
mcx - Express & compile probabilistic programs for performant inference.
-
Einshape - DSL-based reshaping library for JAX and other frameworks.
-
ALX - Open-source library for distributed matrix factorization using Alternating Least Squares, more info in ALX: Large Scale Matrix Factorization on TPUs.
-
Diffrax - Numerical differential equation solvers in JAX.
-
tinygp - The tiniest of Gaussian process libraries in JAX.
Models and Projects
JAX
Flax
Haiku
Trax
-
Reformer - Implementation of the Reformer (efficient transformer) architecture.
Videos
-
NeurIPS 2020: JAX Ecosystem Meetup - JAX, its use at DeepMind, and discussion between engineers, scientists, and JAX core team.
-
Introduction to JAX - Simple neural network from scratch in JAX.
-
JAX: Accelerated Machine Learning Research | SciPy 2020 | VanderPlas - JAX's core design, how it's powering new research, and how you can start using it.
-
Bayesian Programming with JAX + NumPyro — Andy Kitchen - Introduction to Bayesian modelling using NumPyro.
-
JAX: Accelerated machine-learning research via composable function transformations in Python | NeurIPS 2019 | Skye Wanderman-Milne - JAX intro presentation in Program Transformations for Machine Learning workshop.
-
JAX on Cloud TPUs | NeurIPS 2020 | Skye Wanderman-Milne and James Bradbury - Presentation of TPU host access with demo.
-
Deep Implicit Layers - Neural ODEs, Deep Equilibirum Models, and Beyond | NeurIPS 2020 - Tutorial created by Zico Kolter, David Duvenaud, and Matt Johnson with Colab notebooks avaliable in Deep Implicit Layers.
-
Solving y=mx+b with Jax on a TPU Pod slice - Mat Kelcey - A four part YouTube tutorial series with Colab notebooks that starts with Jax fundamentals and moves up to training with a data parallel approach on a v3-32 TPU Pod slice.
-
JAX, Flax & Transformers 🤗 - 3 days of talks around JAX / Flax, Transformers, large-scale language modeling and other great topics.
Papers
This section contains papers focused on JAX (e.g. JAX-based library whitepapers, research on JAX, etc). Papers implemented in JAX are listed in the Models/Projects section.
-
Compiling machine learning programs via high-level tracing. Roy Frostig, Matthew James Johnson, Chris Leary. MLSys 2018. - White paper describing an early version of JAX, detailing how computation is traced and compiled.
-
JAX, M.D.: A Framework for Differentiable Physics. Samuel S. Schoenholz, Ekin D. Cubuk. NeurIPS 2020. - Introduces JAX, M.D., a differentiable physics library which includes simulation environments, interaction potentials, neural networks, and more.
-
Enabling Fast Differentially Private SGD via Just-in-Time Compilation and Vectorization. Pranav Subramani, Nicholas Vadivelu, Gautam Kamath. arXiv 2020. - Uses JAX's JIT and VMAP to achieve faster differentially private than existing libraries.
Tutorials and Blog Posts
Contributing
Contributions welcome! Read the contribution guidelines first.