Project Awesome project awesome

Machine Learning > JAX

Automatic differentiation and XLA compilation brought together for high-performance machine learning research.

Collection 2.1k stars GitHub

Libraries

New Libraries

FedJAX 273 updated 2mo ago

Federated learning in JAX, built on Optax and Haiku.

Equivariant MLP

Construct equivariant neural network layers.

jax-resnet 119 updated 3y ago

Implementations and checkpoints for ResNet variants in Flax.

jax-raft

JAX/Flax port of the RAFT optical flow estimator.

Parallax 153 updated 5y ago

Immutable Torch Modules for JAX.

Optimistix 560 updated 16d ago

Root finding, minimisation, fixed points, and least squares.

JAXopt 1.0k updated 3mo ago

Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX.

jax-unirep 108 updated 1y ago

Library implementing the UniRep model for protein machine learning applications.

flowjax 223 updated 1mo ago

Distributions and normalizing flows built as equinox modules.

flaxdiff 41 updated 19d ago

Framework and Library for building and training Diffusion models in multi-node multi-device distributed settings (TPUs)

jax-flows

Normalizing flows in JAX.

sklearn-jax-kernels 47 updated 5y ago

scikit-learn kernel matrices using JAX.

jax-cosmo 229 updated 9mo ago

Differentiable cosmology library.

efax 76 updated 3d ago

Exponential Families in JAX.

mpi4jax 520 updated 6d ago

Combine MPI operations with your Jax code on CPUs and GPUs.

imax 42 updated 1y ago

Image augmentations and transformations.

FlaxVision 45 updated 8mo ago

Flax version of TorchVision.

Oryx 4.4k updated 21d ago

Probabilistic programming language based on program transformations.

Optimal Transport Tools 213 (archived)

Toolbox that bundles utilities to solve optimal transport problems.

delta PV 65 updated 5mo ago

A photovoltaic simulator with automatic differentation.

jaxlie 326 updated 11mo ago

Lie theory library for rigid body transformations and optimization.

BRAX 3.1k updated 9d ago

Differentiable physics engine to simulate environments along with learning algorithms to train agents for these environments.

flaxmodels

Pretrained models for Jax/Flax.

CR.Sparse 97 updated 2y ago

XLA accelerated algorithms for sparse representations and compressive sensing.

exojax 69 updated 2d ago

Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible to JAX.

PIX 434 updated 1y ago

PIX is an image processing library in JAX, for JAX.

bayex 105 updated 11mo ago

Bayesian Optimization powered by JAX.

JaxDF 134 updated yesterday

Framework for differentiable simulators with arbitrary discretizations.

tree-math 210 (archived)

Convert functions that operate on arrays into functions that operate on PyTrees.

jax-models 161 updated 3y ago

Implementations of research papers originally without code or code written with frameworks other than JAX.

PGMax 65 (archived)

