In this tutorial, we will explore how to develop a Neural Network (NN) with JAX. And what better model to choose than the Transformer. As JAX is growing in popularity, more and more developer teams are starting to experiment with it and incorporating it into their projects. Despite the fact that it lacks the maturity of Tensorflow or Pytorch, it provides some great features for building and training Deep Learning models.

For a solid understanding of JAX basics, check my previous article if you haven’t already. Also you can find the full code in our Github repository.

One of the common problems people have when starting with JAX is the choice of a framework. The people in Deepmind seem to be very busy and have already released a plethora of frameworks on top of JAX. Here is a list of the most famous ones:

  • Haiku: Haiku is the go-to framework for Deep Learning and it’s used by many Google and Deepmind internal teams. It provides some simple, composable abstractions for machine learning research as well as ready-to-use modules and layers.

  • Optax: Optax is a gradient processing and optimization library that contains out-of-the-box optimizers and related mathematical operations.

  • RLax: RLax is a reinforcement learning framework with many RL subcomponents and operations.

  • Chex: Chex is a library of utilities for testing and debugging JAX code.

  • Jraph: Jraph is a Graph Neural Networks library in JAX.

  • Flax: Flax is another neural network library with a variety of ready-to-use modules, optimizers, and utilities. It’s most likely the closest we have in an all-in JAX framework.

  • Objax: Objax is a third ml library that focuses on object-oriented programming and code readability. Once again it contains the most popular modules, activation functions, losses, optimizers as well a handful of pre-trained models.

  • Trax: Trax is an end-to-end library for deep learning that focuses on Transformers

  • JAXline: JAXline is a supervised-learning library that is used for distributed JAX training and evaluation.

  • ACME: ACME is another research framework for reinforcement learning.

  • JAX-MD: JAX-MD is a niche framework that deals with molecular dynamics.

  • Jaxchem: JAXChem is another niche library that emphasizes on chemical modeling.

Of course, the question is which one do I choose?

To be honest I’m not sure.

But if I were you and I wanted to learn JAX, I’d start with the most popular ones. Haiku and Flax seem to be used a lot inside Google/Deepmind and have the most active Github community. For this article, I will start with the first one and see if I’ll need another one down the road.

So are you ready to build a Transformer with JAX and Haiku? By the way, I assume that you have a solid understanding of transformers. If you haven’t, please advise our articles on attention and transformers.

Let’s start with the self-attention block.

The self-attention block

First, we need to import JAX and Haiku

import jax

import jax.numpy as jnp

import haiku as hk

Import numpy as np

Luckily for us, Haiku has a built-in MultiHeadAttention block that can be extended to build a masked self-attention block. Our block accepts the query, key, value as well as the mask and returns the output as a JAX array. You can see that the code is very familiar with standard Pytorch or Tensorflow code. All we do is build the causal mask, using np.trill()which nullify all elements of the array above the kth, multiply with our mask and pass everything into the hk.MultiHeadAttention module.

class SelfAttention(hk.MultiHeadAttention):

"""Self attention with a causal mask applied."""

def __call__(


query: jnp.ndarray,

key: Optional[jnp.ndarray] = None,

value: Optional[jnp.ndarray] = None,

mask: Optional[jnp.ndarray] = None,

) -> jnp.ndarray:

key = key if key is not None else query

value = value if value is not None else query

seq_len = query.shape[1]

causal_mask = np.tril(np.ones((seq_len, seq_len)))

mask = mask * causal_mask if mask is not None else causal_mask

return super().__call__(query, key, value, mask)

This snippet allows me to introduce the first key principle of Haiku. All modules should be a subclass of hk.Module. This means that they should implement __init__ and __call__, alongside any other method. In a sense, it’s the same architecture with Pytorch modules, where we implement an __init__ and a forward.

To make that crystal clear, let’s build a simple 2-layer MultilayerPerceptron as an hk.Module, which conveniently will be used in the Transformer below.

The linear layer

A simple 2-layer MLP will look like this. Once again, you can notice how familiar it looks.

class DenseBlock(hk.Module):

"""A 2-layer MLP"""

def __init__(self,

init_scale: float,

widening_factor: int = 4,

name: Optional[str] = None):


self._init_scale = init_scale

