A curated list of awesome JAX libraries, projects, and other resources. Inspired by Awesome TensorFlow.
JAX brings automatic differentiation and the XLA compiler together through a numpy-like API for high performance machine learning research on accelerators like GPUs and TPUs. More info here.
- Neural Network Libraries
- Flax - a flexible library with the largest user base of all JAX NN libraries.
- Haiku - focused on simplicity, created by the authors of Sonnet at DeepMind.
- Objax - has an object oriented design similar to PyTorch.
- Elegy - implements the Keras API with some improvements.
- RLax - library for implementing reinforcement learning agent.
- Trax - a "batteries included" deep learning library focused on providing solutions for common workloads.
- Jraph - a lightweight graph neural network library.
- NumPyro - probabilistic programming based on the Pyro library.
- Chex - utilities to write and test reliable JAX code.
- Optax - a gradient processing and optimization library.
- JAX, M.D. - accelerated, differential molecular dynamics.
- Coax - turn RL papers into code, the easy way.
- SymJAX - symbolic CPU/GPU/TPU programming.
- mcx - Express & compile probabilistic programs for performant inference.
This section contains libraries that are well-made and useful, but have not necessarily been battle-tested by a large userbase yet.
- Neural Network Libraries
- jax-unirep - library implementing the UniRep model for protein machine learning applications.
- jax-flows - Normalizing flows in JAX.
- sklearn-jax-kernels -
scikit-learn
kernel matrices using JAX. - jax-cosmo - a differentiable cosmology library.
- efax - Exponential Families in JAX.
- mpi4jax - Combine MPI operations with your Jax code on CPUs and GPUs.
- Reformer - an implementation of the Reformer (efficient transformer) architecture.
- Vision Transformer - official implementation in Flax of An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.
- Fourier Feature Networks - official implementation of Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains.
- Flax Models - collection of open-sourced Flax models.
- JaxNeRF - implementation of NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis with multi-device GPU/TPU support.
- Big Transfer (BiT) - implementation of Big Transfer (BiT): General Visual Representation Learning.
- NuX - Normalizing flows with JAX.
- kalman-jax - Approximate inference for Markov (i.e., temporal) Gaussian processes using iterated Kalman filtering and smoothing.
- GPJax - Gaussian processes in JAX.
- jaxns - Nested sampling in JAX.
- Introduction to JAX - a simple neural network from scratch in JAX.
- JAX: Accelerated Machine Learning Research | SciPy 2020 | VanderPlas - JAX’s core design, how it’s powering new research, and how you can start using it.
- Bayesian Programming with JAX + NumPyro — Andy Kitchen - introduction to Bayesian modelling using NumPyro.
- JAX: Accelerated machine-learning research via composable function transformations in Python | NeurIPS 2019 | Skye Wanderman-Milne - JAX intro presentation in Program Transformations for Machine Learning workshop.
- JAX on Cloud TPUs | NeurIPS 2020 | Skye Wanderman-Milne and James Bradbury - presentation of TPU host access with demo.
- Deep Implicit Layers - Neural ODEs, Deep Equilibirum Models, and Beyond | NeurIPS 2020 - tutorial created by Zico Kolter, David Duvenaud, and Matt Johnson with Colab notebooks avaliable in Deep Implicit Layers.
This section contains papers focused on JAX (e.g. JAX-based library whitepapers, research on JAX, etc). Papers implemented in JAX are listed in the Models/Projects section.
- Compiling machine learning programs via high-level tracing. Roy Frostig, Matthew James Johnson, Chris Leary. MLSys 2018. - this white paper describes an early version of JAX, detailing how computation is traced and compiled.
- JAX, M.D.: A Framework for Differentiable Physics. Samuel S. Schoenholz, Ekin D. Cubuk. NeurIPS 2020. - introduces JAX, M.D., a differentiable physics library which includes simulation environments, interaction potentials, neural networks, and more.
- Enabling Fast Differentially Private SGD via Just-in-Time Compilation and Vectorization. Pranav Subramani, Nicholas Vadivelu, Gautam Kamath. arXiv 2020. - uses JAX's JIT and VMAP to achieve faster differentially private than existing libraries.
- Using JAX to accelerate our research by David Budden and Matteo Hessel - describes the state of JAX and the JAX ecosystem at DeepMind.
- Getting started with JAX (MLPs, CNNs & RNNs) by Robert Lange - neural network building blocks from scratch with the basic JAX operators.
- Tutorial: image classification with JAX and Flax Linen by 8bitmp3 - learn how to create a simple convolutional network with the Linen API by Flax and train it to recognize handwritten digits.
- Plugging Into JAX by Nick Doiron - compared Flax, Haiku, and Objax on the Kaggle flower classification challenge.
- Meta-Learning in 50 Lines of JAX by Eric Jang - intro to both JAX and Meta-Learning.
- Normalizing Flows in 100 Lines of JAX by Eric Jang - concise implementation of RealNVP.
- Differentiable Path Tracing on the GPU/TPU by Eric Jang - tutorial on implementing path tracing.
- Ensemble networks by Mat Kelcey - ensemble nets are a method of representing an ensemble of models as one single logical model.
- Out of distribution (OOD) detection by Mat Kelcey - implements different methods for OOD detection.
- Understanding Autodiff with JAX by Srihari Radhakrishna - understand how autodiff works using JAX.
- From PyTorch to JAX: towards neural net frameworks that purify stateful code by Sabrina J. Mielke - showcases how to go from a PyTorch-like style of coding to a more Functional-style of coding.
Contributions welcome! Read the contribution guidelines first.