JAX is an open-source Python library from Google that brings together Autograd and XLA for high-performance numerical computing and machine learning (ML) research. JAX performs computations significantly faster, providing a foundation for higher-performance scientific computing. JAX automatically differentiates native Python and NumPy using XLA to compile and run NumPy programs on GPUs and TPUs. JAX allows users to just-in-time compile their own Python functions into XLA-optimized kernels using a one-function API, jit. Compilation and automatic differentiation can be composed arbitrarily, allowing users to express sophisticated algorithms with maximal performance without leaving Python. Google describes JAX as a research project, not an official project.
With a familiar NumPy-style API, auto differentiation and optimization, and code that executes on multiple back ends, JAX has use cases across a number of fields, including the following:
- Large language models
- Drug discovery
- Physics ML
- Reinforcement learning
- Neural graphics
The initial open-source release of JAX was in December 2018. Soon after release, it began to be used by multiple research groups for a range of advanced use cases, including the following:
- Studying spectra of neural networks
- Probabilistic programming and Monte Carlo methods
- Scientific applications in physics and biology
In January 2021, Google introduced a neural network library for JAX, called Flax. In December 2020, Deepmind announced that it was using JAX to accelerate its research by enabling rapid experimentation with novel algorithms and architectures. Deepmind stated that JAX now underpins many of its recent publications. In August 2022, Google open-sourced Rax, a Python library for "Learning to Rank" (LTR) in the JAX ecosystem. In May 2023, Google AI introduced JaxPruner, an open-source JAX-based pruning and sparse training library for machine learning research. Also in May 2023, at its I/O developer conference, Google announced the launch of the PaLM 2 large language model. Among the technical details released, Google stated PaLM 2 was built on top of JAX. Reports in July 2023 state that Apple is developing its own AI-powered chatbot built with JAX and running on Google Cloud.