Machine Learning > JAX
Automatic differentiation and XLA compilation brought together for high-performance machine learning research.
Contents
Libraries
Neural Network Libraries
A High Level API for Deep Learning in JAX. Supports Flax, Haiku, and Optax.
High-level API for specifying neural networks of both finite and infinite width.
Legible, Scalable, Reproducible Foundation Models with Named Tensors and JAX.
LLMs made easy: Pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.
New Libraries
Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.
Library implementing the UniRep model for protein machine learning applications.
Framework and Library for building and training Diffusion models in multi-node multi-device distributed settings (TPUs)
Toolbox that bundles utilities to solve optimal transport problems.
Differentiable physics engine to simulate environments along with learning algorithms to train agents for these environments.
XLA accelerated algorithms for sparse representations and compressive sensing.
Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible to JAX.
Convert functions that operate on arrays into functions that operate on PyTrees.
Implementations of research papers originally without code or code written with frameworks other than JAX.
A framework for building discrete Probabilistic Graphical Models (PGM's) and running inference inference on them via JAX.
A Suite of Industry-Driven Hardware-Accelerated RL Environments written in JAX.
Accelerated curve fitting library for nonlinear least-squares problems (see arXiv paper).
A domain-specific compiler and runtime suite to run JAX code with MPC(Secure Multi-Party Computation).
A simple, performant and scalable Jax LLM written in pure Python/Jax and targeting Google Cloud TPUs.
The layer library for Pax with a goal to be usable by other JAX-based ML projects.
Spiking Neural Networks in JAX for machine learning on neuromorphic hardware.
Nightly CI and optimized examples for JAX on NVIDIA GPUs using libraries such as T5x, Paxml, and Transformer Engine.
EasyDeL is an OpenSource Library to make your training faster and more Optimized With cool Options for training and serving (Llama, MPT, Mixtral, Falcon, etc) in JAX
A Differentiable Massively Parallel Lattice Boltzmann Library in Python for Physics-Based Machine Learning.
High-performance and differentiable simulations of quantum systems with JAX.
Vectorized calculation of optical properties in thin-film structures using JAX. Swiss Army knife tool for thin-film optics research
Algorithms for finding coresets to compress large datasets while retaining their statistical properties.
A reimplementation of MiniGrid, a Reinforcement Learning environment, in JAX
Differentiable Ray Tracing toolbox for Radio Propagation powered by the JAX ecosystem.
Plasma physics simulations using a PIC (Particle-in-Cell) method to self-consistently solve for electron and ion dynamics in electromagnetic fields
A FlashAttention implementation for JAX with support for efficient document mask computation and context parallelism.
Fluid-structure interaction simulations using Immersed Boundary-Lattice Boltzmann Method.
torchax is a library for Jax to interoperate with model code written in PyTorch.
Brain Dynamics Programming Ecosystem
Models and Projects
Haiku
Code used for the paper Unbiased Gradient Estimation in Unrolled Computation Graphs with Persistent Evolution Strategies.
Implementation of the inference pipeline of AlphaFold v2.0, presented in Highly accurate protein structure prediction with AlphaFold.
JAX
Official implementation of Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains.
Flax
Official implementation of Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields.
Implementation of NeuS: Learning Neural Implicit Surfaces by Volume Rendering for Multi-view Reconstruction
Implementation of Big Transfer (BiT): General Visual Representation Learning.
Official implementation of Cross-Modal Contrastive Learning for Text-to-Image Generation.
Official implementation of An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.
Videos
Tutorials and Blog Posts
Learn how to create a simple convolutional network with the Linen API by Flax and train it to recognize handwritten digits.