多头注意力
多头注意力即是在注意力的基础上,用独立学习得到的$h$组不同的 线性投影(linear projections)来变换查询、键和值。 然后,这$h$组变换后的查询、键和值将并行地送到注意力汇聚中。 最后,将这$h$个注意力汇聚的输出拼接在一起, 并且通过另一个可以学习的线性投影进行变换, 以产生最终输出。
对于$h$个注意力汇聚输出,每一个注意力汇聚都被称作一个头(head)。
模型
在实现多头注意力之前,让我们用数学语言将这个模型形式化地描述出来。 给定查询$\mathbf{q} \in \mathbb{R}^{d_q}$、 键$\mathbf{k} \in \mathbb{R}^{d_k}$和 值$\mathbf{v} \in \mathbb{R}^{d_v}$, 每个注意力头$\mathbf{h}_i$($i = 1, \ldots, h$)的计算方法为:
$$\mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v},$$
其中,可学习的参数包括 $\mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q}$、 $\mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k}$和 $\mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v}$, 以及代表注意力汇聚的函数$f$。
加性注意力和缩放点积注意力。 多头注意力的输出需要经过另一个线性转换, 它对应着$h$个头连结后的结果,因此其可学习参数是 $\mathbf W_o\in\mathbb R^{p_o\times h p_v}$:
$$\mathbf W_o \begin{bmatrix}\mathbf h_1\\vdots\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o}.$$
基于这种设计,每个头都可能会关注输入的不同部分, 可以表示比简单加权平均值更复杂的函数。
小结
- 多头注意力融合了来自于多个注意力汇聚的不同知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。
- 基于适当的张量操作,可以实现多头注意力的并行计算。