self._widening_factor = widening_factor

def __call__(self, x: jnp.ndarray) -> jnp.ndarray:

hiddens = x.shape[-1]

initializer = hk.initializers.VarianceScaling(self._init_scale)

x = hk.Linear(self._widening_factor * hiddens, w_init=initializer)(x)

x = jax.nn.gelu(x)

return hk.Linear(hiddens, w_init=initializer)(x)

A few things to notice here:

  • Haiku provides us with a set of weights initializers under hk.initializers, where we can find the most common approaches.

  • It also has built-in many popular layers and modules such as hk.Linear. For the complete list, take a peek at the official documentation.

  • Activation functions are not provided because JAX already has a subpackage called jax.nn, where we can find activation functions such as relu or softmax.

The normalization layer

Layer normalization is another integral block of the transformer architecture, which we can also find in the common modules inside Haiku.

def layer_norm(x: jnp.ndarray, name: Optional[str] = None) -> jnp.ndarray:

"""Apply a unique LayerNorm to x with default settings."""

return hk.LayerNorm(axis=-1,




The transformer

And now for the good stuff. Below you can find a very simplistic Transformer, which makes use of our predefined modules. Inside __init__, we define the basic variables such as the number of layers, attention heads, and the dropout rate. Inside __call__, we compose a list of blocks using a for loop.

As you can see, each block includes:

In the end, we also add a final normalization layer.

class Transformer(hk.Module):

"""A transformer stack."""

def __init__(self,

num_heads: int,

num_layers: int,

dropout_rate: float,

name: Optional[str] = None):


self._num_layers = num_layers

self._num_heads = num_heads

self._dropout_rate = dropout_rate

def __call__(self,

h: jnp.ndarray,

mask: Optional[jnp.ndarray],

is_training: bool) -> jnp.ndarray:

"""Connects the transformer.


h: Inputs, [B, T, H].

mask: Padding mask, [B, T].

is_training: Whether we're training or not.


Array of shape [B, T, H].


init_scale = 2. / self._num_layers

dropout_rate = self._dropout_rate if is_training else 0.

if mask is not None:

mask = mask[:, None, None, :]

for i in range(self._num_layers):

h_norm = layer_norm(h, name=f'h{i}_ln_1')

h_attn = SelfAttention(




name=f'h{i}_attn')(h_norm, mask=mask)

h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn)

h = h + h_attn

h_norm = layer_norm(h, name=f'h{i}_ln_2')

h_dense = DenseBlock(init_scale, name=f'h{i}_mlp')(h_norm)

h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense)

h = h + h_dense

h = layer_norm(h, name='ln_f')

return h

I think that by now you have realized that building a Neural Network with JAX is dead simple.

The embeddings layer

For completion, let’s also include the embeddings layer. It is good to know that Haiku also provides an embedding layer which will create the tokens from our input sentence. The token are then added to the positional embeddings, which produce the final input.

def embeddings(data: Mapping[str, jnp.ndarray], vocab_size: int) :

tokens = data['obs']

input_mask = jnp.greater(tokens, 0)

seq_length = tokens.shape[1]

embed_init = hk.initializers.TruncatedNormal(stddev=0.02)

token_embedding_map = hk.Embed(vocab_size, d_model, w_init=embed_init)

token_embs = token_embedding_map(tokens)

positional_embeddings = hk.get_parameter(

'pos_embs', [seq_length, d_model], init=embed_init)

input_embeddings = token_embs + positional_embeddings

return input_embeddings, input_mask

hk.get_parameter(param_name, ...) is used to access the trainable parameters of a module. But you may ask, why not just using object properties as we do in Pytorch. This is where the second key principle of Haiku comes into play. We use this API so that we can convert the code into a pure function using hk.transform. This is not very simple to grasp but I will try to make it as clear as possible.

Why pure functions?

The power of JAX comes into its function transformations: the ability to vectorize a function with vmap, the automatic parallelization with pmap, just in time compilation with jit. The caveat here is that in order to transform a function, it needs to be pure.

A pure function is a function that has the following properties:

  • The function return values are identical for identical arguments (no variation with local static variables, non-local variables, mutable reference arguments, or input streams).

  • The function application has no side effects (no mutation of local static variables, non-local variables, mutable reference arguments, or input/output streams).


Source: Scala pure functions by O’Reily

This practically means that a pure function will always:

  • return the same result if invoked with the same inputs

  • all the input data is passed through the function arguments, all the results are output through the function results

Haiku provides a function transformation, called hk.transform, that turns functions with object-oriented, functionally “impure” modules into pure functions that can be used with JAX. To see that in practice, let’s continue with the training of our Transformer model.

The forward pass

A typical forward pass includes:

  1. Taking the input and compute the input embedding

  2. Run through the Transformer’s blocks

  3. Return the output

The aforementioned steps can be easily composed with JAX as following:

def build_forward_fn(vocab_size: int, d_model: int, num_heads: int,

num_layers: int, dropout_rate: float):

"""Create the model's forward pass."""

def forward_fn(data: Mapping[str, jnp.ndarray],

is_training: bool = True) -> jnp.ndarray:

"""Forward pass."""

input_embeddings, input_mask = embeddings(data, vocab_size)

transformer = Transformer(

num_heads=num_heads, num_layers=num_layers, dropout_rate=dropout_rate)

output_embeddings = transformer(input_embeddings, input_mask, is_training)

return hk.Linear(vocab_size)(output_embeddings)

return forward_fn

Although the code is straightforward, its structure might seem a bit odd. The actual forward pass is executed through the forward_fn function. However, we wrap this with the build_forward_fn function which returns the forward_fn. What the heck?

Down the road, we will need to transform the forward_fn function into a pure function using hk.transform so that we can take advantage of automatic differentiation, parallelization etc.

This will be accomplished by:

forward_fn = build_forward_fn(vocab_size, d_model, num_heads,

num_layers, dropout_rate)

forward_fn = hk.transform(forward_fn)

That’s why instead of simply defining a function, we wrapp and return the function itself, or a callable to be more precise. This callable can then be passed into the hk.transform and become a pure function. If this is clear, let’s continue with our loss function.

The loss function

The loss function is our well-known cross-entropy function with the difference that we are also taking the mask into consideration. Once again, JAX provides one_hot and log_softmax functionalities.

def lm_loss_fn(forward_fn,

vocab_size: int,



data: Mapping[str, jnp.ndarray],

is_training: bool = True) -> jnp.ndarray:

"""Compute the loss on data wrt params."""

logits = forward_fn(params, rng, data, is_training)

targets = jax.nn.one_hot(data['target'], vocab_size)

assert logits.shape == targets.shape

mask = jnp.greater(data['obs'], 0)

loss = -jnp.sum(targets * jax.nn.log_softmax(logits), axis=-1)

loss = jnp.sum(loss * mask) / jnp.sum(mask)

return loss

If you are still with me, take a sip of coffee because things are going to get serious from now on. It’s time to build our training loop.

The training loop

Because neither Jax nor Haiku has optimization functionalities built-in, we will make use of another framework, called Optax. As mentioned in the beginning, Optax is the goto package for gradient processing.

First here are some things you need to know about Optax:

The key transformation of Optax is the GradientTransformation. The transformation is defined by two functions, the __init__ and the __update__. The __init__ initializes the state and the __update__ transforms the gradients with respect to the state and the current value of the parameters

state = init(params)

grads, state = update(grads, state, params=None)

One more thing to know before we see the code, is Python’s built-in functools.partial function. The functools package deals with higher-order functions and operations on callable objects.

A function is called a Higher Order function if it contains other functions as a parameter or returns a function as an output.

The partial, which can also be used as an annotation, returns a new function based on an original one, but with fewer or fixed arguments. If for example, f multiplies two values x,y, the partial will create a new function where x will be fixed and equal with 2

from functools import partial

def f(x,y):

return x * y

g = partial(f,2)


After this short detour, let’s proceed. To decongest our main function, we will extract the gradients update into its own class.

First of all the GradientUpdater accepts the model, the loss function, and an optimizer.

  1. The model will be a pure forward_fn function transformed by hk.transform

forward_fn = build_forward_fn(vocab_size, d_model, num_heads,

num_layers, dropout_rate)

forward_fn = hk.transform(forward_fn)

  1. The loss function will be the result of a partial with a fixed forward_fn and `vocab_size

loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)

  1. The optimizer is a set of optimization transformations that will run sequentially ( operations can be combined using optax.chain )

optimizer = optax.chain(


optax.adam(learning_rate, b1=0.9, b2=0.99))

The Gradient updater will be initialized as follows:

updater = GradientUpdater(forward_fn.init, loss_fn, optimizer)

and will look like this:

class GradientUpdater:

"""A stateless abstraction around an init_fn/update_fn pair.

This extracts some common boilerplate from the training loop.


def __init__(self, net_init, loss_fn,

optimizer: optax.GradientTransformation):

self._net_init = net_init

self._loss_fn = loss_fn

self._opt = optimizer

@functools.partial(jax.jit, static_argnums=0)

def init(self, master_rng, data):

"""Initializes state of the updater."""

out_rng, init_rng = jax.random.split(master_rng)

params = self._net_init(init_rng, data)

opt_state = self._opt.init(params)

out = dict(






return out

@functools.partial(jax.jit, static_argnums=0)

def update(self, state: Mapping[str, Any], data: Mapping[str, jnp.ndarray]):

"""Updates the state using some data and returns metrics."""

rng, new_rng = jax.random.split(state['rng'])

params = state['params']

loss, g = jax.value_and_grad(self._loss_fn)(params, rng, data)

updates, opt_state = self._opt.update(g, state['opt_state'])

params = optax.apply_updates(params, updates)

new_state = {

'step': state['step'] + 1,

'rng': new_rng,

'opt_state': opt_state,

'params': params,


metrics = {

'step': state['step'],

'loss': loss,


return new_state, metrics

Inside __init__, we initialize our optimizer with self._opt.init(params) and we declare the state of the optimization. The state will be a dictionary with:

The update function will update both the state of the optimizer as well as the trainable parameters. In the end, it will return the new state.

updates, opt_state = self._opt.update(g, state['opt_state'])

params = optax.apply_updates(params, updates)

Two more things to notice here:

  • jax.value_and_grad() is a special function that returns a differentiable function with its gradients

  • Both __init__ and __update__ are annotated with @functools.partial(jax.jit, static_argnums=0), which will trigger the just-in-time compiler and compile them into XLA during runtime. Note that if we haven’t transformedforward_fn into a pure function, this wouldn’t be possible.

Finally, we are ready to build the entire training loop, which combines all the ideas and code mentioned so far.

def main():

train_dataset, vocab_size = load(batch_size,


forward_fn = build_forward_fn(vocab_size, d_model, num_heads,

num_layers, dropout_rate)

forward_fn = hk.transform(forward_fn)

loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)

optimizer = optax.chain(


optax.adam(learning_rate, b1=0.9, b2=0.99))

updater = GradientUpdater(forward_fn.init, loss_fn, optimizer)'Initializing parameters...')

rng = jax.random.PRNGKey(428)

data = next(train_dataset)

state = updater.init(rng, data)'Starting train loop...')

prev_time = time.time()

for step in range(MAX_STEPS):

data = next(train_dataset)

state, metrics = updater.update(state, data)

Notice how we incorporate the GradientUpdate. It’s just two lines of code:

  • state = updater.init(rng, data)

  • state, metrics = updater.update(state, data)

And that’s it. I hope that by now you have a more clear understanding of JAX and its capabilities.


The code presented is heavily inspired by the official examples of the Haiku framework. It has been modified to fit the needs of this article. For the complete list of examples, check the official repository


In this article, we saw how one can develop and train a vanilla Transformer in JAX using Haiku. Although the code isn’t necessarily hard to grasp, it still lacks the readability of Pytorch or Tensorflow. I highly recommend to play around with it, discover the strengths and weaknesses of JAX and see if it’d be a good fit for your next project. In my experience, JAX is very strong for research applications that require high performance but quite immature for real-life projects. Let us know what you think in our discord channel.

Deep Learning in Production Book 📖

Learn how to build, train, deploy, scale and maintain deep learning models. Understand ML infrastructure and MLOps using hands-on examples.

Learn more

* Disclosure: Please note that some of the links above might be affiliate links, and at no additional cost to you, we will earn a commission if you decide to make a purchase after clicking through.

Source link