In this tutorial, we will explore how to develop a Neural Network (NN) with JAX. And what better model to choose than the Transformer. As JAX is growing in popularity, more and more developer teams are starting to experiment with it and incorporating it into their projects. Despite the fact that it lacks the maturity of Tensorflow or Pytorch, it provides some great features for building and training Deep Learning models.
For a solid understanding of JAX basics, check my previous article if you haven’t already. Also you can find the full code in our Github repository.
One of the common problems people have when starting with JAX is the choice of a framework. The people in Deepmind seem to be very busy and have already released a plethora of frameworks on top of JAX. Here is a list of the most famous ones:

Haiku: Haiku is the goto framework for Deep Learning and it’s used by many Google and Deepmind internal teams. It provides some simple, composable abstractions for machine learning research as well as readytouse modules and layers.

Optax: Optax is a gradient processing and optimization library that contains outofthebox optimizers and related mathematical operations.

RLax: RLax is a reinforcement learning framework with many RL subcomponents and operations.

Chex: Chex is a library of utilities for testing and debugging JAX code.

Jraph: Jraph is a Graph Neural Networks library in JAX.

Flax: Flax is another neural network library with a variety of readytouse modules, optimizers, and utilities. It’s most likely the closest we have in an allin JAX framework.

Objax: Objax is a third ml library that focuses on objectoriented programming and code readability. Once again it contains the most popular modules, activation functions, losses, optimizers as well a handful of pretrained models.

Trax: Trax is an endtoend library for deep learning that focuses on Transformers

JAXline: JAXline is a supervisedlearning library that is used for distributed JAX training and evaluation.

ACME: ACME is another research framework for reinforcement learning.

JAXMD: JAXMD is a niche framework that deals with molecular dynamics.

Jaxchem: JAXChem is another niche library that emphasizes on chemical modeling.
Of course, the question is which one do I choose?
To be honest I’m not sure.
But if I were you and I wanted to learn JAX, I’d start with the most popular ones. Haiku and Flax seem to be used a lot inside Google/Deepmind and have the most active Github community. For this article, I will start with the first one and see if I’ll need another one down the road.
So are you ready to build a Transformer with JAX and Haiku? By the way, I assume that you have a solid understanding of transformers. If you haven’t, please advise our articles on attention and transformers.
Let’s start with the selfattention block.
The selfattention block
First, we need to import JAX and Haiku
import jax
import jax.numpy as jnp
import haiku as hk
Import numpy as np
Luckily for us, Haiku has a builtin MultiHeadAttention
block that can be extended to build a masked selfattention block. Our block accepts the query, key, value as well as the mask and returns the output as a JAX array. You can see that the code is very familiar with standard Pytorch or Tensorflow code. All we do is build the causal mask, using np.trill()
which nullify all elements of the array above the kth, multiply with our mask and pass everything into the hk.MultiHeadAttention
module.
class SelfAttention(hk.MultiHeadAttention):
"""Self attention with a causal mask applied."""
def __call__(
self,
query: jnp.ndarray,
key: Optional[jnp.ndarray] = None,
value: Optional[jnp.ndarray] = None,
mask: Optional[jnp.ndarray] = None,
) > jnp.ndarray:
key = key if key is not None else query
value = value if value is not None else query
seq_len = query.shape[1]
causal_mask = np.tril(np.ones((seq_len, seq_len)))
mask = mask * causal_mask if mask is not None else causal_mask
return super().__call__(query, key, value, mask)
This snippet allows me to introduce the first key principle of Haiku. All modules should be a subclass of hk.Module
. This means that they should implement __init__
and __call__
, alongside any other method. In a sense, it’s the same architecture with Pytorch modules, where we implement an __init__
and a forward
.
To make that crystal clear, let’s build a simple 2layer MultilayerPerceptron as an hk.Module
, which conveniently will be used in the Transformer below.
The linear layer
A simple 2layer MLP will look like this. Once again, you can notice how familiar it looks.
class DenseBlock(hk.Module):
"""A 2layer MLP"""
def __init__(self,
init_scale: float,
widening_factor: int = 4,
name: Optional[str] = None):
super().__init__(name=name)
self._init_scale = init_scale
self._widening_factor = widening_factor
def __call__(self, x: jnp.ndarray) > jnp.ndarray:
hiddens = x.shape[1]
initializer = hk.initializers.VarianceScaling(self._init_scale)
x = hk.Linear(self._widening_factor * hiddens, w_init=initializer)(x)
x = jax.nn.gelu(x)
return hk.Linear(hiddens, w_init=initializer)(x)
A few things to notice here:

