%load_ext d2lbook.tab
tab.interact_select('mxnet', 'pytorch', 'tensorflow', 'jax')
Multi-Head Attention⚓︎
:label:sec_multihead-attention
In practice, given the same set of queries, keys, and values we may want our model to combine knowledge from different behaviors of the same attention mechanism, such as capturing dependencies of various ranges (e.g., shorter-range vs. longer-range) within a sequence. Thus, it may be beneficial to allow our attention mechanism to jointly use different representation subspaces of queries, keys, and values.
To this end, instead of performing
a single attention pooling,
queries, keys, and values
can be transformed
with \(h\) independently learned linear projections.
Then these \(h\) projected queries, keys, and values
are fed into attention pooling in parallel.
In the end,
\(h\) attention-pooling outputs
are concatenated and
transformed with another learned linear projection
to produce the final output.
This design
is called multi-head attention,
where each of the \(h\) attention pooling outputs
is a head :cite:Vaswani.Shazeer.Parmar.ea.2017
.
Using fully connected layers
to perform learnable linear transformations,
:numref:fig_multi-head-attention
describes multi-head attention.
:label:fig_multi-head-attention
%%tab mxnet
from d2l import mxnet as d2l
import math
from mxnet import autograd, np, npx
from mxnet.gluon import nn
npx.set_np()
%%tab pytorch
from d2l import torch as d2l
import math
import torch
from torch import nn
%%tab tensorflow
from d2l import tensorflow as d2l
import tensorflow as tf
%%tab jax
from d2l import jax as d2l
from flax import linen as nn
from jax import numpy as jnp
import jax
Model⚓︎
Before providing the implementation of multi-head attention, let's formalize this model mathematically. Given a query \(\mathbf{q} \in \mathbb{R}^{d_q}\), a key \(\mathbf{k} \in \mathbb{R}^{d_k}\), and a value \(\mathbf{v} \in \mathbb{R}^{d_v}\), each attention head \(\mathbf{h}_i\) (\(i = 1, \ldots, h\)) is computed as
where
\(\mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q}\),
\(\mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k}\),
and \(\mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v}\)
are learnable parameters and
\(f\) is attention pooling,
such as
additive attention and scaled dot product attention
in :numref:sec_attention-scoring-functions
.
The multi-head attention output
is another linear transformation via
learnable parameters
\(\mathbf W_o\in\mathbb R^{p_o\times h p_v}\)
of the concatenation of \(h\) heads:
Based on this design, each head may attend to different parts of the input. More sophisticated functions than the simple weighted average can be expressed.
Implementation⚓︎
In our implementation,
we [choose the scaled dot product attention
for each head] of the multi-head attention.
To avoid significant growth of computational cost and parametrization cost,
we set \(p_q = p_k = p_v = p_o / h\).
Note that \(h\) heads can be computed in parallel
if we set the number of outputs
of linear transformations
for the query, key, and value
to \(p_q h = p_k h = p_v h = p_o\).
In the following implementation,
\(p_o\) is specified via the argument num_hiddens
.
%%tab mxnet
class MultiHeadAttention(d2l.Module): #@save
"""Multi-head attention."""
def __init__(self, num_hiddens, num_heads, dropout, use_bias=False,
**kwargs):
super().__init__()
self.num_heads = num_heads
self.attention = d2l.DotProductAttention(dropout)
self.W_q = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
self.W_k = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
self.W_v = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
self.W_o = nn.Dense(num_hiddens, use_bias=use_bias, flatten=False)
def forward(self, queries, keys, values, valid_lens):
# Shape of queries, keys, or values:
# (batch_size, no. of queries or key-value pairs, num_hiddens)
# Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
# After transposing, shape of output queries, keys, or values:
# (batch_size * num_heads, no. of queries or key-value pairs,
# num_hiddens / num_heads)
queries = self.transpose_qkv(self.W_q(queries))
keys = self.transpose_qkv(self.W_k(keys))
values = self.transpose_qkv(self.W_v(values))
if valid_lens is not None:
# On axis 0, copy the first item (scalar or vector) for num_heads
# times, then copy the next item, and so on
valid_lens = valid_lens.repeat(self.num_heads, axis=0)
# Shape of output: (batch_size * num_heads, no. of queries,
# num_hiddens / num_heads)
output = self.attention(queries, keys, values, valid_lens)
# Shape of output_concat: (batch_size, no. of queries, num_hiddens)
output_concat = self.transpose_output(output)
return self.W_o(output_concat)
%%tab pytorch
class MultiHeadAttention(d2l.Module): #@save
"""Multi-head attention."""
def __init__(self, num_hiddens, num_heads, dropout, bias=False, **kwargs):
super().__init__()
self.num_heads = num_heads
self.attention = d2l.DotProductAttention(dropout)
self.W_q = nn.LazyLinear(num_hiddens, bias=bias)
self.W_k = nn.LazyLinear(num_hiddens, bias=bias)
self.W_v = nn.LazyLinear(num_hiddens, bias=bias)
self.W_o = nn.LazyLinear(num_hiddens, bias=bias)
def forward(self, queries, keys, values, valid_lens):
# Shape of queries, keys, or values:
# (batch_size, no. of queries or key-value pairs, num_hiddens)
# Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
# After transposing, shape of output queries, keys, or values:
# (batch_size * num_heads, no. of queries or key-value pairs,
# num_hiddens / num_heads)
queries = self.transpose_qkv(self.W_q(queries))
keys = self.transpose_qkv(self.W_k(keys))
values = self.transpose_qkv(self.W_v(values))
if valid_lens is not None:
# On axis 0, copy the first item (scalar or vector) for num_heads
# times, then copy the next item, and so on
valid_lens = torch.repeat_interleave(
valid_lens, repeats=self.num_heads, dim=0)
# Shape of output: (batch_size * num_heads, no. of queries,
# num_hiddens / num_heads)
output = self.attention(queries, keys, values, valid_lens)
# Shape of output_concat: (batch_size, no. of queries, num_hiddens)
output_concat = self.transpose_output(output)
return self.W_o(output_concat)
%%tab tensorflow
class MultiHeadAttention(d2l.Module): #@save
"""Multi-head attention."""
def __init__(self, key_size, query_size, value_size, num_hiddens,
num_heads, dropout, bias=False, **kwargs):
super().__init__()
self.num_heads = num_heads
self.attention = d2l.DotProductAttention(dropout)
self.W_q = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
self.W_k = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
self.W_v = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
self.W_o = tf.keras.layers.Dense(num_hiddens, use_bias=bias)
def call(self, queries, keys, values, valid_lens, **kwargs):
# Shape of queries, keys, or values:
# (batch_size, no. of queries or key-value pairs, num_hiddens)
# Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
# After transposing, shape of output queries, keys, or values:
# (batch_size * num_heads, no. of queries or key-value pairs,
# num_hiddens / num_heads)
queries = self.transpose_qkv(self.W_q(queries))
keys = self.transpose_qkv(self.W_k(keys))
values = self.transpose_qkv(self.W_v(values))
if valid_lens is not None:
# On axis 0, copy the first item (scalar or vector) for num_heads
# times, then copy the next item, and so on
valid_lens = tf.repeat(valid_lens, repeats=self.num_heads, axis=0)
# Shape of output: (batch_size * num_heads, no. of queries,
# num_hiddens / num_heads)
output = self.attention(queries, keys, values, valid_lens, **kwargs)
# Shape of output_concat: (batch_size, no. of queries, num_hiddens)
output_concat = self.transpose_output(output)
return self.W_o(output_concat)
%%tab jax
class MultiHeadAttention(nn.Module): #@save
num_hiddens: int
num_heads: int
dropout: float
bias: bool = False
def setup(self):
self.attention = d2l.DotProductAttention(self.dropout)
self.W_q = nn.Dense(self.num_hiddens, use_bias=self.bias)
self.W_k = nn.Dense(self.num_hiddens, use_bias=self.bias)
self.W_v = nn.Dense(self.num_hiddens, use_bias=self.bias)
self.W_o = nn.Dense(self.num_hiddens, use_bias=self.bias)
@nn.compact
def __call__(self, queries, keys, values, valid_lens, training=False):
# Shape of queries, keys, or values:
# (batch_size, no. of queries or key-value pairs, num_hiddens)
# Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
# After transposing, shape of output queries, keys, or values:
# (batch_size * num_heads, no. of queries or key-value pairs,
# num_hiddens / num_heads)
queries = self.transpose_qkv(self.W_q(queries))
keys = self.transpose_qkv(self.W_k(keys))
values = self.transpose_qkv(self.W_v(values))
if valid_lens is not None:
# On axis 0, copy the first item (scalar or vector) for num_heads
# times, then copy the next item, and so on
valid_lens = jnp.repeat(valid_lens, self.num_heads, axis=0)
# Shape of output: (batch_size * num_heads, no. of queries,
# num_hiddens / num_heads)
output, attention_weights = self.attention(
queries, keys, values, valid_lens, training=training)
# Shape of output_concat: (batch_size, no. of queries, num_hiddens)
output_concat = self.transpose_output(output)
return self.W_o(output_concat), attention_weights
To allow for [parallel computation of multiple heads],
the above MultiHeadAttention
class uses two transposition methods as defined below.
Specifically,
the transpose_output
method reverses the operation
of the transpose_qkv
method.
%%tab mxnet
@d2l.add_to_class(MultiHeadAttention) #@save
def transpose_qkv(self, X):
"""Transposition for parallel computation of multiple attention heads."""
# Shape of input X: (batch_size, no. of queries or key-value pairs,
# num_hiddens). Shape of output X: (batch_size, no. of queries or
# key-value pairs, num_heads, num_hiddens / num_heads)
X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
# Shape of output X: (batch_size, num_heads, no. of queries or key-value
# pairs, num_hiddens / num_heads)
X = X.transpose(0, 2, 1, 3)
# Shape of output: (batch_size * num_heads, no. of queries or key-value
# pairs, num_hiddens / num_heads)
return X.reshape(-1, X.shape[2], X.shape[3])
@d2l.add_to_class(MultiHeadAttention) #@save
def transpose_output(self, X):
"""Reverse the operation of transpose_qkv."""
X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
X = X.transpose(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)
%%tab pytorch
@d2l.add_to_class(MultiHeadAttention) #@save
def transpose_qkv(self, X):
"""Transposition for parallel computation of multiple attention heads."""
# Shape of input X: (batch_size, no. of queries or key-value pairs,
# num_hiddens). Shape of output X: (batch_size, no. of queries or
# key-value pairs, num_heads, num_hiddens / num_heads)
X = X.reshape(X.shape[0], X.shape[1], self.num_heads, -1)
# Shape of output X: (batch_size, num_heads, no. of queries or key-value
# pairs, num_hiddens / num_heads)
X = X.permute(0, 2, 1, 3)
# Shape of output: (batch_size * num_heads, no. of queries or key-value
# pairs, num_hiddens / num_heads)
return X.reshape(-1, X.shape[2], X.shape[3])
@d2l.add_to_class(MultiHeadAttention) #@save
def transpose_output(self, X):
"""Reverse the operation of transpose_qkv."""
X = X.reshape(-1, self.num_heads, X.shape[1], X.shape[2])
X = X.permute(0, 2, 1, 3)
return X.reshape(X.shape[0], X.shape[1], -1)
%%tab tensorflow
@d2l.add_to_class(MultiHeadAttention) #@save
def transpose_qkv(self, X):
"""Transposition for parallel computation of multiple attention heads."""
# Shape of input X: (batch_size, no. of queries or key-value pairs,
# num_hiddens). Shape of output X: (batch_size, no. of queries or
# key-value pairs, num_heads, num_hiddens / num_heads)
X = tf.reshape(X, shape=(X.shape[0], X.shape[1], self.num_heads, -1))
# Shape of output X: (batch_size, num_heads, no. of queries or key-value
# pairs, num_hiddens / num_heads)
X = tf.transpose(X, perm=(0, 2, 1, 3))
# Shape of output: (batch_size * num_heads, no. of queries or key-value
# pairs, num_hiddens / num_heads)
return tf.reshape(X, shape=(-1, X.shape[2], X.shape[3]))
@d2l.add_to_class(MultiHeadAttention) #@save
def transpose_output(self, X):
"""Reverse the operation of transpose_qkv."""
X = tf.reshape(X, shape=(-1, self.num_heads, X.shape[1], X.shape[2]))
X = tf.transpose(X, perm=(0, 2, 1, 3))
return tf.reshape(X, shape=(X.shape[0], X.shape[1], -1))
%%tab jax
@d2l.add_to_class(MultiHeadAttention) #@save
def transpose_qkv(self, X):
"""Transposition for parallel computation of multiple attention heads."""
# Shape of input X: (batch_size, no. of queries or key-value pairs,
# num_hiddens). Shape of output X: (batch_size, no. of queries or
# key-value pairs, num_heads, num_hiddens / num_heads)
X = X.reshape((X.shape[0], X.shape[1], self.num_heads, -1))
# Shape of output X: (batch_size, num_heads, no. of queries or key-value
# pairs, num_hiddens / num_heads)
X = jnp.transpose(X, (0, 2, 1, 3))
# Shape of output: (batch_size * num_heads, no. of queries or key-value
# pairs, num_hiddens / num_heads)
return X.reshape((-1, X.shape[2], X.shape[3]))
@d2l.add_to_class(MultiHeadAttention) #@save
def transpose_output(self, X):
"""Reverse the operation of transpose_qkv."""
X = X.reshape((-1, self.num_heads, X.shape[1], X.shape[2]))
X = jnp.transpose(X, (0, 2, 1, 3))
return X.reshape((X.shape[0], X.shape[1], -1))
Let's [test our implemented] MultiHeadAttention
class
using a toy example where keys and values are the same.
As a result,
the shape of the multi-head attention output
is (batch_size
, num_queries
, num_hiddens
).
%%tab pytorch
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, num_kvpairs = 2, 4, 6
valid_lens = d2l.tensor([3, 2])
X = d2l.ones((batch_size, num_queries, num_hiddens))
Y = d2l.ones((batch_size, num_kvpairs, num_hiddens))
d2l.check_shape(attention(X, Y, Y, valid_lens),
(batch_size, num_queries, num_hiddens))
%%tab mxnet
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
attention.initialize()
%%tab jax
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_heads, 0.5)
%%tab tensorflow
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
%%tab mxnet
batch_size, num_queries, num_kvpairs = 2, 4, 6
valid_lens = d2l.tensor([3, 2])
X = d2l.ones((batch_size, num_queries, num_hiddens))
Y = d2l.ones((batch_size, num_kvpairs, num_hiddens))
d2l.check_shape(attention(X, Y, Y, valid_lens),
(batch_size, num_queries, num_hiddens))
%%tab tensorflow
batch_size, num_queries, num_kvpairs = 2, 4, 6
valid_lens = d2l.tensor([3, 2])
X = tf.ones((batch_size, num_queries, num_hiddens))
Y = tf.ones((batch_size, num_kvpairs, num_hiddens))
d2l.check_shape(attention(X, Y, Y, valid_lens, training=False),
(batch_size, num_queries, num_hiddens))
%%tab jax
batch_size, num_queries, num_kvpairs = 2, 4, 6
valid_lens = d2l.tensor([3, 2])
X = d2l.ones((batch_size, num_queries, num_hiddens))
Y = d2l.ones((batch_size, num_kvpairs, num_hiddens))
d2l.check_shape(attention.init_with_output(d2l.get_key(), X, Y, Y, valid_lens,
training=False)[0][0],
(batch_size, num_queries, num_hiddens))
Summary⚓︎
Multi-head attention combines knowledge of the same attention pooling via different representation subspaces of queries, keys, and values. To compute multiple heads of multi-head attention in parallel, proper tensor manipulation is needed.
Exercises⚓︎
- Visualize attention weights of multiple heads in this experiment.
- Suppose that we have a trained model based on multi-head attention and we want to prune less important attention heads to increase the prediction speed. How can we design experiments to measure the importance of an attention head?
:begin_tab:mxnet
Discussions
:end_tab:
:begin_tab:pytorch
Discussions
:end_tab:
:begin_tab:tensorflow
Discussions
:end_tab:
:begin_tab:jax
Discussions
:end_tab:
创建日期: November 25, 2023