JAX is an open-source Python library from Google that brings together Autograd and XLA for high-performance numerical computing and machine learning (ML) research.