Haiku provides us with a set of weights initializers under
hk.initializers
, where we can find the most common approaches. 
It also has builtin many popular layers and modules such as
hk.Linear
. For the complete list, take a peek at the official documentation. 
Activation functions are not provided because JAX already has a subpackage called
jax.nn
, where we can find activation functions such asrelu
orsoftmax
.
The normalization layer
Layer normalization is another integral block of the transformer architecture, which we can also find in the common modules inside Haiku.
def layer_norm(x: jnp.ndarray, name: Optional[str] = None) > jnp.ndarray:
"""Apply a unique LayerNorm to x with default settings."""
return hk.LayerNorm(axis=1,
create_scale=True,
create_offset=True,
name=name)(x)
The transformer
And now for the good stuff. Below you can find a very simplistic Transformer, which makes use of our predefined modules. Inside __init__
, we define the basic variables such as the number of layers, attention heads, and the dropout rate. Inside __call__
, we compose a list of blocks using a for
loop.
As you can see, each block includes:
In the end, we also add a final normalization layer.
class Transformer(hk.Module):
"""A transformer stack."""
def __init__(self,
num_heads: int,
num_layers: int,
dropout_rate: float,
name: Optional[str] = None):
super().__init__(name=name)
self._num_layers = num_layers
self._num_heads = num_heads
self._dropout_rate = dropout_rate
def __call__(self,
h: jnp.ndarray,
mask: Optional[jnp.ndarray],
is_training: bool) > jnp.ndarray:
"""Connects the transformer.
Args:
h: Inputs, [B, T, H].
mask: Padding mask, [B, T].
is_training: Whether we're training or not.
Returns:
Array of shape [B, T, H].
"""
init_scale = 2. / self._num_layers
dropout_rate = self._dropout_rate if is_training else 0.
if mask is not None:
mask = mask[:, None, None, :]
for i in range(self._num_layers):
h_norm = layer_norm(h, name=f'h{i}_ln_1')
h_attn = SelfAttention(
num_heads=self._num_heads,
key_size=64,
w_init_scale=init_scale,
name=f'h{i}_attn')(h_norm, mask=mask)
h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn)
h = h + h_attn
h_norm = layer_norm(h, name=f'h{i}_ln_2')
h_dense = DenseBlock(init_scale, name=f'h{i}_mlp')(h_norm)
h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense)
h = h + h_dense
h = layer_norm(h, name='ln_f')
return h
I think that by now you have realized that building a Neural Network with JAX is dead simple.
The embeddings layer
For completion, let’s also include the embeddings layer. It is good to know that Haiku also provides an embedding layer which will create the tokens from our input sentence. The token are then added to the positional embeddings, which produce the final input.
def embeddings(data: Mapping[str, jnp.ndarray], vocab_size: int) :
tokens = data['obs']
input_mask = jnp.greater(tokens, 0)
seq_length = tokens.shape[1]
embed_init = hk.initializers.TruncatedNormal(stddev=0.02)
token_embedding_map = hk.Embed(vocab_size, d_model, w_init=embed_init)
token_embs = token_embedding_map(tokens)
positional_embeddings = hk.get_parameter(
'pos_embs', [seq_length, d_model], init=embed_init)
input_embeddings = token_embs + positional_embeddings
return input_embeddings, input_mask
hk.get_parameter(param_name, ...)
is used to access the trainable parameters of a module. But you may ask, why not just using object properties as we do in Pytorch. This is where the second key principle of Haiku comes into play. We use this API so that we can convert the code into a pure function using hk.transform
. This is not very simple to grasp but I will try to make it as clear as possible.
Why pure functions?
The power of JAX comes into its function transformations: the ability to vectorize a function with vmap
, the automatic parallelization with pmap
, just in time compilation with jit
. The caveat here is that in order to transform a function, it needs to be pure.
A pure function is a function that has the following properties:

The function return values are identical for identical arguments (no variation with local static variables, nonlocal variables, mutable reference arguments, or input streams).

