Skip to content

FlashAttention 的简单理解

复制本地路径 | 在线编辑

参考文章:https://zhuanlan.zhihu.com/p/668888063

从 softmax 到 online-softmax

首先,softmax 的计算公式是:

\[ \operatorname{softmax}(x)_i = \frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}} \]

工程上,为了防止溢出,通常会减去一个最大值,确保指数项不会溢出:

\[ \operatorname{safe\text{-}softmax}(x)_i = \frac{e^{x_i - m}}{\sum_{j=1}^{n} e^{x_j - m}} \]
\[ m = \max(x_1, x_2, \ldots, x_n) \]

如下是伪代码:

下面介绍 online-softmax,它的目标是把上面的 3 个 for 循环变为 2 个 for 循环,方法如下:

构造 \(d'_i\)。其中,\(m_i = \max(x_1, \ldots, x_i)\),定义如下:

\[ d'_i \leftarrow d'_{i-1} e^{m_{i-1} - m_i} + e^{x_i - m_i} \]

根据定义,可以得到:

\[ d_N = d'_N = \sum_{j=1}^{N} e^{x_j - m_N} \]

并且因此可以递推地计算:

\[ d'_i = d'_{i-1} e^{m_{i-1} - m_i} + e^{x_i - m_i} \]
推导过程

所以,它的好处是可以变成两个 for 循环:

online-softmax 的好处(重点)

但是,变成两个 for 循环的好处是什么呢?仔细想想,其实时间复杂度根本就没有变。跑三个 for 循环计算次数是 3*N,现在变成两个 for 循环,但第一个循环中计算两次,所以计算次数还是 3*N,甚至细究起来,online-softmax 的计算量还更大一些。

理解这一点就是核心之处了。无论是 online-softmax 还是 FlashAttention,都不是从时间复杂度上进行优化,而是尽可能提高缓存友好性,减少程序在 SRAM 和更慢存储介质之间的数据交换。

三个循环的情况

我们先来看三个循环的情况。假设 SRAM 只能存储四个向量:

第一步,计算 \(m_i\),每次之后要进行交换,即:

  1. 计算 \(m_1\) 时,\(x_1\)\(m_0\) 在 SRAM 中,此时 SRAM 中有 \(m_0, x_1, m_1\)

  2. \(m_0\)\(x_1\) 从 SRAM 送出去(否则下一步最后会有五个向量)

  3. 计算 \(m_2\) 时,\(x_2\)\(m_1\) 在 SRAM 中,此时 SRAM 中有 \(m_1, x_2, m_2\)

  4. 同理,把 \(m_1\)\(x_2\) 从 SRAM 送出去

第二步,计算 \(d_i\),每次之后要进行交换,即:

  1. 计算 \(d_1\) 时,\(d_0\)\(m_N\)\(x_1\) 在 SRAM 中,此时 SRAM 中有 \(m_N, x_1, d_0, d_1\)

  2. \(x_1\)\(d_0\) 从 SRAM 送出去

  3. 计算 \(d_2\) 时,\(d_1\)\(m_N\)\(x_2\) 在 SRAM 中,此时 SRAM 中有 \(m_N, x_2, d_1, d_2\)

  4. 同理,把 \(x_2\)\(d_1\) 从 SRAM 送出去

第三步,计算 \(a_i\),每次之后要进行交换,即:

  1. 计算 \(a_1\) 时,\(d_N\)\(m_N\)\(x_1\) 在 SRAM 中,此时 SRAM 中有 \(m_N, x_1, d_N, a_1\)

  2. \(x_1\)\(a_1\) 从 SRAM 送出去

  3. 计算 \(a_2\) 时,\(d_N\)\(m_N\)\(x_2\) 在 SRAM 中,此时 SRAM 中有 \(m_N, x_2, d_N, a_2\)

  4. 同理,把 \(x_2\)\(a_2\) 从 SRAM 送出去

综上所述,每一轮里,基本都是从 SRAM 中读取,计算完之后再送出去,缓存友好性很差,甚至可以说基本没有。

两个循环的情况

对于第一个 for 循环,即:

  1. 计算 \(m_1\)\(d_1\) 时,\(m_0\)\(x_1\) 在 SRAM 中,此时 SRAM 中有 \(m_0, x_1, m_1, d_1\)

  2. \(m_0\)\(x_1\) 从 SRAM 送出去

  3. 计算 \(m_2\)\(d_2\) 时,\(m_1\)\(x_2\) 在 SRAM 中,此时 SRAM 中有 \(m_1, x_2, m_2, d_2\)

  4. \(m_1\)\(x_2\) 从 SRAM 送出去

对于第二个 for 循环,和上面三个循环中的最后一个是一样的。

所以总体来看,在两个循环的情况下,我们也是每次都要从 SRAM 中读取,计算完之后再送出去。但由于只需要两个循环,所以和 SRAM 交换的次数会少很多。

关键点就在于第一个循环做了合并,从而把 \(m_i\)\(d_i\) 的计算合并到一起,减少了和 SRAM 的交换次数。

FlashAttention V1

从这一小节开始,进入到 FlashAttention 部分。既然可以优化成两个 for 循环(2-pass),我们还能不能优化成 1-pass?遗憾的是,对于 safe-softmax,并不存在这样的 1-pass 算法。但是,Attention 的目标并不是求 softmax,而是求最终的输出 \(O\)

\[ \operatorname{Attention}(Q, K, V) = \operatorname{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V \]

原理和推导

原始计算 Attention 的情况,其实就是多了红框中的部分,其他部分就是上面介绍的 online-softmax 流程。

下面介绍 FlashAttention 的做法。它有点类似 online-softmax,也是定义一个新的变量:

\[ o'_i := \sum_{j=1}^{i} \frac{e^{x_j - m_i}}{d'_i} \, V[j,:] \]

由定义可知:

\[ o'_N = o_N := \sum_{j=1}^{N} \frac{e^{x_j - m_N}}{d'_N} \, V[j,:] \]

并且可以推导出 \(o'_i\)\(o'_{i-1}\) 的关系:

所以,最终可以合并为一个循环:

总结

可以看出,FlashAttention 的优化和 online-softmax 的优化思路很相似。通过一个新的定义,成功把 \(a_i\) 这一步隐去了,因此可以合并为一个循环。

FlashAttention V2

这一部分会很简略,甚至很多地方可能是错的。

V1 的分块操作

显然,正常情况下,对于向量肯定不是一个一个地操作,而是进行分块处理:

上面的 \(b\) 就是分块大小,下面是 FlashAttention V1 的伪代码:

我没有太细究这一部分,意会即可。具体可以看参考文章,里面还有各种细节分析,比如分块大小对速度的具体影响公式等等。(不过我感觉太细节了。)

V2 的优化

FlashAttention V2 其实没有大刀阔斧的优化,更多是在 V1 的基础上做一些细节优化。

减少非 matmul 的冗余计算,增加 Tensor Cores 运算比例

这个不了解,没细看。

增加 seqlen 维度的并行

FA1 中的伪代码是先 load K/V,再 load Q;在 FA2 中,算法调换了循环顺序,先 load Q,再 load K、V。

它的好处就是调整完之后,\(O_i\) 可以直接在外层计算,不像 FA1 中那样需要在内层计算。具体细节其实还有很多,比如上面的操作只会出现在 forward 中,backward 中没有调整顺序。具体细节看参考文章。

更好的 Warp Partitioning 策略,避免 Split-K

这一部分也没有细看,就放一张图吧。具体细节还是看参考文章。

总结

这是一个非常扎实且有效的工作。现在,FlashAttention 也是计算 Attention 的主要实现方式之一。其中的思想很有过去计算机时代的特点:空间很宝贵,因此要尽量提升缓存一致性。平时更普遍的思路往往是空间无所谓,优化计算量才是关键;过去很多节省空间的算法,也基本被遗弃在历史长河中了。但是在显存场景里,还是有很多空间利用的 trick,比较有 old-school 的感觉。

Comments

本文阅读 Loading 本站访问 Loading 访客 Loading