Hi, my name is Jake VanderPlas.
I am a Software Engineer at Google Research, and today I want to tell you about JAX, which is a Python package for accelerating your machine learning research.
So at a glance, what is JAX? Basically, JAX provides a lightweight API for array-based computing that’s very similar to NumPy.
So if you’re writing NumPy code in Python already, JAX will probably feel very familiar.
On top of that, what it adds is a set of composable function transformations for doing things like automatic differentiation, just-in-time compilation, and then automated vectorization and parallelization of your code.
And then finally, what you can do with this code once you write it is execute it on CPU or GPU or TPU without any changes to your code.
So it ends up being a very powerful system for building models from scratch and exploring different machine learning problems.
On top of this, JAX is really fast.
In the recent MLPerf competition, which pitted different software and hardware systems against each other on common deep learning algorithms and problems, JAX outperformed other systems on some of these common problems.
You can look at these results and see how JAX stacks up against some of your other favorite systems.
So let’s step back and motivate Jax a little bit.
Thinking about this, how might you implement a performant and a scalable deep neural network from scratch in Python? Usually, Python programmers would start with something like NumPy, because it’s this familiar, array-based data processing language that’s been used literally for decades in the Python community.
And if you were trying to create a deep learning system in NumPy, you might start with a predict method.
Here, this is a feedforward neural network that does a sequence of doc products and activation functions to transform the inputs into some sort of outputs that can be learned.
The next thing you need once you have this model defined is a loss function, and this is the thing that’ll give you your metric that you’re trying to optimize in order to fit the best machine learning model.
So here we’re you using a mean squared error loss.
And now what’s missing with this deep learning in NumPy? Deep learning takes a lot of computation.
Wed like to run it on accelerated hardware.
So you want to run this model on GPU and TPU, and that’s a little bit difficult with classic NumPy.
The next thing you might want to do is use automatic differentiation, which would let you fit this loss function very efficiently without having to do numerical differentiation along the way.
Next thing you might want to do is add compilation so you can fuse together these operations and make them much more efficient.
Finally if you’re working with large datasets it’s nice to be able to parallelize this operation across multiple cores or multiple machines.
So let’s take a look at what JAX can do to fill in these missing pieces.
The first thing you can do is replace the numpy import with jax.numpy.
And this has the same API as the classic NumPy in many cases, but it will allow you to do some of these things that were missing in the first pass.
So, for example, JAX, automatically via this XLA backend, will target CPUs, GPUs, and TPUs for fast computation of your models and your algorithms.
On top of that, JAX provides the set of composable transformations, one of which is the grad transform, which can take a loss function like mse_loss and convert it into a Python function that computes the gradient.
Once you have this gradient function, you might want to apply it across multiple pieces of data, and in JAX, you no longer have to rewrite your prediction and your loss functions to handle this batch data.
If you pass it through the vmap transform, this’ll automatically vectorize your code so you can use the same code across multiple batches.
If you want to compile this, you can use the jit transform, which stands for just-in-time compilation.
And this will fuse operations together using the XLA compiler to make your code sometimes much, much faster than it was originally.
And finally, if you want to parallelize your code, there’s a transform that’s very similar to vmap called pmap.
And if you run pmap through your code, this will be able to natively target multiple cores in your system or a cluster of TPUs or GPUs that you have access to.
So this ends up being a very powerful system to build up these fast computations without much extra code.
So the key ideas here is, in JAX, Python code is traced to an intermediate representation, and JAX knows how to transform this intermediate representation, and I’ll tell you a little bit about this in a moment, and the same intermediate representation enables domain-specific compilation via XLA, so you can target different backends.
It has this familiar user-facing API based on NumPy and SciPy.
So if you’ve been coding in the Python data space for a while, JAX should feel fairly familiar.
And on top of it, it’s this powerful set of transforms, grad, git, vmap, pmap, and others that let you do things with your code that weren’t able to do before.
So I want to step back a bit now and talk about how JAX works, because it’s interesting to use a powerful black box like this, but I think it’s even more fun if you know what’s going on under the hood.
And the gist of how JAX works is that it traces Python functions.
So just as a thought experiment, let’s take a look at this function f(x) return x + 2.
What does this function do? You know, it may seem obvious, but Python is such a dynamic language that this function could literally do anything.
Say, if X is an instance of this EspressoMaker object that overloads the add functions so when you tell it the number of espressos you want, it’ll “ssh” to your espresso maker and make those automatically.
This is a bit silly, but it just drives home the point that we don’t really know what Python functions do unless we know exactly what’s being passed to them.
And JAX takes advantage of this Python dynamism in order to figure out what’s going on in a function and to come up with a representation that can be transformed in the way that we saw earlier.
So it does this by calling the function on what’s called a tracer value.
So an example of this is a shaped array tracer value, and you can see what happens here is that the add function not only returns a shaped array with the result, but also records the computation.
This isn’t exactly what’s happening in the JAX code, but it’s the basic gist of what’s happening under the hood.
So, how does this work? So let’s say you have a function like this that computes the base two log of a number X.
How does JAX trace this and understand what’s going on? Well the first thing is all of JAX operations are based on operations in lax, which is the set of primitive operations that mirrors XLA.
And now when we want to compute the log base 2 of some array, the first thing we do is put in a ShapedArray value in place of this X, and once that ShapedArray value is in there, we can step through the function and see what operations are done on this.
So we see the log of X and we record b = log a.
We see the log of 2 and we record c = log 2.0.
We see the division operation between these two and we record d = div b c, and then we return this d value.
And now what’s left here is what’s known as a jaxper, short for a JAX expression or JAX representation of the function.
And this encodes everything that this function does in terms of its numerical processing of the inputs and how they lead to the outputs.
And this is a well-defined intermediate representation that lets you do a number of things.
JAX knows how to automatically differentiate this, knows how to vectorize and parallelize this, and it knows how to just-in-time compile this compile this by passing this to XLA.
And so the result is you have this nice pipeline where you’re writing Python code on the left and JAX is tracing it, turning it into an intermediate representation, transforming that intermediate representation in the way you specify in your Python code, and eventually jit compiling it to HLO, which stands for high level optimized code, which is what XLA reads in, and XLA can take this HLO, compile it, and send it to CPUs, GPUs or TPUs.
And all that you need to do as the user is write this Python code on the left hand side.
Everything else is under the hood and kind of happens automatically.
So this has been really powerful for use across Google and outside Google as well.
These are just a couple of examples of applications that JAX has powered.
On the top left, we have protein folding.
This is the AlphaFold from DeepMind.
The current version of AlphaFold runs on JAX.
I won’t say much more about it because there’s another talk in the session that dives into it.
But we’ve also seen JAX used for robotic control, used for physical simulations and other simulations where you need to run lots of computations on accelerated chips.
And it ends up being an incredibly powerful system for exploring and building these kinds of models.
If you’d like to get started with JAX, you can go to our website.
There’s a very nice getting started guide, a one-on-one tutorial to help familiarize you with how to get started with JAX, and if you want to dive in a little more, there’s a whole ecosystem of tools built around JAX for everything from deep learning.
There’s higher-level deep learning libraries.
There are optimization libraries for doing physical modeling and other applications, probabilistic programming, there’s graph neural networks and many, many more that I don’t have time to highlight here.
With that, I’d like to thank you for listening to this talk.
If you want to learn more about JAX, take a look at the documentation at jax.readthedocs.io.
And we’d love to see what you build with this tool.