FlashAttention:从 1 到 3

本文将首先从公式出发,讲解 3-pass safe softmax、 2-pass online softmax 以及 1-pass FlashAttention 的原理。然后结合论文介绍 FlashAttention 1/2/3。

从 saft softmax 到 FlashAttention1

Self-Attention 的计算公式如下(其中,\(O,Q,K,V\) 都是 \((N,d)\) 形状的矩阵,\(N\) 是 sequence length,\(d\) 是 head dimension,为了简单起见,忽略了缩放因子 \(\frac{1}{\sqrt {d}}\)):

\[ O=\mathrm{softmax} (QK^T)V \tag{1} \]

标准的计算流程分为几步:

\[ \begin{align} X & =Q K^{T} \tag{2}\\ A & =\operatorname{softmax}(X) \tag{3}\\ O & =A V \tag{4} \end{align} \]

其中,\(X\) 是 softmax 前的临时值,\(A\) 是注意力分数,\(O\) 是输出。

在 FlashAttention 中,不需要在全局内存中生成 X 和 A,而是将公式 1 的计算融合到一个单一的 cuda kernel 中。

在像矩阵乘法这样的经典算法中,可以使用分块(tiling)的方法,确保片上内存可以容纳下参与计算的数据。分块的前提是,运算满足结合律,即整个矩阵乘法的结果可以分解为许多分块矩阵乘法的和。

但是在 Self-Attention 中,softmax 看上去并不满足结合律,因此无法像矩阵乘法那样分块运算,有什么方法可以让 softmax 也满足结合律吗?

首先来回顾一下 softmax 是如何计算的,下面是标准 softmax 的计算公式:

\[ \mathrm{soft}\max \left( \left\{ x_1,x_2,\dots ,x_N \right\} \right) =\left\{ \frac{e^{x_i}}{\sum\nolimits_{j=1}^N{e^{x_j}}} \right\} _{i=1}^{N} \tag{5} \]

由于 \(e^x\) 非常容易溢出,因此实际应用中,一般会通过一个等价变换变成 safe softmax:

\[ \frac{e^{x_i}}{\sum\nolimits_{j=1}^N{e^{x_j}}}=\frac{e^{x_i-m}}{\sum\nolimits_{j=1}^N{e^{x_j-m}}} \tag{6} \]

其中 \(m=\max _{j=1}^{N}\left( x_j \right)\)

对 safe softmax 的计算包括三次遍历,一般也叫 3-pass safe softmax

\(m_i\) 表示 \(\max _{j=1}^{i}\left( x_j \right)\)\(m_0=-\infty\)\(l_i\) 表示 \(\sum\nolimits_{j=1}^N{e^{x_j-m_N}}\)\(l_0=0\)\(a_i\) 表示最终结果

第一次 pass 得到最大值 \(m_N\)

\[ m_i=\max \left( m_{i-1},x_i \right) \tag{7} \]

第二次 pass 得到分母项 \(l_N\)

\[ l_i=l_{i-1}+e^{x_i-m_N} \tag{8} \]

第三次 pass 得到 \(a_i\)

\[ a_i=\frac{e^{x_i-m_N}}{l_N} \tag{9} \]

看上去公式 8 依赖于 \(m_N\),公式 9 依赖于 \(m_N\)\(l_N\),似乎没有办法减少 pass 的次数?其实是可以的,这个技巧是用 \(l_{i}^{\prime}=\sum\nolimits_{j=1}^i{e^{x_j-m_i}}\) 替换 \(l_i\),它们的最终结果 \(l_{N}^{\prime}\)\(l_N\) 是相同的,因此公式 9 可以简单替换一下写成 \(a_i=\frac{e^{x_i-m_N}}{l_{N}^{\prime}}\)

\(l_{i}^{\prime}\) 的更新方法是:

