Curated list of awesome lists
Awesome JAX
JAX brings automatic differentiation and the XLA compiler together through a NumPylike API for high performance machine learning research on accelerators like GPUs and TPUs.
This is a curated list of awesome JAX libraries, projects, and other resources. Contributions are welcome!
Contents
Libraries
 Neural Network Libraries

Flax  Centered on flexibility and clarity.

Haiku  Focused on simplicity, created by the authors of Sonnet at DeepMind.

Objax  Has an object oriented design similar to PyTorch.

Elegy  A High Level API for Deep Learning in JAX. Supports Flax, Haiku, and Optax.

Trax  "Batteries included" deep learning library focused on providing solutions for common workloads.

Jraph  Lightweight graph neural network library.

Neural Tangents  Highlevel API for specifying neural networks of both finite and infinite width.

HuggingFace  Ecosystem of pretrained Transformers for a wide range of natural language tasks (Flax).

Equinox  Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX.

NumPyro  Probabilistic programming based on the Pyro library.

Chex  Utilities to write and test reliable JAX code.

Optax  Gradient processing and optimization library.

RLax  Library for implementing reinforcement learning agents.

JAX, M.D.  Accelerated, differential molecular dynamics.

Coax  Turn RL papers into code, the easy way.

Distrax  Reimplementation of TensorFlow Probability, containing probability distributions and bijectors.

cvxpylayers  Construct differentiable convex optimization layers.

TensorLy  Tensor learning made simple.

NetKet  Machine Learning toolbox for Quantum Physics.
New Libraries
This section contains libraries that are wellmade and useful, but have not necessarily been battletested by a large userbase yet.
 Neural Network Libraries

FedJAX  Federated learning in JAX, built on Optax and Haiku.

Equivariant MLP  Construct equivariant neural network layers.

jaxresnet  Implementations and checkpoints for ResNet variants in Flax.

Parallax  Immutable Torch Modules for JAX.

jaxunirep  Library implementing the UniRep model for protein machine learning applications.

jaxflows  Normalizing flows in JAX.

sklearnjaxkernels 
scikitlearn
kernel matrices using JAX.

jaxcosmo  Differentiable cosmology library.

efax  Exponential Families in JAX.

mpi4jax  Combine MPI operations with your Jax code on CPUs and GPUs.

imax  Image augmentations and transformations.

FlaxVision  Flax version of TorchVision.

Oryx  Probabilistic programming language based on program transformations.

Optimal Transport Tools  Toolbox that bundles utilities to solve optimal transport problems.

delta PV  A photovoltaic simulator with automatic differentation.

jaxlie  Lie theory library for rigid body transformations and optimization.

BRAX  Differentiable physics engine to simulate environments along with learning algorithms to train agents for these environments.

flaxmodels  Pretrained models for Jax/Flax.

CR.Sparse  XLA accelerated algorithms for sparse representations and compressive sensing.

exojax  Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible to JAX.

JAXopt  Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.

PIX  PIX is an image processing library in JAX, for JAX.

bayex  Bayesian Optimization powered by JAX.

JaxDF  Framework for differentiable simulators with arbitrary discretizations.

treemath  Convert functions that operate on arrays into functions that operate on PyTrees.

jaxmodels  Implementations of research papers originally without code or code written with frameworks other than JAX.

PGMax  A framework for building discrete Probabilistic Graphical Models (PGM's) and running inference inference on them via JAX.

EvoJAX  HardwareAccelerated Neuroevolution

evosax  JAXBased Evolution Strategies

SymJAX  Symbolic CPU/GPU/TPU programming.

mcx  Express & compile probabilistic programs for performant inference.

Einshape  DSLbased reshaping library for JAX and other frameworks.

ALX  Opensource library for distributed matrix factorization using Alternating Least Squares, more info in ALX: Large Scale Matrix Factorization on TPUs.

Diffrax  Numerical differential equation solvers in JAX.

tinygp  The tiniest of Gaussian process libraries in JAX.
Models and Projects
JAX
Flax
Haiku
Trax

Reformer  Implementation of the Reformer (efficient transformer) architecture.
Videos

NeurIPS 2020: JAX Ecosystem Meetup  JAX, its use at DeepMind, and discussion between engineers, scientists, and JAX core team.

Introduction to JAX  Simple neural network from scratch in JAX.

JAX: Accelerated Machine Learning Research  SciPy 2020  VanderPlas  JAX's core design, how it's powering new research, and how you can start using it.

Bayesian Programming with JAX + NumPyro — Andy Kitchen  Introduction to Bayesian modelling using NumPyro.

JAX: Accelerated machinelearning research via composable function transformations in Python  NeurIPS 2019  Skye WandermanMilne  JAX intro presentation in Program Transformations for Machine Learning workshop.

JAX on Cloud TPUs  NeurIPS 2020  Skye WandermanMilne and James Bradbury  Presentation of TPU host access with demo.

Deep Implicit Layers  Neural ODEs, Deep Equilibirum Models, and Beyond  NeurIPS 2020  Tutorial created by Zico Kolter, David Duvenaud, and Matt Johnson with Colab notebooks avaliable in Deep Implicit Layers.

Solving y=mx+b with Jax on a TPU Pod slice  Mat Kelcey  A four part YouTube tutorial series with Colab notebooks that starts with Jax fundamentals and moves up to training with a data parallel approach on a v332 TPU Pod slice.

JAX, Flax & Transformers 🤗  3 days of talks around JAX / Flax, Transformers, largescale language modeling and other great topics.
Papers
This section contains papers focused on JAX (e.g. JAXbased library whitepapers, research on JAX, etc). Papers implemented in JAX are listed in the Models/Projects section.

Compiling machine learning programs via highlevel tracing. Roy Frostig, Matthew James Johnson, Chris Leary. MLSys 2018.  White paper describing an early version of JAX, detailing how computation is traced and compiled.

JAX, M.D.: A Framework for Differentiable Physics. Samuel S. Schoenholz, Ekin D. Cubuk. NeurIPS 2020.  Introduces JAX, M.D., a differentiable physics library which includes simulation environments, interaction potentials, neural networks, and more.

Enabling Fast Differentially Private SGD via JustinTime Compilation and Vectorization. Pranav Subramani, Nicholas Vadivelu, Gautam Kamath. arXiv 2020.  Uses JAX's JIT and VMAP to achieve faster differentially private than existing libraries.
Tutorials and Blog Posts
Contributing
Contributions welcome! Read the contribution guidelines first.