MHA
复制本地路径 | 在线编辑
Standard Attention
传统的 Attention 如下,就是 QKV 的矩阵乘法,其中 \(Q = W_Q X\),其他类似。
\[
Attention(Q, k, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V
\]
MHA
本质上是把原来的 Attention 计算拆分成多个小份的 Attention 计算,然后并行计算,最后再合回原来的维度。
$$
\begin{aligned}
head_i = Attention(Q_i, K_i, V_i) \
H = concat(head_1, head_2, ..., head_h) \
MHA(Q, K, V) = H W^O
\end{aligned}
$$
之所以这样,是希望能够理解输入不同部分之间的关系。因此切分成不同的头,每个头独立学习不同关系,比如第一个头学习的是主语和宾语之间的关系,第二个头学习的是谓语和宾语之间的关系,以此类推。
MQA 和 GQA
放到 LLM 中去讲了,个人笔记链接,因为这个涉及到 KV Cache,都是为了解决实际应用中显存占用过高的问题,是一个工程化的问题。