\[ \begin{align} l_{i}^{\prime}&=\sum_{j=1}^i{e^{x_j-m_i}}\\ &=\left( \sum_{j=1}^{i-1}{e^{x_j-m_i}} \right) +e^{x_i-m_i}\\ &=\left( \sum_{j=1}^{i-1}{e^{x_j-m_{i-1}}} \right) e^{m_{i-1}-m_i}+e^{x_i-m_i}\\ &=l_{i-1}^{\prime}e^{m_{i-1}-m_i}+e^{x_i-m_i}\\ \end{align} \tag{10} \]

可以看出 \(l_{i}^{\prime}\) 的更新就只依赖于 \(m_{i-1}\)\(m_i\) 而不是依赖于 \(m_N\)

这样就可以将公式 7 和公式 8 的两次 pass 合并为一次 pass,这种算法也被称为 2-pass online softmax

第一次 pass 得到 \(m_N\)\(l_N\)

\[ \begin{align} m_i&=\max \left( m_{i-1},x_i \right)\\ l_{i}^{\prime}&=l_{i-1}^{\prime}e^{m_{i-1}-m_i}+e^{x_i-m_i}\\ \end{align} \]

第二次 pass 得到 \(a_i\)

\[ a_i=\frac{e^{x_i-m_N}}{l_{N}^{\prime}} \]

第一次 pass 中,用 \(m_i\)\(l_{i-1}^{\prime}\) 更新的行为在论文中叫做 rescale,FA2 论文中还会提到这一点。

但是这依然需要两次 pass,是否有方法能够实现一次 pass 的 softmax?这实际上是不可能的,但是在 Self-Attention 中,我们的目标并不是注意力分数 \(A\),而是输出 \(O\),而 \(O\) 的一次 pass 是存在的。

下面首先介绍一下 Multi-pass Self-Attention。由于所有行的计算是独立的,为了简单只展示第 k 行的计算。

\(Q[k,:]\) 表示 \(Q\) 的第 k 行,\(K^T[:,i]\) 表示 \(K^T\) 的第 i 列,\(O[k,:]\) 表示 \(O\) 的第 k 行,\(V[i,:]\) 表示 \(V\) 的第 i 行,\(o_i=\sum\nolimits_{j=1}^i{a_jV\left[ j,: \right]}\) 是一个用来存 \(A[k,:i]\times V[:,:i]\) 部分结果的行向量。

第一次 pass,得到 \(m_N\)\(l_N\)

\[ \begin{align} x_i&=Q[k,:]K^T[:,i]\\ m_i&=\max \left( m_{i-1},x_i \right)\\ l_{i}^{\prime}&=l_{i-1}^{\prime}e^{m_{i-1}-m_i}+e^{x_i-m_i}\\ \end{align} \]

第二次 pass,得到 \(\boldsymbol{o}_N\)

\[ \begin{align} a_i&=\frac{e^{x_i-m_N}}{l_{N}^{\prime}} \tag{11}\\ \boldsymbol{o}_i&=\boldsymbol{o}_{i-1}+a_iV[i,:] \tag{12}\\ \end{align} \]

最后将 \(\boldsymbol{o}_N\) 写入到 \(O\)

\[ O[k,:]=\boldsymbol{o}_N \]

将公式 11 代入进公式 12,可以得到:

\[ \boldsymbol{o}_i=\sum_{j=1}^i{\frac{e^{x_j-m_N}}{l_{N}^{\prime}}V[j,:]} \tag{13} \]

下面是 Multi-pass Self-Attention 的示意图。

\(\boldsymbol{o}_N\) 的结果依赖于 \(m_N\)\(l_{N}^{\prime}\),这和之前在 3-pass safe softmax 中遇到的问题很相似,因此解决技巧也是相似的,使用 \(\boldsymbol{o}^{\prime}\) 替换 \(\boldsymbol{o}\)

\[ \boldsymbol{o}_{i}^{\prime}=\left( \sum_{j=1}^i{\frac{e^{x_j-m_i}}{l_{i}^{\prime}}V[j,:]} \right) \]

