The Transformer Architecture⚓︎
We have compared CNNs, RNNs, and self-attention in
Notably, self-attention
enjoys both parallel computation and
the shortest maximum path length.
it is appealing to design deep architectures
by using self-attention.
Unlike earlier self-attention models
that still rely on RNNs for input representations :cite:Cheng.Dong.Lapata.2016,Lin.Feng.Santos.ea.2017,Paulus.Xiong.Socher.2017
the Transformer model
is solely based on attention mechanisms
without any convolutional or recurrent layer :cite:Vaswani.Shazeer.Parmar.ea.2017
Though originally proposed
for sequence-to-sequence learning on text data,
Transformers have been
pervasive in a wide range of
modern deep learning applications,
such as in areas to do with language, vision, speech, and reinforcement learning.
%%tab mxnet
from d2l import mxnet as d2l
import math
from mxnet import autograd, init, np, npx
from mxnet.gluon import nn
import pandas as pd
%%tab pytorch
from d2l import torch as d2l
import math
import pandas as pd
import torch
from torch import nn
%%tab tensorflow
from d2l import tensorflow as d2l
import numpy as np
import pandas as pd
import tensorflow as tf
%%tab jax
from d2l import jax as d2l
from flax import linen as nn
from jax import numpy as jnp
import jax
import math
import pandas as pd
As an instance of the encoder--decoder
the overall architecture of
the Transformer
is presented in :numref:fig_transformer
As we can see,
the Transformer is composed of an encoder and a decoder.
In contrast to
Bahdanau attention
for sequence-to-sequence learning
in :numref:fig_s2s_attention_details
the input (source) and output (target)
sequence embeddings
are added with positional encoding
before being fed into
the encoder and the decoder
that stack modules based on self-attention.
Now we provide an overview of the
Transformer architecture in :numref:fig_transformer
At a high level,
the Transformer encoder is a stack of multiple identical layers,
where each layer
has two sublayers (either is denoted as \(\textrm{sublayer}\)).
The first
is a multi-head self-attention pooling
and the second is a positionwise feed-forward network.
in the encoder self-attention,
queries, keys, and values are all from the
outputs of the previous encoder layer.
Inspired by the ResNet design of :numref:sec_resnet
a residual connection is employed
around both sublayers.
In the Transformer,
for any input \(\mathbf{x} \in \mathbb{R}^d\) at any position of the sequence,
we require that \(\textrm{sublayer}(\mathbf{x}) \in \mathbb{R}^d\) so that
the residual connection \(\mathbf{x} + \textrm{sublayer}(\mathbf{x}) \in \mathbb{R}^d\) is feasible.
This addition from the residual connection is immediately
followed by layer normalization :cite:Ba.Kiros.Hinton.2016
As a result, the Transformer encoder outputs a \(d\)-dimensional vector representation
for each position of the input sequence.
The Transformer decoder is also a stack of multiple identical layers with residual connections and layer normalizations. As well as the two sublayers described in the encoder, the decoder inserts a third sublayer, known as the encoder--decoder attention, between these two. In the encoder--decoder attention, queries are from the outputs of the decoder's self-attention sublayer, and the keys and values are from the Transformer encoder outputs. In the decoder self-attention, queries, keys, and values are all from the outputs of the previous decoder layer. However, each position in the decoder is allowed only to attend to all positions in the decoder up to that position. This masked attention preserves the autoregressive property, ensuring that the prediction only depends on those output tokens that have been generated.
We have already described and implemented
multi-head attention based on scaled dot products
in :numref:sec_multihead-attention
and positional encoding in :numref:subsec_positional-encoding
In the following, we will implement
the rest of the Transformer model.
[Positionwise Feed-Forward Networks]⚓︎
The positionwise feed-forward network transforms
the representation at all the sequence positions
using the same MLP.
This is why we call it positionwise.
In the implementation below,
the input X
with shape
(batch size, number of time steps or sequence length in tokens,
number of hidden units or feature dimension)
will be transformed by a two-layer MLP into
an output tensor of shape
(batch size, number of time steps, ffn_num_outputs
%%tab mxnet
class PositionWiseFFN(nn.Block): #@save
"""The positionwise feed-forward network."""
def __init__(self, ffn_num_hiddens, ffn_num_outputs):
self.dense1 = nn.Dense(ffn_num_hiddens, flatten=False,
self.dense2 = nn.Dense(ffn_num_outputs, flatten=False)
def forward(self, X):
return self.dense2(self.dense1(X))
%%tab pytorch
class PositionWiseFFN(nn.Module): #@save
"""The positionwise feed-forward network."""
def __init__(self, ffn_num_hiddens, ffn_num_outputs):
self.dense1 = nn.LazyLinear(ffn_num_hiddens)
self.relu = nn.ReLU()
self.dense2 = nn.LazyLinear(ffn_num_outputs)
def forward(self, X):
return self.dense2(self.relu(self.dense1(X)))
%%tab tensorflow
class PositionWiseFFN(tf.keras.layers.Layer): #@save
"""The positionwise feed-forward network."""
def __init__(self, ffn_num_hiddens, ffn_num_outputs):
self.dense1 = tf.keras.layers.Dense(ffn_num_hiddens)
self.relu = tf.keras.layers.ReLU()
self.dense2 = tf.keras.layers.Dense(ffn_num_outputs)
def call(self, X):
return self.dense2(self.relu(self.dense1(X)))
%%tab jax
class PositionWiseFFN(nn.Module): #@save
"""The positionwise feed-forward network."""
ffn_num_hiddens: int
ffn_num_outputs: int
def setup(self):
self.dense1 = nn.Dense(self.ffn_num_hiddens)
self.dense2 = nn.Dense(self.ffn_num_outputs)
def __call__(self, X):
return self.dense2(nn.relu(self.dense1(X)))
The following example shows that [the innermost dimension of a tensor changes] to the number of outputs in the positionwise feed-forward network. Since the same MLP transforms at all the positions, when the inputs at all these positions are the same, their outputs are also identical.
%%tab mxnet
ffn = PositionWiseFFN(4, 8)
ffn(np.ones((2, 3, 4)))[0]
%%tab pytorch
ffn = PositionWiseFFN(4, 8)
ffn(d2l.ones((2, 3, 4)))[0]
%%tab tensorflow
ffn = PositionWiseFFN(4, 8)
ffn(tf.ones((2, 3, 4)))[0]
%%tab jax
ffn = PositionWiseFFN(4, 8)
ffn.init_with_output(d2l.get_key(), jnp.ones((2, 3, 4)))[0][0]
Residual Connection and Layer Normalization⚓︎
Now let's focus on the "add & norm" component in :numref:fig_transformer
As we described at the beginning of this section,
this is a residual connection immediately
followed by layer normalization.
Both are key to effective deep architectures.
In :numref:sec_batch_norm
we explained how batch normalization
recenters and rescales across the examples within
a minibatch.
As discussed in :numref:subsec_layer-normalization-in-bn
layer normalization is the same as batch normalization
except that the former
normalizes across the feature dimension,
thus enjoying benefits of scale independence and batch size independence.
Despite its pervasive applications
in computer vision,
batch normalization
is usually empirically
less effective than layer normalization
in natural language processing
tasks, where the inputs are often
variable-length sequences.
The following code snippet [compares the normalization across different dimensions by layer normalization and batch normalization].
%%tab mxnet
ln = nn.LayerNorm()
bn = nn.BatchNorm()
X = d2l.tensor([[1, 2], [2, 3]])
# Compute mean and variance from X in the training mode
with autograd.record():
print('layer norm:', ln(X), '\nbatch norm:', bn(X))
%%tab pytorch
ln = nn.LayerNorm(2)
bn = nn.LazyBatchNorm1d()
X = d2l.tensor([[1, 2], [2, 3]], dtype=torch.float32)
# Compute mean and variance from X in the training mode
print('layer norm:', ln(X), '\nbatch norm:', bn(X))
%%tab tensorflow
ln = tf.keras.layers.LayerNormalization()
bn = tf.keras.layers.BatchNormalization()
X = tf.constant([[1, 2], [2, 3]], dtype=tf.float32)
print('layer norm:', ln(X), '\nbatch norm:', bn(X, training=True))
%%tab jax
ln = nn.LayerNorm()
bn = nn.BatchNorm()
X = d2l.tensor([[1, 2], [2, 3]], dtype=d2l.float32)
# Compute mean and variance from X in the training mode
print('layer norm:', ln.init_with_output(d2l.get_key(), X)[0],
'\nbatch norm:', bn.init_with_output(d2l.get_key(), X,
Now we can implement the AddNorm
[using a residual connection followed by layer normalization].
Dropout is also applied for regularization.
%%tab mxnet
class AddNorm(nn.Block): #@save
"""The residual connection followed by layer normalization."""
def __init__(self, dropout):
self.dropout = nn.Dropout(dropout)
self.ln = nn.LayerNorm()
def forward(self, X, Y):
return self.ln(self.dropout(Y) + X)
%%tab pytorch
class AddNorm(nn.Module): #@save
"""The residual connection followed by layer normalization."""
def __init__(self, norm_shape, dropout):
self.dropout = nn.Dropout(dropout)
self.ln = nn.LayerNorm(norm_shape)
def forward(self, X, Y):
return self.ln(self.dropout(Y) + X)
%%tab tensorflow
class AddNorm(tf.keras.layers.Layer): #@save
"""The residual connection followed by layer normalization."""
def __init__(self, norm_shape, dropout):
self.dropout = tf.keras.layers.Dropout(dropout)
self.ln = tf.keras.layers.LayerNormalization(norm_shape)
def call(self, X, Y, **kwargs):
return self.ln(self.dropout(Y, **kwargs) + X)
%%tab jax
class AddNorm(nn.Module): #@save
"""The residual connection followed by layer normalization."""
dropout: int
def __call__(self, X, Y, training=False):
return nn.LayerNorm()(
nn.Dropout(self.dropout)(Y, deterministic=not training) + X)
The residual connection requires that the two inputs are of the same shape so that [the output tensor also has the same shape after the addition operation].
%%tab mxnet
add_norm = AddNorm(0.5)
shape = (2, 3, 4)
d2l.check_shape(add_norm(d2l.ones(shape), d2l.ones(shape)), shape)
%%tab pytorch
add_norm = AddNorm(4, 0.5)
shape = (2, 3, 4)
d2l.check_shape(add_norm(d2l.ones(shape), d2l.ones(shape)), shape)
%%tab tensorflow
# Normalized_shape is: [i for i in range(len(input.shape))][1:]
add_norm = AddNorm([1, 2], 0.5)
shape = (2, 3, 4)
d2l.check_shape(add_norm(tf.ones(shape), tf.ones(shape), training=False),
%%tab jax
add_norm = AddNorm(0.5)
shape = (2, 3, 4)
output, _ = add_norm.init_with_output(d2l.get_key(), d2l.ones(shape),
d2l.check_shape(output, shape)
With all the essential components to assemble
the Transformer encoder,
let's start by
implementing [a single layer within the encoder].
The following TransformerEncoderBlock
contains two sublayers: multi-head self-attention and positionwise feed-forward networks,
where a residual connection followed by layer normalization is employed
around both sublayers.
%%tab mxnet
class TransformerEncoderBlock(nn.Block): #@save
"""The Transformer encoder block."""
def __init__(self, num_hiddens, ffn_num_hiddens, num_heads, dropout,
self.attention = d2l.MultiHeadAttention(
num_hiddens, num_heads, dropout, use_bias)
self.addnorm1 = AddNorm(dropout)
self.ffn = PositionWiseFFN(ffn_num_hiddens, num_hiddens)
self.addnorm2 = AddNorm(dropout)
def forward(self, X, valid_lens):
Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
return self.addnorm2(Y, self.ffn(Y))
%%tab pytorch
class TransformerEncoderBlock(nn.Module): #@save
"""The Transformer encoder block."""
def __init__(self, num_hiddens, ffn_num_hiddens, num_heads, dropout,
self.attention = d2l.MultiHeadAttention(num_hiddens, num_heads,
dropout, use_bias)
self.addnorm1 = AddNorm(num_hiddens, dropout)
self.ffn = PositionWiseFFN(ffn_num_hiddens, num_hiddens)
self.addnorm2 = AddNorm(num_hiddens, dropout)
def forward(self, X, valid_lens):
Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
return self.addnorm2(Y, self.ffn(Y))
%%tab tensorflow
class TransformerEncoderBlock(tf.keras.layers.Layer): #@save
"""The Transformer encoder block."""
def __init__(self, key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_hiddens, num_heads, dropout, bias=False):
self.attention = d2l.MultiHeadAttention(
key_size, query_size, value_size, num_hiddens, num_heads, dropout,
self.addnorm1 = AddNorm(norm_shape, dropout)
self.ffn = PositionWiseFFN(ffn_num_hiddens, num_hiddens)
self.addnorm2 = AddNorm(norm_shape, dropout)
def call(self, X, valid_lens, **kwargs):
Y = self.addnorm1(X, self.attention(X, X, X, valid_lens, **kwargs),
return self.addnorm2(Y, self.ffn(Y), **kwargs)
%%tab jax
class TransformerEncoderBlock(nn.Module): #@save
"""The Transformer encoder block."""
num_hiddens: int
ffn_num_hiddens: int
num_heads: int
dropout: float
use_bias: bool = False
def setup(self):
self.attention = d2l.MultiHeadAttention(self.num_hiddens, self.num_heads,
self.dropout, self.use_bias)
self.addnorm1 = AddNorm(self.dropout)
self.ffn = PositionWiseFFN(self.ffn_num_hiddens, self.num_hiddens)
self.addnorm2 = AddNorm(self.dropout)
def __call__(self, X, valid_lens, training=False):
output, attention_weights = self.attention(X, X, X, valid_lens,
Y = self.addnorm1(X, output, training=training)
return self.addnorm2(Y, self.ffn(Y), training=training), attention_weights
As we can see, [no layer in the Transformer encoder changes the shape of its input.]
%%tab mxnet
X = d2l.ones((2, 100, 24))
valid_lens = d2l.tensor([3, 2])
encoder_blk = TransformerEncoderBlock(24, 48, 8, 0.5)
d2l.check_shape(encoder_blk(X, valid_lens), X.shape)
%%tab pytorch
X = d2l.ones((2, 100, 24))
valid_lens = d2l.tensor([3, 2])
encoder_blk = TransformerEncoderBlock(24, 48, 8, 0.5)
d2l.check_shape(encoder_blk(X, valid_lens), X.shape)
%%tab tensorflow
X = tf.ones((2, 100, 24))
valid_lens = tf.constant([3, 2])
norm_shape = [i for i in range(len(X.shape))][1:]
encoder_blk = TransformerEncoderBlock(24, 24, 24, 24, norm_shape, 48, 8, 0.5)
d2l.check_shape(encoder_blk(X, valid_lens, training=False), X.shape)
%%tab jax
X = jnp.ones((2, 100, 24))
valid_lens = jnp.array([3, 2])
encoder_blk = TransformerEncoderBlock(24, 48, 8, 0.5)
(output, _), _ = encoder_blk.init_with_output(d2l.get_key(), X, valid_lens,
d2l.check_shape(output, X.shape)
In the following [Transformer encoder] implementation,
we stack num_blks
instances of the above TransformerEncoderBlock
Since we use the fixed positional encoding
whose values are always between \(-1\) and \(1\),
we multiply values of the learnable input embeddings
by the square root of the embedding dimension
to rescale before summing up the input embedding and the positional encoding.
%%tab mxnet
class TransformerEncoder(d2l.Encoder): #@save
"""The Transformer encoder."""
def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens,
num_heads, num_blks, dropout, use_bias=False):
self.num_hiddens = num_hiddens
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
self.blks = nn.Sequential()
for _ in range(num_blks):
num_hiddens, ffn_num_hiddens, num_heads, dropout, use_bias))
def forward(self, X, valid_lens):
# Since positional encoding values are between -1 and 1, the embedding
# values are multiplied by the square root of the embedding dimension
# to rescale before they are summed up
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
self.attention_weights = [None] * len(self.blks)
for i, blk in enumerate(self.blks):
X = blk(X, valid_lens)
i] = blk.attention.attention.attention_weights
return X
%%tab pytorch
class TransformerEncoder(d2l.Encoder): #@save
"""The Transformer encoder."""
def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens,
num_heads, num_blks, dropout, use_bias=False):
self.num_hiddens = num_hiddens
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
self.blks = nn.Sequential()
for i in range(num_blks):
self.blks.add_module("block"+str(i), TransformerEncoderBlock(
num_hiddens, ffn_num_hiddens, num_heads, dropout, use_bias))
def forward(self, X, valid_lens):
# Since positional encoding values are between -1 and 1, the embedding
# values are multiplied by the square root of the embedding dimension
# to rescale before they are summed up
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
self.attention_weights = [None] * len(self.blks)
for i, blk in enumerate(self.blks):
X = blk(X, valid_lens)
i] = blk.attention.attention.attention_weights
return X
%%tab tensorflow
class TransformerEncoder(d2l.Encoder): #@save
"""The Transformer encoder."""
def __init__(self, vocab_size, key_size, query_size, value_size,
num_hiddens, norm_shape, ffn_num_hiddens, num_heads,
num_blks, dropout, bias=False):
self.num_hiddens = num_hiddens
self.embedding = tf.keras.layers.Embedding(vocab_size, num_hiddens)
self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
self.blks = [TransformerEncoderBlock(
key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_hiddens, num_heads, dropout, bias) for _ in range(
def call(self, X, valid_lens, **kwargs):
# Since positional encoding values are between -1 and 1, the embedding
# values are multiplied by the square root of the embedding dimension
# to rescale before they are summed up
X = self.pos_encoding(self.embedding(X) * tf.math.sqrt(
tf.cast(self.num_hiddens, dtype=tf.float32)), **kwargs)
self.attention_weights = [None] * len(self.blks)
for i, blk in enumerate(self.blks):
X = blk(X, valid_lens, **kwargs)
i] = blk.attention.attention.attention_weights
return X
%%tab jax
class TransformerEncoder(d2l.Encoder): #@save
"""The Transformer encoder."""
vocab_size: int
ffn_num_hiddens: int
num_heads: int
num_blks: int
dropout: float
use_bias: bool = False
def setup(self):
self.embedding = nn.Embed(self.vocab_size, self.num_hiddens)
self.pos_encoding = d2l.PositionalEncoding(self.num_hiddens, self.dropout)
self.blks = [TransformerEncoderBlock(self.num_hiddens,
self.dropout, self.use_bias)
for _ in range(self.num_blks)]
def __call__(self, X, valid_lens, training=False):
# Since positional encoding values are between -1 and 1, the embedding
# values are multiplied by the square root of the embedding dimension
# to rescale before they are summed up
X = self.embedding(X) * math.sqrt(self.num_hiddens)
X = self.pos_encoding(X, training=training)
attention_weights = [None] * len(self.blks)
for i, blk in enumerate(self.blks):
X, attention_w = blk(X, valid_lens, training=training)
attention_weights[i] = attention_w
# Flax sow API is used to capture intermediate variables
self.sow('intermediates', 'enc_attention_weights', attention_weights)
return X
Below we specify hyperparameters to [create a two-layer Transformer encoder].
The shape of the Transformer encoder output
is (batch size, number of time steps, num_hiddens
%%tab mxnet
encoder = TransformerEncoder(200, 24, 48, 8, 2, 0.5)
d2l.check_shape(encoder(np.ones((2, 100)), valid_lens), (2, 100, 24))
%%tab pytorch
encoder = TransformerEncoder(200, 24, 48, 8, 2, 0.5)
d2l.check_shape(encoder(d2l.ones((2, 100), dtype=torch.long), valid_lens),
(2, 100, 24))
%%tab tensorflow
encoder = TransformerEncoder(200, 24, 24, 24, 24, [1, 2], 48, 8, 2, 0.5)
d2l.check_shape(encoder(tf.ones((2, 100)), valid_lens, training=False),
(2, 100, 24))
%%tab jax
encoder = TransformerEncoder(200, 24, 48, 8, 2, 0.5)
jnp.ones((2, 100), dtype=jnp.int32),
(2, 100, 24))
As shown in :numref:fig_transformer
[the Transformer decoder
is composed of multiple identical layers].
Each layer is implemented in the following
which contains three sublayers:
decoder self-attention,
encoder--decoder attention,
and positionwise feed-forward networks.
These sublayers employ
a residual connection around them
followed by layer normalization.
As we described earlier in this section,
in the masked multi-head decoder self-attention
(the first sublayer),
queries, keys, and values
all come from the outputs of the previous decoder layer.
When training sequence-to-sequence models,
tokens at all the positions (time steps)
of the output sequence
are known.
during prediction
the output sequence is generated token by token;
at any decoder time step
only the generated tokens
can be used in the decoder self-attention.
To preserve autoregression in the decoder,
its masked self-attention
specifies dec_valid_lens
so that
any query
only attends to
all positions in the decoder
up to the query position.
%%tab mxnet
class TransformerDecoderBlock(nn.Block):
# The i-th block in the Transformer decoder
def __init__(self, num_hiddens, ffn_num_hiddens, num_heads, dropout, i):
self.i = i
self.attention1 = d2l.MultiHeadAttention(num_hiddens, num_heads,
self.addnorm1 = AddNorm(dropout)
self.attention2 = d2l.MultiHeadAttention(num_hiddens, num_heads,
self.addnorm2 = AddNorm(dropout)
self.ffn = PositionWiseFFN(ffn_num_hiddens, num_hiddens)
self.addnorm3 = AddNorm(dropout)
def forward(self, X, state):
enc_outputs, enc_valid_lens = state[0], state[1]
# During training, all the tokens of any output sequence are processed
# at the same time, so state[2][self.i] is None as initialized. When
# decoding any output sequence token by token during prediction,
# state[2][self.i] contains representations of the decoded output at
# the i-th block up to the current time step
if state[2][self.i] is None:
key_values = X
key_values = np.concatenate((state[2][self.i], X), axis=1)
state[2][self.i] = key_values
if autograd.is_training():
batch_size, num_steps, _ = X.shape
# Shape of dec_valid_lens: (batch_size, num_steps), where every
# row is [1, 2, ..., num_steps]
dec_valid_lens = np.tile(np.arange(1, num_steps + 1, ctx=X.ctx),
(batch_size, 1))
dec_valid_lens = None
# Self-attention
X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
Y = self.addnorm1(X, X2)
# Encoder-decoder attention. Shape of enc_outputs:
# (batch_size, num_steps, num_hiddens)
Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
Z = self.addnorm2(Y, Y2)
return self.addnorm3(Z, self.ffn(Z)), state
%%tab pytorch
class TransformerDecoderBlock(nn.Module):
# The i-th block in the Transformer decoder
def __init__(self, num_hiddens, ffn_num_hiddens, num_heads, dropout, i):
self.i = i
self.attention1 = d2l.MultiHeadAttention(num_hiddens, num_heads,
self.addnorm1 = AddNorm(num_hiddens, dropout)
self.attention2 = d2l.MultiHeadAttention(num_hiddens, num_heads,
self.addnorm2 = AddNorm(num_hiddens, dropout)
self.ffn = PositionWiseFFN(ffn_num_hiddens, num_hiddens)
self.addnorm3 = AddNorm(num_hiddens, dropout)
def forward(self, X, state):
enc_outputs, enc_valid_lens = state[0], state[1]
# During training, all the tokens of any output sequence are processed
# at the same time, so state[2][self.i] is None as initialized. When
# decoding any output sequence token by token during prediction,
# state[2][self.i] contains representations of the decoded output at
# the i-th block up to the current time step
if state[2][self.i] is None:
key_values = X
key_values = torch.cat((state[2][self.i], X), dim=1)
state[2][self.i] = key_values
if self.training:
batch_size, num_steps, _ = X.shape
# Shape of dec_valid_lens: (batch_size, num_steps), where every
# row is [1, 2, ..., num_steps]
dec_valid_lens = torch.arange(
1, num_steps + 1, device=X.device).repeat(batch_size, 1)
dec_valid_lens = None
# Self-attention
X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
Y = self.addnorm1(X, X2)
# Encoder-decoder attention. Shape of enc_outputs:
# (batch_size, num_steps, num_hiddens)
Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
Z = self.addnorm2(Y, Y2)
return self.addnorm3(Z, self.ffn(Z)), state
%%tab tensorflow
class TransformerDecoderBlock(tf.keras.layers.Layer):
# The i-th block in the Transformer decoder
def __init__(self, key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_hiddens, num_heads, dropout, i):
self.i = i
self.attention1 = d2l.MultiHeadAttention(
key_size, query_size, value_size, num_hiddens, num_heads, dropout)
self.addnorm1 = AddNorm(norm_shape, dropout)
self.attention2 = d2l.MultiHeadAttention(
key_size, query_size, value_size, num_hiddens, num_heads, dropout)
self.addnorm2 = AddNorm(norm_shape, dropout)
self.ffn = PositionWiseFFN(ffn_num_hiddens, num_hiddens)
self.addnorm3 = AddNorm(norm_shape, dropout)
def call(self, X, state, **kwargs):
enc_outputs, enc_valid_lens = state[0], state[1]
# During training, all the tokens of any output sequence are processed
# at the same time, so state[2][self.i] is None as initialized. When
# decoding any output sequence token by token during prediction,
# state[2][self.i] contains representations of the decoded output at
# the i-th block up to the current time step
if state[2][self.i] is None:
key_values = X
key_values = tf.concat((state[2][self.i], X), axis=1)
state[2][self.i] = key_values
if kwargs["training"]:
batch_size, num_steps, _ = X.shape
# Shape of dec_valid_lens: (batch_size, num_steps), where every
# row is [1, 2, ..., num_steps]
dec_valid_lens = tf.repeat(
tf.reshape(tf.range(1, num_steps + 1),
shape=(-1, num_steps)), repeats=batch_size, axis=0)
dec_valid_lens = None
# Self-attention
X2 = self.attention1(X, key_values, key_values, dec_valid_lens,
Y = self.addnorm1(X, X2, **kwargs)
# Encoder-decoder attention. Shape of enc_outputs:
# (batch_size, num_steps, num_hiddens)
Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens,
Z = self.addnorm2(Y, Y2, **kwargs)
return self.addnorm3(Z, self.ffn(Z), **kwargs), state
%%tab jax
class TransformerDecoderBlock(nn.Module):
# The i-th block in the Transformer decoder
num_hiddens: int
ffn_num_hiddens: int
num_heads: int
dropout: float
i: int
def setup(self):
self.attention1 = d2l.MultiHeadAttention(self.num_hiddens,
self.addnorm1 = AddNorm(self.dropout)
self.attention2 = d2l.MultiHeadAttention(self.num_hiddens,
self.addnorm2 = AddNorm(self.dropout)
self.ffn = PositionWiseFFN(self.ffn_num_hiddens, self.num_hiddens)
self.addnorm3 = AddNorm(self.dropout)
def __call__(self, X, state, training=False):
enc_outputs, enc_valid_lens = state[0], state[1]
# During training, all the tokens of any output sequence are processed
# at the same time, so state[2][self.i] is None as initialized. When
# decoding any output sequence token by token during prediction,
# state[2][self.i] contains representations of the decoded output at
# the i-th block up to the current time step
if state[2][self.i] is None:
key_values = X
key_values = jnp.concatenate((state[2][self.i], X), axis=1)
state[2][self.i] = key_values
if training:
batch_size, num_steps, _ = X.shape
# Shape of dec_valid_lens: (batch_size, num_steps), where every
# row is [1, 2, ..., num_steps]
dec_valid_lens = jnp.tile(jnp.arange(1, num_steps + 1),
(batch_size, 1))
dec_valid_lens = None
# Self-attention
X2, attention_w1 = self.attention1(X, key_values, key_values,
dec_valid_lens, training=training)
Y = self.addnorm1(X, X2, training=training)
# Encoder-decoder attention. Shape of enc_outputs:
# (batch_size, num_steps, num_hiddens)
Y2, attention_w2 = self.attention2(Y, enc_outputs, enc_outputs,
enc_valid_lens, training=training)
Z = self.addnorm2(Y, Y2, training=training)
return self.addnorm3(Z, self.ffn(Z), training=training), state, attention_w1, attention_w2
To facilitate scaled dot product operations
in the encoder--decoder attention
and addition operations in the residual connections,
[the feature dimension (num_hiddens
) of the decoder is
the same as that of the encoder.]
%%tab mxnet
decoder_blk = TransformerDecoderBlock(24, 48, 8, 0.5, 0)
X = np.ones((2, 100, 24))
state = [encoder_blk(X, valid_lens), valid_lens, [None]]
d2l.check_shape(decoder_blk(X, state)[0], X.shape)
%%tab pytorch
decoder_blk = TransformerDecoderBlock(24, 48, 8, 0.5, 0)
X = d2l.ones((2, 100, 24))
state = [encoder_blk(X, valid_lens), valid_lens, [None]]
d2l.check_shape(decoder_blk(X, state)[0], X.shape)
%%tab tensorflow
decoder_blk = TransformerDecoderBlock(24, 24, 24, 24, [1, 2], 48, 8, 0.5, 0)
X = tf.ones((2, 100, 24))
state = [encoder_blk(X, valid_lens), valid_lens, [None]]
d2l.check_shape(decoder_blk(X, state, training=False)[0], X.shape)
%%tab jax
decoder_blk = TransformerDecoderBlock(24, 48, 8, 0.5, 0)
X = d2l.ones((2, 100, 24))
state = [encoder_blk.init_with_output(d2l.get_key(), X, valid_lens)[0][0],
valid_lens, [None]]
d2l.check_shape(decoder_blk.init_with_output(d2l.get_key(), X, state)[0][0],
Now we [construct the entire Transformer decoder]
composed of num_blks
instances of TransformerDecoderBlock
In the end,
a fully connected layer computes the prediction
for all the vocab_size
possible output tokens.
Both of the decoder self-attention weights
and the encoder--decoder attention weights
are stored for later visualization.
%%tab mxnet
class TransformerDecoder(d2l.AttentionDecoder):
def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads,
num_blks, dropout):
self.num_hiddens = num_hiddens
self.num_blks = num_blks
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
self.blks = nn.Sequential()
for i in range(num_blks):
num_hiddens, ffn_num_hiddens, num_heads, dropout, i))
self.dense = nn.Dense(vocab_size, flatten=False)
def init_state(self, enc_outputs, enc_valid_lens):
return [enc_outputs, enc_valid_lens, [None] * self.num_blks]
def forward(self, X, state):
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
self._attention_weights = [[None] * len(self.blks) for _ in range (2)]
for i, blk in enumerate(self.blks):
X, state = blk(X, state)
# Decoder self-attention weights
i] = blk.attention1.attention.attention_weights
# Encoder-decoder attention weights
i] = blk.attention2.attention.attention_weights
return self.dense(X), state
def attention_weights(self):
return self._attention_weights
%%tab pytorch
class TransformerDecoder(d2l.AttentionDecoder):
def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads,
num_blks, dropout):
self.num_hiddens = num_hiddens
self.num_blks = num_blks
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
self.blks = nn.Sequential()
for i in range(num_blks):
self.blks.add_module("block"+str(i), TransformerDecoderBlock(
num_hiddens, ffn_num_hiddens, num_heads, dropout, i))
self.dense = nn.LazyLinear(vocab_size)
def init_state(self, enc_outputs, enc_valid_lens):
return [enc_outputs, enc_valid_lens, [None] * self.num_blks]
def forward(self, X, state):
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
self._attention_weights = [[None] * len(self.blks) for _ in range (2)]
for i, blk in enumerate(self.blks):
X, state = blk(X, state)
# Decoder self-attention weights
i] = blk.attention1.attention.attention_weights
# Encoder-decoder attention weights
i] = blk.attention2.attention.attention_weights
return self.dense(X), state
def attention_weights(self):
return self._attention_weights
%%tab tensorflow
class TransformerDecoder(d2l.AttentionDecoder):
def __init__(self, vocab_size, key_size, query_size, value_size,
num_hiddens, norm_shape, ffn_num_hiddens, num_heads,
num_blks, dropout):
self.num_hiddens = num_hiddens
self.num_blks = num_blks
self.embedding = tf.keras.layers.Embedding(vocab_size, num_hiddens)
self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
self.blks = [TransformerDecoderBlock(
key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_hiddens, num_heads, dropout, i)
for i in range(num_blks)]
self.dense = tf.keras.layers.Dense(vocab_size)
def init_state(self, enc_outputs, enc_valid_lens):
return [enc_outputs, enc_valid_lens, [None] * self.num_blks]
def call(self, X, state, **kwargs):
X = self.pos_encoding(self.embedding(X) * tf.math.sqrt(
tf.cast(self.num_hiddens, dtype=tf.float32)), **kwargs)
# 2 attention layers in decoder
self._attention_weights = [[None] * len(self.blks) for _ in range(2)]
for i, blk in enumerate(self.blks):
X, state = blk(X, state, **kwargs)
# Decoder self-attention weights
self._attention_weights[0][i] = (
# Encoder-decoder attention weights
self._attention_weights[1][i] = (
return self.dense(X), state
def attention_weights(self):
return self._attention_weights
%%tab jax
class TransformerDecoder(nn.Module):
vocab_size: int
num_hiddens: int
ffn_num_hiddens: int
num_heads: int
num_blks: int
dropout: float
def setup(self):
self.embedding = nn.Embed(self.vocab_size, self.num_hiddens)
self.pos_encoding = d2l.PositionalEncoding(self.num_hiddens,
self.blks = [TransformerDecoderBlock(self.num_hiddens,
self.num_heads, self.dropout, i)
for i in range(self.num_blks)]
self.dense = nn.Dense(self.vocab_size)
def init_state(self, enc_outputs, enc_valid_lens):
return [enc_outputs, enc_valid_lens, [None] * self.num_blks]
def __call__(self, X, state, training=False):
X = self.embedding(X) * jnp.sqrt(jnp.float32(self.num_hiddens))
X = self.pos_encoding(X, training=training)
attention_weights = [[None] * len(self.blks) for _ in range(2)]
for i, blk in enumerate(self.blks):
X, state, attention_w1, attention_w2 = blk(X, state,
# Decoder self-attention weights
attention_weights[0][i] = attention_w1
# Encoder-decoder attention weights
attention_weights[1][i] = attention_w2
# Flax sow API is used to capture intermediate variables
self.sow('intermediates', 'dec_attention_weights', attention_weights)
return self.dense(X), state
Let's instantiate an encoder--decoder model
by following the Transformer architecture.
Here we specify that
both the Transformer encoder and the Transformer decoder
have two layers using 4-head attention.
As in :numref:sec_seq2seq_training
we train the Transformer model
for sequence-to-sequence learning on the English--French machine translation dataset.
%%tab all
data = d2l.MTFraEng(batch_size=128)
num_hiddens, num_blks, dropout = 256, 2, 0.2
ffn_num_hiddens, num_heads = 64, 4
if tab.selected('tensorflow'):
key_size, query_size, value_size = 256, 256, 256
norm_shape = [2]
if tab.selected('pytorch', 'mxnet', 'jax'):
encoder = TransformerEncoder(
len(data.src_vocab), num_hiddens, ffn_num_hiddens, num_heads,
num_blks, dropout)
decoder = TransformerDecoder(
len(data.tgt_vocab), num_hiddens, ffn_num_hiddens, num_heads,
num_blks, dropout)
if tab.selected('mxnet', 'pytorch'):
model = d2l.Seq2Seq(encoder, decoder, tgt_pad=data.tgt_vocab['<pad>'],
trainer = d2l.Trainer(max_epochs=30, gradient_clip_val=1, num_gpus=1)
if tab.selected('jax'):
model = d2l.Seq2Seq(encoder, decoder, tgt_pad=data.tgt_vocab['<pad>'],
lr=0.001, training=True)
trainer = d2l.Trainer(max_epochs=30, gradient_clip_val=1, num_gpus=1)
if tab.selected('tensorflow'):
with d2l.try_gpu():
encoder = TransformerEncoder(
len(data.src_vocab), key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_hiddens, num_heads, num_blks, dropout)
decoder = TransformerDecoder(
len(data.tgt_vocab), key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_hiddens, num_heads, num_blks, dropout)
model = d2l.Seq2Seq(encoder, decoder, tgt_pad=data.tgt_vocab['<pad>'],
trainer = d2l.Trainer(max_epochs=30, gradient_clip_val=1)
trainer.fit(model, data)
After training, we use the Transformer model to [translate a few English sentences] into French and compute their BLEU scores.
%%tab all
engs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
if tab.selected('pytorch', 'mxnet', 'tensorflow'):
preds, _ = model.predict_step(
data.build(engs, fras), d2l.try_gpu(), data.num_steps)
if tab.selected('jax'):
preds, _ = model.predict_step(
trainer.state.params, data.build(engs, fras), data.num_steps)
for en, fr, p in zip(engs, fras, preds):
translation = []
for token in data.tgt_vocab.to_tokens(p):
if token == '<eos>':
print(f'{en} => {translation}, bleu,'
f'{d2l.bleu(" ".join(translation), fr, k=2):.3f}')
Let's [visualize the Transformer attention weights] when translating the final English sentence into French.
The shape of the encoder self-attention weights
is (number of encoder layers, number of attention heads, num_steps
or number of queries, num_steps
or number of key-value pairs).
%%tab pytorch, mxnet, tensorflow
_, dec_attention_weights = model.predict_step(
data.build([engs[-1]], [fras[-1]]), d2l.try_gpu(), data.num_steps, True)
enc_attention_weights = d2l.concat(model.encoder.attention_weights, 0)
shape = (num_blks, num_heads, -1, data.num_steps)
enc_attention_weights = d2l.reshape(enc_attention_weights, shape)
(num_blks, num_heads, data.num_steps, data.num_steps))
%%tab jax
_, (dec_attention_weights, enc_attention_weights) = model.predict_step(
trainer.state.params, data.build([engs[-1]], [fras[-1]]),
data.num_steps, True)
enc_attention_weights = d2l.concat(enc_attention_weights, 0)
shape = (num_blks, num_heads, -1, data.num_steps)
enc_attention_weights = d2l.reshape(enc_attention_weights, shape)
(num_blks, num_heads, data.num_steps, data.num_steps))
In the encoder self-attention, both queries and keys come from the same input sequence. Since padding tokens do not carry meaning, with specified valid length of the input sequence no query attends to positions of padding tokens. In the following, two layers of multi-head attention weights are presented row by row. Each head independently attends based on a separate representation subspace of queries, keys, and values.
%%tab mxnet, tensorflow, jax
enc_attention_weights, xlabel='Key positions', ylabel='Query positions',
titles=['Head %d' % i for i in range(1, 5)], figsize=(7, 3.5))
%%tab pytorch
enc_attention_weights.cpu(), xlabel='Key positions',
ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)],
figsize=(7, 3.5))
[To visualize the decoder self-attention weights and the encoder--decoder attention weights, we need more data manipulations.] For example, we fill the masked attention weights with zero. Note that the decoder self-attention weights and the encoder--decoder attention weights both have the same queries: the beginning-of-sequence token followed by the output tokens and possibly end-of-sequence tokens.
%%tab mxnet
dec_attention_weights_2d = [d2l.tensor(head[0]).tolist()
for step in dec_attention_weights
for attn in step for blk in attn for head in blk]
dec_attention_weights_filled = d2l.tensor(
dec_attention_weights = d2l.reshape(dec_attention_weights_filled, (
-1, 2, num_blks, num_heads, data.num_steps))
dec_self_attention_weights, dec_inter_attention_weights = \
dec_attention_weights.transpose(1, 2, 3, 0, 4)
%%tab pytorch
dec_attention_weights_2d = [head[0].tolist()
for step in dec_attention_weights
for attn in step for blk in attn for head in blk]
dec_attention_weights_filled = d2l.tensor(
shape = (-1, 2, num_blks, num_heads, data.num_steps)
dec_attention_weights = d2l.reshape(dec_attention_weights_filled, shape)
dec_self_attention_weights, dec_inter_attention_weights = \
dec_attention_weights.permute(1, 2, 3, 0, 4)
%%tab tensorflow
dec_attention_weights_2d = [head[0] for step in dec_attention_weights
for attn in step
for blk in attn for head in blk]
dec_attention_weights_filled = tf.convert_to_tensor(
dec_attention_weights = tf.reshape(dec_attention_weights_filled, shape=(
-1, 2, num_blks, num_heads, data.num_steps))
dec_self_attention_weights, dec_inter_attention_weights = tf.transpose(
dec_attention_weights, perm=(1, 2, 3, 0, 4))
%%tab jax
dec_attention_weights_2d = [head[0].tolist() for step in dec_attention_weights
for attn in step
for blk in attn for head in blk]
dec_attention_weights_filled = d2l.tensor(
dec_attention_weights = dec_attention_weights_filled.reshape(
(-1, 2, num_blks, num_heads, data.num_steps))
dec_self_attention_weights, dec_inter_attention_weights = \
dec_attention_weights.transpose(1, 2, 3, 0, 4)
%%tab all
(num_blks, num_heads, data.num_steps, data.num_steps))
(num_blks, num_heads, data.num_steps, data.num_steps))
Because of the autoregressive property of the decoder self-attention, no query attends to key--value pairs after the query position.
%%tab all
dec_self_attention_weights[:, :, :, :],
xlabel='Key positions', ylabel='Query positions',
titles=['Head %d' % i for i in range(1, 5)], figsize=(7, 3.5))
Similar to the case in the encoder self-attention, via the specified valid length of the input sequence, [no query from the output sequence attends to those padding tokens from the input sequence.]
%%tab all
dec_inter_attention_weights, xlabel='Key positions',
ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)],
figsize=(7, 3.5))
Although the Transformer architecture was originally proposed for sequence-to-sequence learning, as we will discover later in the book, either the Transformer encoder or the Transformer decoder is often individually used for different deep learning tasks.
The Transformer is an instance of the encoder--decoder architecture, though either the encoder or the decoder can be used individually in practice. In the Transformer architecture, multi-head self-attention is used for representing the input sequence and the output sequence, though the decoder has to preserve the autoregressive property via a masked version. Both the residual connections and the layer normalization in the Transformer are important for training a very deep model. The positionwise feed-forward network in the Transformer model transforms the representation at all the sequence positions using the same MLP.
- Train a deeper Transformer in the experiments. How does it affect the training speed and the translation performance?
- Is it a good idea to replace scaled dot product attention with additive attention in the Transformer? Why?
- For language modeling, should we use the Transformer encoder, decoder, or both? How would you design this method?
- What challenges can Transformers face if input sequences are very long? Why?
- How would you improve the computational and memory efficiency of Transformers? Hint: you may refer to the survey paper by :citet:
