Curated list of awesome lists
Awesome JAX
JAX brings automatic differentiation and the XLA compiler together through a NumPylike API for high performance machine learning research on accelerators like GPUs and TPUs.
This is a curated list of awesome JAX libraries, projects, and other resources. Contributions are welcome!
Contents
Libraries
 Neural Network Libraries

Flax  Centered on flexibility and clarity.

Haiku  Focused on simplicity, created by the authors of Sonnet at DeepMind.

Objax  Has an object oriented design similar to PyTorch.

Elegy  A High Level API for Deep Learning in JAX. Supports Flax, Haiku, and Optax.

Trax  "Batteries included" deep learning library focused on providing solutions for common workloads.

Jraph  Lightweight graph neural network library.

Neural Tangents  Highlevel API for specifying neural networks of both finite and infinite width.

HuggingFace  Ecosystem of pretrained Transformers for a wide range of natural language tasks (Flax).

Equinox  Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX.

Scenic  A Jax Library for Computer Vision Research and Beyond.

NumPyro  Probabilistic programming based on the Pyro library.

Chex  Utilities to write and test reliable JAX code.

Optax  Gradient processing and optimization library.

RLax  Library for implementing reinforcement learning agents.

JAX, M.D.  Accelerated, differential molecular dynamics.

Coax  Turn RL papers into code, the easy way.

Distrax  Reimplementation of TensorFlow Probability, containing probability distributions and bijectors.

cvxpylayers  Construct differentiable convex optimization layers.

TensorLy  Tensor learning made simple.

NetKet  Machine Learning toolbox for Quantum Physics.

Fortuna  AWS library for Uncertainty Quantification in Deep Learning.

BlackJAX  Library of samplers for JAX.
New Libraries
This section contains libraries that are wellmade and useful, but have not necessarily been battletested by a large userbase yet.
 Neural Network Libraries

FedJAX  Federated learning in JAX, built on Optax and Haiku.

Equivariant MLP  Construct equivariant neural network layers.

jaxresnet  Implementations and checkpoints for ResNet variants in Flax.

Parallax  Immutable Torch Modules for JAX.

jaxunirep  Library implementing the UniRep model for protein machine learning applications.

jaxflows  Normalizing flows in JAX.

sklearnjaxkernels 
scikitlearn
kernel matrices using JAX.

jaxcosmo  Differentiable cosmology library.

efax  Exponential Families in JAX.

mpi4jax  Combine MPI operations with your Jax code on CPUs and GPUs.

imax  Image augmentations and transformations.

FlaxVision  Flax version of TorchVision.

Oryx  Probabilistic programming language based on program transformations.

Optimal Transport Tools  Toolbox that bundles utilities to solve optimal transport problems.

delta PV  A photovoltaic simulator with automatic differentation.

jaxlie  Lie theory library for rigid body transformations and optimization.

BRAX  Differentiable physics engine to simulate environments along with learning algorithms to train agents for these environments.

flaxmodels  Pretrained models for Jax/Flax.

CR.Sparse  XLA accelerated algorithms for sparse representations and compressive sensing.

exojax  Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible to JAX.

JAXopt  Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.

PIX  PIX is an image processing library in JAX, for JAX.

bayex  Bayesian Optimization powered by JAX.

JaxDF  Framework for differentiable simulators with arbitrary discretizations.

treemath  Convert functions that operate on arrays into functions that operate on PyTrees.

jaxmodels  Implementations of research papers originally without code or code written with frameworks other than JAX.

