目录

多头注意力

目录

多头注意力即是在注意力的基础上,用独立学习得到的$h$组不同的 线性投影(linear projections)来变换查询、键和值。 然后,这$h$组变换后的查询、键和值将并行地送到注意力汇聚中。 最后,将这$h$个注意力汇聚的输出拼接在一起, 并且通过另一个可以学习的线性投影进行变换, 以产生最终输出。

对于$h$个注意力汇聚输出,每一个注意力汇聚都被称作一个(head)。 ./1.png

模型

在实现多头注意力之前,让我们用数学语言将这个模型形式化地描述出来。 给定查询$\mathbf{q} \in \mathbb{R}^{d_q}$、 键$\mathbf{k} \in \mathbb{R}^{d_k}$和 值$\mathbf{v} \in \mathbb{R}^{d_v}$, 每个注意力头$\mathbf{h}_i$($i = 1, \ldots, h$)的计算方法为:

$$\mathbf{h}_i = f(\mathbf W_i^{(q)}\mathbf q, \mathbf W_i^{(k)}\mathbf k,\mathbf W_i^{(v)}\mathbf v) \in \mathbb R^{p_v},$$

其中,可学习的参数包括 $\mathbf W_i^{(q)}\in\mathbb R^{p_q\times d_q}$、 $\mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k}$和 $\mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v}$, 以及代表注意力汇聚的函数$f$。

加性注意力和缩放点积注意力。 多头注意力的输出需要经过另一个线性转换, 它对应着$h$个头连结后的结果,因此其可学习参数是 $\mathbf W_o\in\mathbb R^{p_o\times h p_v}$:

$$\mathbf W_o \begin{bmatrix}\mathbf h_1\\vdots\\mathbf h_h\end{bmatrix} \in \mathbb{R}^{p_o}.$$

基于这种设计,每个头都可能会关注输入的不同部分, 可以表示比简单加权平均值更复杂的函数。

小结

  • 多头注意力融合了来自于多个注意力汇聚的不同知识,这些知识的不同来源于相同的查询、键和值的不同的子空间表示。
  • 基于适当的张量操作,可以实现多头注意力的并行计算。