Machine Learning > JAX
Automatic differentiation and XLA compilation brought together for high-performance machine learning research.
Contents
Libraries
Neural Network Libraries
An evolution on Flax by the same team
Focused on simplicity, created by the authors of Sonnet at DeepMind.
Has an object oriented design similar to PyTorch.
A High Level API for Deep Learning in JAX. Supports Flax, Haiku, and Optax.
Lightweight graph neural network library.
High-level API for specifying neural networks of both finite and infinite width.
Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX.
A Jax Library for Computer Vision Research and Beyond.
Prioritizes legibility, visualization, and easy editing of neural network models with composable tools and a simple mental model.
Legible, Scalable, Reproducible Foundation Models with Named Tensors and JAX.
LLMs made easy: Pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.
Probabilistic programming based on the Pyro library.
Utilities to write and test reliable JAX code.
Gradient processing and optimization library.
Library for implementing reinforcement learning agents.
Accelerated, differential molecular dynamics.
Turn RL papers into code, the easy way.
Reimplementation of TensorFlow Probability, containing probability distributions and bijectors.
Construct differentiable convex optimization layers.
Tensor learning made simple.
Machine Learning toolbox for Quantum Physics.
AWS library for Uncertainty Quantification in Deep Learning.
Library of samplers for JAX.
Probabilistic state space models.
New Libraries
Federated learning in JAX, built on Optax and Haiku.
Construct equivariant neural network layers.
Implementations and checkpoints for ResNet variants in Flax.
JAX/Flax port of the RAFT optical flow estimator.
Immutable Torch Modules for JAX.
Root finding, minimisation, fixed points, and least squares.
Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.
Library implementing the UniRep model for protein machine learning applications.
Distributions and normalizing flows built as equinox modules.
Framework and Library for building and training Diffusion models in multi-node multi-device distributed settings (TPUs)
Normalizing flows in JAX.
scikit-learn kernel matrices using JAX.
Differentiable cosmology library.
Exponential Families in JAX.
Combine MPI operations with your Jax code on CPUs and GPUs.
Image augmentations and transformations.
Flax version of TorchVision.
Probabilistic programming language based on program transformations.
Toolbox that bundles utilities to solve optimal transport problems.
A photovoltaic simulator with automatic differentation.
Lie theory library for rigid body transformations and optimization.
Differentiable physics engine to simulate environments along with learning algorithms to train agents for these environments.
Pretrained models for Jax/Flax.
XLA accelerated algorithms for sparse representations and compressive sensing.
Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible to JAX.
PIX is an image processing library in JAX, for JAX.
Bayesian Optimization powered by JAX.
Framework for differentiable simulators with arbitrary discretizations.
Convert functions that operate on arrays into functions that operate on PyTrees.
Implementations of research papers originally without code or code written with frameworks other than JAX.
A framework for building discrete Probabilistic Graphical Models (PGM's) and running inference inference on them via JAX.
Hardware-Accelerated Neuroevolution
JAX-Based Evolution Strategies
Symbolic CPU/GPU/TPU programming.
Express & compile probabilistic programs for performant inference.
DSL-based reshaping library for JAX and other frameworks.
Numerical differential equation solvers in JAX.
The tiniest of Gaussian process libraries in JAX.
Reinforcement Learning Environments with the well-known gym API.
Monte Carlo tree search algorithms in native JAX.
Second Order Optimization with Approximate Curvature for NNs.
Convert functions/graphs to JAX functions.
A library for differentiable acoustic simulations
Gaussian processes in JAX.
A Suite of Industry-Driven Hardware-Accelerated RL Environments written in JAX.
Equinox version of Torchvision.
Accelerated curve fitting library for nonlinear least-squares problems (see arXiv paper).
Solve macroeconomic models with hetereogeneous agents using JAX.
A domain-specific compiler and runtime suite to run JAX code with MPC(Secure Multi-Party Computation).
Add a tqdm progress bar to JAX scans and loops.
Serialize JAX, Flax, Haiku, or Objax model params with safetensors.
Differentiable stencil decorators in JAX.
A simple, performant and scalable Jax LLM written in pure Python/Jax and targeting Google Cloud TPUs.
A Jax-based machine learning framework for training large scale models.
The layer library for Pax with a goal to be usable by other JAX-based ML projects.
Vectorisable, end-to-end RL algorithms in JAX.
Automatically apply LoRA to JAX models (Flax, Haiku, etc.)
Scientific computational imaging in JAX.
Spiking Neural Networks in JAX for machine learning on neuromorphic hardware.
Optimal transport tools in JAX.
Quality Diversity optimization in Jax.
Nightly CI and optimized examples for JAX on NVIDIA GPUs using libraries such as T5x, Paxml, and Transformer Engine.
Vectorized board game environments for RL with an AlphaZero example.
EasyDeL is an OpenSource Library to make your training faster and more Optimized With cool Options for training and serving (Llama, MPT, Mixtral, Falcon, etc) in JAX
A Differentiable Massively Parallel Lattice Boltzmann Library in Python for Physics-Based Machine Learning.
High-performance and differentiable simulations of quantum systems with JAX.
Agent-Based modelling framework in JAX.
Vectorized calculation of optical properties in thin-film structures using JAX. Swiss Army knife tool for thin-film optics research
Algorithms for finding coresets to compress large datasets while retaining their statistical properties.
A reimplementation of MiniGrid, a Reinforcement Learning environment, in JAX
Finite-Difference Time-Domain Electromagnetic Simulations in JAX
Differentiable Ray Tracing toolbox for Radio Propagation powered by the JAX ecosystem.
Plasma physics simulations using a PIC (Particle-in-Cell) method to self-consistently solve for electron and ion dynamics in electromagnetic fields
A FlashAttention implementation for JAX with support for efficient document mask computation and context parallelism.
differentiable (magneto)hydrodynamics for astrophysics in JAX
Fluid-structure interaction simulations using Immersed Boundary-Lattice Boltzmann Method.
High-performance tomographic reconstruction.
torchax is a library for Jax to interoperate with model code written in PyTorch.
Brain Dynamics Programming Ecosystem
Brain Dynamics Programming in Python.
Physical units and unit-aware mathematical system in JAX.
Dendritic Modeling in JAX.
State-based Transformation System for Program Compilation and Augmentation.
Leveraging Taichi Lang to customize brain dynamics operators.
Models and Projects
Haiku
Code used for the paper Unbiased Gradient Estimation in Unrolled Computation Graphs with Persistent Evolution Strategies.
Implementation of the inference pipeline of AlphaFold v2.0, presented in Highly accurate protein structure prediction with AlphaFold.
Baseline code to reproduce results in WikiGraphs: A Wikipedia Text - Knowledge Graph Paired Datase.
Normalizing flows with JAX.
JAX implementation of the paper Auction learning as a two-player game.
JAX
Official implementation of Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains.
Approximate inference for Markov (i.e., temporal) Gaussian processes using iterated Kalman filtering and smoothing.
Nested sampling in JAX.
Flax
Collection of LLMs implemented in JAX & Flax
Flax implementation of DeepSeek-R1 1.5B distilled reasoning LLM.
Official implementation of Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields.
Implementation of NeuS: Learning Neural Implicit Surfaces by Volume Rendering for Multi-view Reconstruction
Implementation of Big Transfer (BiT): General Visual Representation Learning.
Implementations of reinforcement learning algorithms.
Implementation of Pay Attention to MLPs.
Minimal implementation of MLP-Mixer: An all-MLP Architecture for Vision.
Official implementation of Aggregating Nested Transformers.
Official implementation of Cross-Modal Contrastive Learning for Text-to-Image Generation.
Official implementation of An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.
Port of mseitzer/pytorch-fid to Flax.
A JAX + Flax implementation of Combinatorial Optimization with Physics-Inspired Graph Neural Networks.
Flax implementation of DETR: End-to-end Object Detection with Transformers using Sinkhorn solver and parallel bipartite matching.
Tutorials and Blog Posts
Learn how to create a simple convolutional network with the Linen API by Flax and train it to recognize handwritten digits.
Tutorial demonstrating the infrastructure required to provide custom ops in JAX.
A series of notebooks and videos going from zero JAX knowledge to building neural networks in Haiku.