\(\boldsymbol{o}_N^{\prime}\)\(\boldsymbol{o}_N\) 显然相同,\(\boldsymbol{o}_i^{\prime}\) 的更新公式是:

\[ \begin{align} \boldsymbol{o}_{i}^{\prime}&=\left( \sum_{j=1}^i{\frac{e^{x_j-m_i}}{l_{i}^{\prime}}V[j,:]} \right)\\ &=\left( \sum_{j=1}^{i-1}{\frac{e^{x_j-m_i}}{l_{i}^{\prime}}V[j,:]} \right) +\frac{e^{x_i-m_i}}{l_{i}^{\prime}}V[i,:]\\ &=\left( \sum_{j=1}^{i-1}{\frac{e^{x_j-m_{i-1}}}{l_{i-1}^{\prime}}\frac{e^{x_j-m_i}}{e^{x_j-m_{i-1}}}\frac{l_{i-1}^{\prime}}{l_{i}^{\prime}}V[j,:]} \right) +\frac{e^{x_i-m_i}}{l_{i}^{\prime}}V[i,:]\\ &=\left( \sum_{j=1}^{i-1}{\frac{e^{x_j-m_{i-1}}}{l_{i-1}^{\prime}}V[j,:]} \right) \frac{l_{i-1}^{\prime}}{l_{i}^{\prime}}e^{m_{i-1}-m_i}+\frac{e^{x_i-m_i}}{l_{i}^{\prime}}V[i,:]\\ &=\boldsymbol{o}_{i-1}^{\prime} \frac{l_{i-1}^{\prime}e^{m_{i-1}-m_i}}{l_{i}^{\prime}}+\frac{e^{x_i-m_i}}{l_{i}^{\prime}}V[i,:]\\ \end{align} \tag{14} \]

\(\boldsymbol{o}_i^{\prime}\) 依赖于 \(m_i,m_{i-1},l_{i}^{\prime},l_{i-1}^{\prime}\) 而不是 \(m_N,l_{N}^{\prime}\)。这样所有的计算只需要一次 pass 即可完成,这就是 FlashAttention

\[ \begin{align} x_i&=Q[k,:]K^T[:,i]\\ m_i&=\max \left( m_{i-1},x_i \right)\\ l_{i}^{\prime}&=l_{i-1}^{\prime}e^{m_{i-1}-m_i}+e^{x_i-m_i}\\ \boldsymbol{o}_{i}^{\prime}&=\boldsymbol{o}_{i-1}^{\prime}\frac{l_{i-1}^{\prime}e^{m_{i-1}-m_i}}{l_{i}^{\prime}}+\frac{e^{x_i-m_i}}{l_{i}^{\prime}}V[i,:]\\ \end{align} \tag{15} \]

\[ O[k,:]=\boldsymbol{o}_N^{\prime} \]

\(x_i,m_i,l_{i}^{\prime},\boldsymbol{o}_i^{\prime}\) 的存储占用较小,可以放入 GPU 的共享内存中。并且由于该算法中的所有操作都是满足结合律的,因此与分块(tiling)兼容。

FlashAttention(Tiling)

\(K^T\) 分成多个块,假设 \(b\) 表示 tile 的 block size,\(tiles\) 表示行方向有多少个 tile,\(N=b\times tiles\)\(\boldsymbol{x}_i\) 表示存 \(Q[k]K^T\) 结果的第 i 个 tile \([(i-1)b:ib]\) 的向量,\(m_i^{local}\) 表示 \(x_i\) 中的局部最大值。

\(i\gets 1,tiles\)