A framework for building discrete Probabilistic Graphical Models (PGM's) and running inference inference on them via JAX.

EvoJAX 944 (archived)

Hardware-Accelerated Neuroevolution

evosax 740 updated 6mo ago

JAX-Based Evolution Strategies

SymJAX 130 updated 2y ago

Symbolic CPU/GPU/TPU programming.

mcx 330 updated 2y ago

Express & compile probabilistic programs for performant inference.

Einshape

DSL-based reshaping library for JAX and other frameworks.

Diffrax 1.9k updated 1mo ago

Numerical differential equation solvers in JAX.

tinygp 337 updated 10d ago

The tiniest of Gaussian process libraries in JAX.

gymnax 873 updated 9mo ago

Reinforcement Learning Environments with the well-known gym API.

Mctx 2.6k updated 6mo ago

Monte Carlo tree search algorithms in native JAX.

KFAC-JAX 317 updated 9d ago

Second Order Optimization with Approximate Curvature for NNs.

TF2JAX 120 updated 8d ago

Convert functions/graphs to JAX functions.

jwave 194 updated 3d ago

A library for differentiable acoustic simulations

GPJax 603 updated 8d ago

Gaussian processes in JAX.

Jumanji 821 updated 16d ago

A Suite of Industry-Driven Hardware-Accelerated RL Environments written in JAX.

Eqxvision 111 updated 1y ago

Equinox version of Torchvision.

JAXFit 61 updated 2y ago

Accelerated curve fitting library for nonlinear least-squares problems (see arXiv paper).

econpizza 112 updated 4mo ago

Solve macroeconomic models with hetereogeneous agents using JAX.

SPU 318 updated 5d ago

A domain-specific compiler and runtime suite to run JAX code with MPC(Secure Multi-Party Computation).

jax-tqdm 125 updated 10mo ago

Add a tqdm progress bar to JAX scans and loops.

safejax 47 updated 1y ago

Serialize JAX, Flax, Haiku, or Objax model params with safetensors.

Kernex 71 updated 4mo ago

Differentiable stencil decorators in JAX.

MaxText 2.2k updated 2d ago

A simple, performant and scalable Jax LLM written in pure Python/Jax and targeting Google Cloud TPUs.

Pax 550 updated 8d ago

A Jax-based machine learning framework for training large scale models.

Praxis 195 updated 15d ago

The layer library for Pax with a goal to be usable by other JAX-based ML projects.

purejaxrl 1.0k updated 1y ago

Vectorisable, end-to-end RL algorithms in JAX.

Lorax 145 updated 2y ago

Automatically apply LoRA to JAX models (Flax, Haiku, etc.)

SCICO 154 updated yesterday

Scientific computational imaging in JAX.

Spyx 133 updated 2mo ago

Spiking Neural Networks in JAX for machine learning on neuromorphic hardware.

OTT-JAX 720 updated 4d ago

Optimal transport tools in JAX.

QDax 346 updated 4mo ago

Quality Diversity optimization in Jax.

JAX Toolbox 392 updated 2d ago

Nightly CI and optimized examples for JAX on NVIDIA GPUs using libraries such as T5x, Paxml, and Transformer Engine.

Pgx 595 updated 1y ago

Vectorized board game environments for RL with an AlphaZero example.

EasyDeL 345 updated 2d ago

EasyDeL is an OpenSource Library to make your training faster and more Optimized With cool Options for training and serving (Llama, MPT, Mixtral, Falcon, etc) in JAX

XLB 448 updated 16d ago

A Differentiable Massively Parallel Lattice Boltzmann Library in Python for Physics-Based Machine Learning.

dynamiqs 272 updated 2d ago

High-performance and differentiable simulations of quantum systems with JAX.

foragax 5 updated 1y ago

Agent-Based modelling framework in JAX.

tmmax 30 updated 1mo ago

Vectorized calculation of optical properties in thin-film structures using JAX. Swiss Army knife tool for thin-film optics research

Coreax 37 updated 2d ago

Algorithms for finding coresets to compress large datasets while retaining their statistical properties.

NAVIX 163 updated 5mo ago

A reimplementation of MiniGrid, a Reinforcement Learning environment, in JAX

FDTDX 257 updated yesterday

Finite-Difference Time-Domain Electromagnetic Simulations in JAX

DiffeRT 53 updated 3d ago

Differentiable Ray Tracing toolbox for Radio Propagation powered by the JAX ecosystem.

JAX-in-Cell 22 updated 2d ago

Plasma physics simulations using a PIC (Particle-in-Cell) method to self-consistently solve for electron and ion dynamics in electromagnetic fields

kvax 160 updated 4mo ago

A FlashAttention implementation for JAX with support for efficient document mask computation and context parallelism.

astronomix 57 updated 4d ago

differentiable (magneto)hydrodynamics for astrophysics in JAX

vivsim 33 updated 3d ago

Fluid-structure interaction simulations using Immersed Boundary-Lattice Boltzmann Method.

MBIRJAX 20 updated 2d ago

High-performance tomographic reconstruction.

torchax 200 updated 8d ago

torchax is a library for Jax to interoperate with model code written in PyTorch.

Models and Projects