The function application has no side effects (no mutation of local static variables, nonlocal variables, mutable reference arguments, or input/output streams).
Source: Scala pure functions by O’Reily
This practically means that a pure function will always:

return the same result if invoked with the same inputs

all the input data is passed through the function arguments, all the results are output through the function results
Haiku provides a function transformation, called hk.transform
, that turns functions with objectoriented, functionally “impure” modules into pure functions that can be used with JAX. To see that in practice, let’s continue with the training of our Transformer model.
The forward pass
A typical forward pass includes:

Taking the input and compute the input embedding

Run through the Transformer’s blocks

Return the output
The aforementioned steps can be easily composed with JAX as following:
def build_forward_fn(vocab_size: int, d_model: int, num_heads: int,
num_layers: int, dropout_rate: float):
"""Create the model's forward pass."""
def forward_fn(data: Mapping[str, jnp.ndarray],
is_training: bool = True) > jnp.ndarray:
"""Forward pass."""
input_embeddings, input_mask = embeddings(data, vocab_size)
transformer = Transformer(
num_heads=num_heads, num_layers=num_layers, dropout_rate=dropout_rate)
output_embeddings = transformer(input_embeddings, input_mask, is_training)
return hk.Linear(vocab_size)(output_embeddings)
return forward_fn
Although the code is straightforward, its structure might seem a bit odd. The actual forward pass is executed through the forward_fn
function. However, we wrap this with the build_forward_fn
function which returns the forward_fn
. What the heck?
Down the road, we will need to transform the forward_fn
function into a pure function using hk.transform
so that we can take advantage of automatic differentiation, parallelization etc.
This will be accomplished by:
forward_fn = build_forward_fn(vocab_size, d_model, num_heads,
num_layers, dropout_rate)
forward_fn = hk.transform(forward_fn)
That’s why instead of simply defining a function, we wrapp and return the function itself, or a callable to be more precise. This callable can then be passed into the hk.transform
and become a pure function. If this is clear, let’s continue with our loss function.
The loss function
The loss function is our wellknown crossentropy function with the difference that we are also taking the mask into consideration. Once again, JAX provides one_hot
and log_softmax
functionalities.
def lm_loss_fn(forward_fn,
vocab_size: int,
params,
rng,
data: Mapping[str, jnp.ndarray],
is_training: bool = True) > jnp.ndarray:
"""Compute the loss on data wrt params."""
logits = forward_fn(params, rng, data, is_training)
targets = jax.nn.one_hot(data['target'], vocab_size)
assert logits.shape == targets.shape
mask = jnp.greater(data['obs'], 0)
loss = jnp.sum(targets * jax.nn.log_softmax(logits), axis=1)
loss = jnp.sum(loss * mask) / jnp.sum(mask)
return loss
If you are still with me, take a sip of coffee because things are going to get serious from now on. It’s time to build our training loop.
The training loop
Because neither Jax nor Haiku has optimization functionalities builtin, we will make use of another framework, called Optax. As mentioned in the beginning, Optax is the goto package for gradient processing.
First here are some things you need to know about Optax:
The key transformation of Optax is the GradientTransformation
. The transformation is defined by two functions, the __init__
and the __update__
. The __init__
initializes the state and the __update__
transforms the gradients with respect to the state and the current value of the parameters
state = init(params)
grads, state = update(grads, state, params=None)
One more thing to know before we see the code, is Python’s builtin functools.partial
function. The functools
package deals with higherorder functions and operations on callable objects.
A function is called a Higher Order function if it contains other functions as a parameter or returns a function as an output.
The partial
, which can also be used as an annotation, returns a new function based on an original one, but with fewer or fixed arguments. If for example, f multiplies two values x,y, the partial will create a new function where x will be fixed and equal with 2
from functools import partial
def f(x,y):
return x * y
g = partial(f,2)
print(g(4))
After this short detour, let’s proceed. To decongest our main
function, we will extract the gradients update into its own class.
First of all the GradientUpdater
accepts the model, the loss function, and an optimizer.
 The model will be a pure
forward_fn
function transformed byhk.transform
forward_fn = build_forward_fn(vocab_size, d_model, num_heads,
num_layers, dropout_rate)
forward_fn = hk.transform(forward_fn)
 The loss function will be the result of a partial with a fixed