\[ \begin{align} \boldsymbol{x}_i&=Q[k,:]K^T[:,\left( i-1 \right) b:ib]\\ m_{i}^{local}&=\underset{j=1}{\overset{b}{\max}}\left( \boldsymbol{x}_i \right)\\ m_i&=\max \left( m_{i-1},m_{i}^{local} \right)\\ l_{i}^{\prime}&=l_{i-1}^{\prime}e^{m_{i-1}-m_i}+\sum_{j=1}^b{e^{\boldsymbol{x}_i\left[ j \right] -m_i}}\\ \boldsymbol{o}_{i}^{\prime}&=\boldsymbol{o}_{i-1}^{\prime}\frac{d_{i-1}^{\prime}e^{m_{i-1}-m_i}}{l_{i}^{\prime}}+\sum_{j=1}^b{\frac{e^{\boldsymbol{x}_i\left[ j \right] -m_i}}{l_{i}^{\prime}}V[j+\left( i-1 \right) b,:]}\\ \end{align} \tag{16} \]

\[ O[k,:]=\boldsymbol{o}_{N/b}^{\prime} \]

上图说明了 FlashAttention 在硬件上的计算方式。蓝色的块表示驻留在 SRAM 中的 tile,而红色的块对应于第 \(k\) 行。L 表示 sequence length,可以相当大(例如16k),d 通常在 Transformers 中较小(例如GPT3的128),b 是可以控制的块大小。值得注意的是,整体 SRAM 内存占用仅依赖于 b 和 d,而与 N 无关。因此,该算法能够扩展到较长的 sequence length 而不会遇到内存问题。在计算过程中,从左到右遍历 \(K^T\)\(A\) ,从上到下遍历 \(V\),并相应地更新 \(m,l,O\) 的状态。

以上证明来自于 《From Online Softmax to FlashAttention》。理解了上述过程,FlashAttention1 论文中的相应证明也就好理解了。万字长文详解FlashAttention v1/v2 - 知乎 这篇文章中详细介绍了论文中的证明。

上图是 FlashAttention1 论文中的算法伪代码。对 \(Q,O\) 沿着行方向分成 \(T_r\) 块,每一块的大小是 \(B_r\times d\),对 \(K,V\) 沿着行方向分为 \(T_c\) 块,每一块的大小为 \(B_c\times d\)

伪代码中带波浪线的是局部值,\(\tilde{m}_{i j}\) 即为下图中的 \(m_{i j}\),其它类似。第 11 行更新 \(l^{new}_i\) 看上去和公式 16 中的不太对应,这是因为在第 10 行,用 \(\tilde{m}_{i j}\) 这个局部值计算得到 \(\tilde{P}_{i j}\),再得到 \(\tilde{l}_{i j}\),所以第 11 要用 \(m^{new}_i\) 去更新 \(\tilde{l}_{i j}\)

下图是伪代码的示意图,虚线上的数字对应伪代码中的行数。[[极简 FlashAttention CUDA 实现]] 这个仓库的代码和为代码能对应上,结合伪代码、代码、示意图,应该就能看懂 FlashAttention1 的过程了。

这里分块参数 \(B_c,B_r\) 的选取和 SRAM 的大小 M 有关,\(B_c=\lceil \frac{M}{4d} \rceil ,B_r=\min \left( \lceil \frac{M}{4d} \rceil ,d \right)\)

这样选择的原因是为了尽量用满 SRAM:

\[ \begin{align} SRAM\left( Q_i \right) &=B_r\times d=\min \left( \lceil \frac{M}{4d} \rceil ,d \right) \times d<\left. \lceil \frac{M}{4} \rceil \right.\\ SRAM\left( O_i \right) &=B_r\times d=\min \left( \lceil \frac{M}{4d} \rceil ,d \right) \times d<\lceil \frac{M}{4} \rceil\\ SRAM\left( K_j,V_j \right) &=2\times B_c\times d=2\times \lceil \frac{M}{4d} \rceil \times d<\lceil \frac{M}{2} \rceil\\ \end{align} \]