PGMax  A framework for building discrete Probabilistic Graphical Models (PGM's) and running inference inference on them via JAX.

EvoJAX  HardwareAccelerated Neuroevolution

evosax  JAXBased Evolution Strategies

SymJAX  Symbolic CPU/GPU/TPU programming.

mcx  Express & compile probabilistic programs for performant inference.

Einshape  DSLbased reshaping library for JAX and other frameworks.

ALX  Opensource library for distributed matrix factorization using Alternating Least Squares, more info in ALX: Large Scale Matrix Factorization on TPUs.

Diffrax  Numerical differential equation solvers in JAX.

tinygp  The tiniest of Gaussian process libraries in JAX.

gymnax  Reinforcement Learning Environments with the wellknown gym API.

Mctx  Monte Carlo tree search algorithms in native JAX.

KFACJAX  Second Order Optimization with Approximate Curvature for NNs.

TF2JAX  Convert functions/graphs to JAX functions.

jwave  A library for differentiable acoustic simulations

GPJax  Gaussian processes in JAX.

Jumanji  A Suite of IndustryDriven HardwareAccelerated RL Environments written in JAX.

Eqxvision  Equinox version of Torchvision.

JAXFit  Accelerated curve fitting library for nonlinear leastsquares problems (see arXiv paper).

econpizza  Solve macroeconomic models with hetereogeneous agents using JAX.

SPU  A domainspecific compiler and runtime suite to run JAX code with MPC(Secure MultiParty Computation).

jaxtqdm  Add a tqdm progress bar to JAX scans and loops.

safejax  Serialize JAX, Flax, Haiku, or Objax model params with 🤗
safetensors
.

Kernex  Differentiable stencil decorators in JAX.

MaxText  A simple, performant and scalable Jax LLM written in pure Python/Jax and targeting Google Cloud TPUs.

Pax  A Jaxbased machine learning framework for training large scale models.

Praxis  The layer library for Pax with a goal to be usable by other JAXbased ML projects.

purejaxrl  Vectorisable, endtoend RL algorithms in JAX.

Lorax  Automatically apply LoRA to JAX models (Flax, Haiku, etc.)

SCICO  Scientific computational imaging in JAX.

BrainPy  Brain Dynamics Programming in Python.

OTTJAX  Optimal transport tools in JAX.

QDax  Quality Diversity optimization in Jax.

JAX Toolbox  Nightly CI and optimized examples for JAX on NVIDIA GPUs using libraries such as T5x, Paxml, and Transformer Engine.

Pgx  Vectorized board game environments for RL with an AlphaZero example.
Models and Projects
JAX
Flax
Haiku
Trax

Reformer  Implementation of the Reformer (efficient transformer) architecture.
NumPyro
Videos

NeurIPS 2020: JAX Ecosystem Meetup  JAX, its use at DeepMind, and discussion between engineers, scientists, and JAX core team.

Introduction to JAX  Simple neural network from scratch in JAX.

JAX: Accelerated Machine Learning Research  SciPy 2020  VanderPlas  JAX's core design, how it's powering new research, and how you can start using it.

Bayesian Programming with JAX + NumPyro — Andy Kitchen  Introduction to Bayesian modelling using NumPyro.

JAX: Accelerated machinelearning research via composable function transformations in Python  NeurIPS 2019  Skye WandermanMilne  JAX intro presentation in Program Transformations for Machine Learning workshop.

JAX on Cloud TPUs  NeurIPS 2020  Skye WandermanMilne and James Bradbury  Presentation of TPU host access with demo.

Deep Implicit Layers  Neural ODEs, Deep Equilibirum Models, and Beyond  NeurIPS 2020  Tutorial created by Zico Kolter, David Duvenaud, and Matt Johnson with Colab notebooks avaliable in Deep Implicit Layers.

Solving y=mx+b with Jax on a TPU Pod slice  Mat Kelcey  A four part YouTube tutorial series with Colab notebooks that starts with Jax fundamentals and moves up to training with a data parallel approach on a v332 TPU Pod slice.

JAX, Flax & Transformers 🤗  3 days of talks around JAX / Flax, Transformers, largescale language modeling and other great topics.
Papers
This section contains papers focused on JAX (e.g. JAXbased library whitepapers, research on JAX, etc). Papers implemented in JAX are listed in the Models/Projects section.

Compiling machine learning programs via highlevel tracing. Roy Frostig, Matthew James Johnson, Chris Leary. MLSys 2018.  White paper describing an early version of JAX, detailing how computation is traced and compiled.

JAX, M.D.: A Framework for Differentiable Physics. Samuel S. Schoenholz, Ekin D. Cubuk. NeurIPS 2020.  Introduces JAX, M.D., a differentiable physics library which includes simulation environments, interaction potentials, neural networks, and more.

Enabling Fast Differentially Private SGD via JustinTime Compilation and Vectorization. Pranav Subramani, Nicholas Vadivelu, Gautam Kamath. arXiv 2020.  Uses JAX's JIT and VMAP to achieve faster differentially private than existing libraries.
Tutorials and Blog Posts

Using JAX to accelerate our research by David Budden and Matteo Hessel  Describes the state of JAX and the JAX ecosystem at DeepMind.

Getting started with JAX (MLPs, CNNs & RNNs) by Robert Lange  Neural network building blocks from scratch with the basic JAX operators.

Learn JAX: From Linear Regression to Neural Networks by Rito Ghosh  A gentle introduction to JAX and using it to implement Linear and Logistic Regression, and Neural Network models and using them to solve real world problems.

Tutorial: image classification with JAX and Flax Linen by 8bitmp3  Learn how to create a simple convolutional network with the Linen API by Flax and train it to recognize handwritten digits.

Plugging Into JAX by Nick Doiron  Compares Flax, Haiku, and Objax on the Kaggle flower classification challenge.

MetaLearning in 50 Lines of JAX by Eric Jang  Introduction to both JAX and MetaLearning.

Normalizing Flows in 100 Lines of JAX by Eric Jang  Concise implementation of RealNVP.

Differentiable Path Tracing on the GPU/TPU by Eric Jang  Tutorial on implementing path tracing.

Ensemble networks by Mat Kelcey  Ensemble nets are a method of representing an ensemble of models as one single logical model.

Out of distribution (OOD) detection by Mat Kelcey  Implements different methods for OOD detection.

Understanding Autodiff with JAX by Srihari Radhakrishna  Understand how autodiff works using JAX.

From PyTorch to JAX: towards neural net frameworks that purify stateful code by Sabrina J. Mielke  Showcases how to go from a PyTorchlike style of coding to a more Functionalstyle of coding.

Extending JAX with custom C++ and CUDA code by Dan ForemanMackey  Tutorial demonstrating the infrastructure required to provide custom ops in JAX.

Evolving Neural Networks in JAX by Robert Tjarko Lange  Explores how JAX can power the next generation of scalable neuroevolution algorithms.

Exploring hyperparameter metaloss landscapes with JAX by Luke Metz  Demonstrates how to use JAX to perform innerloss optimization with SGD and Momentum, outerloss optimization with gradients, and outerloss optimization using evolutionary strategies.

Deterministic ADVI in JAX by Martin Ingram  Walk through of implementing automatic differentiation variational inference (ADVI) easily and cleanly with JAX.

Evolved channel selection by Mat Kelcey  Trains a classification model robust to different combinations of input channels at different resolutions, then uses a genetic algorithm to decide the best combination for a particular loss.

Introduction to JAX by Kevin Murphy  Colab that introduces various aspects of the language and applies them to simple ML problems.

Writing an MCMC sampler in JAX by Jeremie Coullon  Tutorial on the different ways to write an MCMC sampler in JAX along with speed benchmarks.

How to add a progress bar to JAX scans and loops by Jeremie Coullon  Tutorial on how to add a progress bar to compiled loops in JAX using the
host_callback
module.

Get started with JAX by Aleksa Gordić  A series of notebooks and videos going from zero JAX knowledge to building neural networks in Haiku.

Writing a Training Loop in JAX + FLAX by Saurav Maheshkar and Soumik Rakshit  A tutorial on writing a simple endtoend training and evaluation pipeline in JAX, Flax and Optax.

Implementing NeRF in JAX by Soumik Rakshit and Saurav Maheshkar  A tutorial on 3D volumetric rendering of scenes represented by Neural Radiance Fields in JAX.

Deep Learning tutorials with JAX+Flax by Phillip Lippe  A series of notebooks explaining various deep learning concepts, from basics (e.g. intro to JAX/Flax, activiation functions) to recent advances (e.g., Vision Transformers, SimCLR), with translations to PyTorch.

Achieving 4000x Speedups with PureJaxRL  A blog post on how JAX can massively speedup RL training through vectorisation.
Books

Jax in Action  A handson guide to using JAX for deep learning and other mathematicallyintensive applications.
Contributing
Contributions welcome! Read the contribution guidelines first.