# 多头注意力
🏷 sec_multihead-attention
在实践中,当给定相同的查询、键和值的集合时,
我们希望模型可以基于相同的注意力机制学习到不同的行为,
然后将不同的行为作为知识组合起来,
捕获序列内各种范围的依赖关系
(例如,短距离依赖和长距离依赖关系)。
因此,允许注意力机制组合使用查询、键和值的不同
子空间表示(representation subspaces)可能是有益的。
为此,与其只使用单独一个注意力汇聚,
我们可以用独立学习得到的 组不同的
线性投影(linear projections)来变换查询、键和值。
然后,这 组变换后的查询、键和值将并行地送到注意力汇聚中。
最后,将这 个注意力汇聚的输出拼接在一起,
并且通过另一个可以学习的线性投影进行变换,
以产生最终输出。
这种设计被称为多头注意力(multihead attention)
:cite: Vaswani.Shazeer.Parmar.ea.2017
。
对于 个注意力汇聚输出,每一个注意力汇聚都被称作一个头(head)。
:numref: fig_multi-head-attention
展示了使用全连接层来实现可学习的线性变换的多头注意力。
🏷 fig_multi-head-attention
# 模型
在实现多头注意力之前,让我们用数学语言将这个模型形式化地描述出来。
给定查询、
键 和
值,
每个注意力头()的计算方法为:
其中,可学习的参数包括
、
和
,
以及代表注意力汇聚的函数。
可以是 :numref: sec_attention-scoring-functions
中的
加性注意力和缩放点积注意力。
多头注意力的输出需要经过另一个线性转换,
它对应着 个头连结后的结果,因此其可学习参数是
:
基于这种设计,每个头都可能会关注输入的不同部分,
可以表示比简单加权平均值更复杂的函数。
import math | |
import torch | |
from torch import nn | |
from d2l import torch as d2l |
# 实现
在实现过程中通常 [选择缩放点积注意力作为每一个注意力头]。
为了避免计算代价和参数代价的大幅增长,
我们设定。
值得注意的是,如果将查询、键和值的线性变换的输出数量设置为
,
则可以并行计算 个头。
在下面的实现中, 是通过参数 num_hiddens
指定的。
#@save | |
class MultiHeadAttention(nn.Module): | |
"""多头注意力""" | |
def __init__(self, key_size, query_size, value_size, num_hiddens, | |
num_heads, dropout, bias=False, **kwargs): | |
super(MultiHeadAttention, self).__init__(**kwargs) | |
self.num_heads = num_heads | |
self.attention = d2l.DotProductAttention(dropout) | |
self.W_q = nn.Linear(query_size, num_hiddens, bias=bias) | |
self.W_k = nn.Linear(key_size, num_hiddens, bias=bias) | |
self.W_v = nn.Linear(value_size, num_hiddens, bias=bias) | |
self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias) | |
def forward(self, queries, keys, values, valid_lens): | |
# queries,keys,values 的形状: | |
# (batch_size,查询或者 “键-值” 对的个数,num_hiddens) | |
# valid_lens 的形状: | |
# (batch_size,) 或 (batch_size,查询的个数) | |
# 经过变换后,输出的 queries,keys,values 的形状: | |
# (batch_size*num_heads,查询或者 “键-值” 对的个数, | |
# num_hiddens/num_heads) | |
queries = transpose_qkv(self.W_q(queries), self.num_heads) | |
keys = transpose_qkv(self.W_k(keys), self.num_heads) | |
values = transpose_qkv(self.W_v(values), self.num_heads) | |
if valid_lens is not None: | |
# 在轴 0,将第一项(标量或者矢量)复制 num_heads 次, | |
# 然后如此复制第二项,然后诸如此类。 | |
valid_lens = torch.repeat_interleave( | |
valid_lens, repeats=self.num_heads, dim=0) | |
# output 的形状:(batch_size*num_heads,查询的个数, | |
# num_hiddens/num_heads) | |
output = self.attention(queries, keys, values, valid_lens) | |
# output_concat 的形状:(batch_size,查询的个数,num_hiddens) | |
output_concat = transpose_output(output, self.num_heads) | |
return self.W_o(output_concat) |
为了能够 [使多个头并行计算],
上面的 MultiHeadAttention
类将使用下面定义的两个转置函数。
具体来说, transpose_output
函数反转了 transpose_qkv
函数的操作。
#@save | |
def transpose_qkv(X, num_heads): | |
"""为了多注意力头的并行计算而变换形状""" | |
# 输入 X 的形状:(batch_size,查询或者 “键-值” 对的个数,num_hiddens) | |
# 输出 X 的形状:(batch_size,查询或者 “键-值” 对的个数,num_heads, | |
# num_hiddens/num_heads) | |
X = X.reshape(X.shape[0], X.shape[1], num_heads, -1) | |
# 输出 X 的形状:(batch_size,num_heads,查询或者 “键-值” 对的个数, | |
# num_hiddens/num_heads) | |
X = X.permute(0, 2, 1, 3) | |
# 最终输出的形状:(batch_size*num_heads, 查询或者 “键-值” 对的个数, | |
# num_hiddens/num_heads) | |
return X.reshape(-1, X.shape[2], X.shape[3]) | |
#@save | |
def transpose_output(X, num_heads): | |
"""逆转transpose_qkv函数的操作""" | |
X = X.reshape(-1, num_heads, X.shape[1], X.shape[2]) | |
X = X.permute(0, 2, 1, 3) | |
return X.reshape(X.shape[0], X.shape[1], -1) |
下面使用键和值相同的小例子来 [测试] 我们编写的 MultiHeadAttention
类。
多头注意力输出的形状是( batch_size
, num_queries
, num_hiddens
)。
num_hiddens, num_heads = 100, 5 | |
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens, | |
num_hiddens, num_heads, 0.5) | |
attention.eval() |
MultiHeadAttention(
(attention): DotProductAttention(
(dropout): Dropout(p=0.5, inplace=False)
)
(W_q): Linear(in_features=100, out_features=100, bias=False)
(W_k): Linear(in_features=100, out_features=100, bias=False)
(W_v): Linear(in_features=100, out_features=100, bias=False)
(W_o): Linear(in_features=100, out_features=100, bias=False)
)
batch_size, num_queries = 2, 4 | |
num_kvpairs, valid_lens = 6, torch.tensor([3, 2]) | |
X = torch.ones((batch_size, num_queries, num_hiddens)) | |
Y = torch.ones((batch_size, num_kvpairs, num_hiddens)) | |
attention(X, Y, Y, valid_lens).shape |
torch.Size([2, 4, 100])
# 小结
- 多头注意力融合了来自于多个注意力汇聚的不同知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。
- 基于适当的张量操作,可以实现多头注意力的并行计算。
# 练习
- 分别可视化这个实验中的多个头的注意力权重。
- 假设有一个完成训练的基于多头注意力的模型,现在希望修剪最不重要的注意力头以提高预测速度。如何设计实验来衡量注意力头的重要性呢?
Discussions