FlashAttention 在前向计算时不保留中间结果,但是反向计算又需要这些中间结果,FlashAttention backward pass 的处理方式是重计算,和 forward 一样,将 Q、K、V 分块读取到 SRAM 中,并计算得到当前块的中间结果。对比 Self-Attention 和 FlashAttention 的伪代码:

可以发现,在 FlashAttention backward pass 中,减少了对 \(S,dS,P,dP\) 这些中间结果的全局内存读写。虽然这样增加了计算量的开销,但是减少 IO 开销的收益更大。

IO 分析:

对于 Self-Attention 从算法 3 可知,第一行读 \(Q,K\)\(2Nd\),写 S 是 \(N^2\),第二行读 \(S\)\(N^2\),写 \(P\)\(N^2\),第三行读 \(P\)\(N^2\),读 \(V\)\(Nd\),写 \(O\)\(Nd\)。忽略常数,IO 复杂度是 \(O(Nd+N^2)\)

对于 FA1,内循环读 \(Q\)\(Nd\),外循环决定了内循环的次数,即 \(T_c=\lceil \frac{4dN}{M} \rceil\) 次,忽略常数,IO 复杂度是 \(O(N^2d^2M^{-1})\)

由于 M 远大于 d,所以 FA1 IO 复杂度远小于 Self-Attention。

FlashAttention2

FlashAttention2 的核心思路与 1 相同,但是做了一些工程上的优化。这包括:

减少了非矩阵乘法运算,增加 Tensor Core 使用比例。

在 Tensor Core 的帮助下,理论上时间复杂度更高的矩阵乘可能实际运行起来会那些复杂度低的操作更快更快。比如这个把 FFT 改成矩阵乘的工作 [2311.05908] FlashFFTConv: Efficient Convolutions for Long Sequences with Tensor Cores

在 FlashAttention 1 的 forward pass 中:

而在 FlashAttention 2 中:

区别是,在 2 中,不是迭代计算都执行 rescale 操作,而是在最后执行一次 rescale 操作,这样就可以减少除法运算。以及保留了部分数据用于 backpass。

从公式 15 复制一份过来,然后修改成公式 17:

\[ \begin{align} x_i&=Q[k,:]K^T[:,i]\\ m_i&=\max \left( m_{i-1},x_i \right)\\ l_{i}^{\prime}&=l_{i-1}^{\prime}e^{m_{i-1}-m_i}+e^{x_i-m_i}\\ \boldsymbol{o}_{i}^{\prime}&=\boldsymbol{o}_{i-1}^{\prime}e^{m_{i-1}-m_i}+e^{x_i-m_i}V[i,:]\\ \end{align} \tag{17} \]

\[ O[k,:]=\frac{\boldsymbol{o}_N^{\prime}}{l_{N}^{\prime}} \]

调整了内外循环,Q 为外层循环,KV 为内层循环,减少 HBM 读写。

在 FA1 中,先在外层循环 \(K,V\),然后内层再循环 \(Q\),但是内层循环每个 block 只会用到 \(Q\) 的一部分,这个在下面的示意图上可以明显地看出来。也就是 FA 1 只在 batch_size 和 head 上做并行,这样当 batch_size 比较小的时候,可能会导致启动的 block 太少,占不满 GPU 的 SM 计算资源。

FA2 调整了循环顺序, 外层是 \(Q\) 循环,这样就可以在 seq len 上做并行,提高计算资源的使用率。

调整 warp 划分策略。

FA1 中,在计算 QK^T 的结果与 V 的乘积时,需要跨 warp 做 reduction,这样 warp 之间就需要通信,所以 FA2 中调整了 warp 的划分策略,每个 warp 间的计算是无关的,减少同步开销。

FA2 的代码实现可以参考 [[FlashAttention CUTLASS 实现]]

FlashAttention3

FlashAttention 3 针对 Hopper 架构做了针对性的优化,包括 warp-specialized 的生产者-消费者流水线、将 softmax 计算与矩阵乘的异步 WGMMA 计算重叠,FP8 低精度加速。