forward_fn
and `vocab_size
loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)
 The optimizer is a set of optimization transformations that will run sequentially ( operations can be combined using
optax.chain
)
optimizer = optax.chain(
optax.clip_by_global_norm(grad_clip_value),
optax.adam(learning_rate, b1=0.9, b2=0.99))
The Gradient updater will be initialized as follows:
updater = GradientUpdater(forward_fn.init, loss_fn, optimizer)
and will look like this:
class GradientUpdater:
"""A stateless abstraction around an init_fn/update_fn pair.
This extracts some common boilerplate from the training loop.
"""
def __init__(self, net_init, loss_fn,
optimizer: optax.GradientTransformation):
self._net_init = net_init
self._loss_fn = loss_fn
self._opt = optimizer
@functools.partial(jax.jit, static_argnums=0)
def init(self, master_rng, data):
"""Initializes state of the updater."""
out_rng, init_rng = jax.random.split(master_rng)
params = self._net_init(init_rng, data)
opt_state = self._opt.init(params)
out = dict(
step=np.array(0),
rng=out_rng,
opt_state=opt_state,
params=params,
)
return out
@functools.partial(jax.jit, static_argnums=0)
def update(self, state: Mapping[str, Any], data: Mapping[str, jnp.ndarray]):
"""Updates the state using some data and returns metrics."""
rng, new_rng = jax.random.split(state['rng'])
params = state['params']
loss, g = jax.value_and_grad(self._loss_fn)(params, rng, data)
updates, opt_state = self._opt.update(g, state['opt_state'])
params = optax.apply_updates(params, updates)
new_state = {
'step': state['step'] + 1,
'rng': new_rng,
'opt_state': opt_state,
'params': params,
}
metrics = {
'step': state['step'],
'loss': loss,
}
return new_state, metrics
Inside __init__
, we initialize our optimizer with self._opt.init(params)
and we declare the state of the optimization. The state will be a dictionary with:
The update
function will update both the state of the optimizer as well as the trainable parameters. In the end, it will return the new state.
updates, opt_state = self._opt.update(g, state['opt_state'])
params = optax.apply_updates(params, updates)
Two more things to notice here:

jax.value_and_grad()
is a special function that returns a differentiable function with its gradients 
Both
__init__
and__update__
are annotated with@functools.partial(jax.jit, static_argnums=0)
, which will trigger the justintime compiler and compile them into XLA during runtime. Note that if we haven’t transformedforward_fn
into a pure function, this wouldn’t be possible.
Finally, we are ready to build the entire training loop, which combines all the ideas and code mentioned so far.
def main():
train_dataset, vocab_size = load(batch_size,
sequence_length)
forward_fn = build_forward_fn(vocab_size, d_model, num_heads,
num_layers, dropout_rate)
forward_fn = hk.transform(forward_fn)
loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)
optimizer = optax.chain(
optax.clip_by_global_norm(grad_clip_value),
optax.adam(learning_rate, b1=0.9, b2=0.99))
updater = GradientUpdater(forward_fn.init, loss_fn, optimizer)
logging.info('Initializing parameters...')
rng = jax.random.PRNGKey(428)
data = next(train_dataset)
state = updater.init(rng, data)
logging.info('Starting train loop...')
prev_time = time.time()
for step in range(MAX_STEPS):
data = next(train_dataset)
state, metrics = updater.update(state, data)
Notice how we incorporate the GradientUpdate
. It’s just two lines of code:

state = updater.init(rng, data)

state, metrics = updater.update(state, data)
And that’s it. I hope that by now you have a more clear understanding of JAX and its capabilities.
Acknowledgments
The code presented is heavily inspired by the official examples of the Haiku framework. It has been modified to fit the needs of this article. For the complete list of examples, check the official repository
Conclusion
In this article, we saw how one can develop and train a vanilla Transformer in JAX using Haiku. Although the code isn’t necessarily hard to grasp, it still lacks the readability of Pytorch or Tensorflow. I highly recommend to play around with it, discover the strengths and weaknesses of JAX and see if it’d be a good fit for your next project. In my experience, JAX is very strong for research applications that require high performance but quite immature for reallife projects. Let us know what you think in our discord channel.
* Disclosure: Please note that some of the links above might be affiliate links, and at no additional cost to you, we will earn a commission if you decide to make a purchase after clicking through.