FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
2026-02-14 · 推理引擎 · 论文精读
论文信息
- 作者: Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
- 机构: Stanford University, University at Buffalo
- 发表: NeurIPS 2022
- 链接: arXiv:2205.14135
一句话总结
FlashAttention 提出了一种 IO 感知(IO-Aware) 的精确注意力算法,通过 分块计算(Tiling) 和 核函数融合(Kernel Fusion) 避免在 GPU 高带宽内存(HBM)中实体化巨大的注意力矩阵,将注意力计算的内存复杂度从 (O(N^2)) 降至 (O(N)),同时在墙钟时间上比标准注意力快 2-4 倍。
Introduction:为什么需要 FlashAttention?
1. Transformer 的核心瓶颈:Self-Attention
Transformer 已成为 NLP、CV、语音等领域的基础架构。然而,自注意力机制(Self-Attention)在序列长度 (N) 上具有 (O(N^2)) 的时间和空间复杂度,这成为 Transformer 处理长序列的根本瓶颈。
标准注意力的计算流程:
在标准实现中,这个过程需要:
- 计算 (S = QK^\top),生成一个 (N \times N) 的注意力分数矩阵
- 对 (S) 施加 softmax,得到 (P = \text{softmax}(S))
- 计算输出 (O = PV)
问题在于:中间矩阵 (S) 和 (P) 都是 (N \times N) 的,当序列长度 (N = 8192) 时,仅这两个矩阵就需要约 512 MB 显存(fp32)。这不仅占用大量内存,更关键的是需要频繁在 GPU 的不同存储层级之间搬运数据。
2. GPU 内存层次结构:被忽视的瓶颈
论文指出,现有工作几乎都以 FLOP 数(浮点运算次数) 作为优化目标,但这忽略了一个关键事实:现代 GPU 的计算速度远超内存读写速度。
以 A100 GPU 为例:
| 存储层级 | 容量 | 带宽 |
|---|---|---|
| SRAM(片上缓存,每个 SM) | ~20 MB(总计) | ~19 TB/s |
| HBM(高带宽内存) | 40/80 GB | ~2 TB/s |
可以看到,SRAM 的带宽是 HBM 的近 10 倍,但容量却小得多。标准注意力实现的问题在于:
标准 Attention 的内存访问模式:
HBM (慢) SRAM (快) 计算单元
┌──────────┐ 读取 Q,K ┌──────────┐ 矩阵乘 ┌──────────┐
│ Q, K, V │ ──────────► │ │ ──────────► │ S=QK^T │
│ │ │ │ │ │
│ S (N²) │ ◄────────── │ 结果 S │ ◄────────── │ │
│ │ 写回 S │ │ │ │
│ │ 读取 S │ │ softmax │ │
│ P (N²) │ ◄────────── │ 结果 P │ ◄────────── │ P=sm(S) │
│ │ 写回 P │ │ │ │
│ │ 读取 P,V │ │ 矩阵乘 │ │
│ O │ ◄────────── │ 结果 O │ ◄────────── │ O=PV │
└──────────┘ 写回 O └──────────┘ └──────────┘
问题:S 和 P 各 N² 大小,反复在 HBM ↔ SRAM 之间搬运!核心洞察:注意力计算其实是 内存带宽受限(Memory-bound) 的操作,而非计算受限。瓶颈不是"算不过来",而是"数据搬不过来"。
3. 现有的近似注意力方法的困境
为了突破 (O(N^2)) 的限制,研究社区提出了大量的 近似注意力(Approximate Attention) 方法,包括:
- 稀疏注意力(Sparse Attention):只计算部分位置对的注意力(如 Longformer、BigBird)
- 低秩近似(Low-rank Approximation):用低秩矩阵近似完整注意力矩阵(如 Linformer、Performer)
- 线性注意力(Linear Attention):通过核方法将 softmax 近似为可分解形式,实现线性复杂度
然而,论文指出这些方法存在两个共性问题:
- 精度损失:近似方法在长序列上经常出现质量退化,尤其是在需要精确建模长距离依赖的任务中
- 墙钟时间并未真正加速:虽然 FLOP 数降低了,但由于这些方法往往引入了更多的内存访问开销(如稀疏索引、额外的矩阵变换),在实际 GPU 上跑起来并没有标准注意力快。论文中的实验表明,很多近似方法在序列长度达到 512-2048 之前甚至比标准注意力更慢
一个反直觉的事实
减少 FLOP ≠ 减少运行时间。在 GPU 上,如果一个算法 FLOP 更少但内存访问更多,它完全可能比 FLOP 更多但内存访问模式更优的算法更慢。这就是 FlashAttention 的出发点。
4. FlashAttention 的核心思路
FlashAttention 不走近似路线,而是从 IO 复杂度(IO Complexity) 的角度重新审视标准注意力,通过优化内存访问模式来实现加速,同时保持结果的 数值精确性。
核心策略包括两点:
(1)分块计算(Tiling):将 Q、K、V 分成小块,每次只加载一小块到 SRAM 中进行计算,避免实体化完整的 (N \times N) 注意力矩阵。
(2)在线 Softmax(Online Softmax):传统 softmax 需要先遍历整行求最大值和求和,再做归一化——这要求整行数据同时在内存中。FlashAttention 采用了 Milakov & Gimelshein (2018) 提出的在线 softmax 技巧,在分块流式处理的过程中 增量更新 softmax 统计量(running max 和 running sum),无需回头修正。
FlashAttention 的内存访问模式:
HBM (慢) SRAM (快) 计算单元
┌──────────┐ 读取 Q块, ┌──────────┐ 一次性 ┌──────────┐
│ Q, K, V │ ──────────► │ Q块,K块, │ ──────────► │ 分块计算 │
│ │ K块, V块 │ V块 │ │ S块→P块→ │
│ │ │ │ 融合计算 │ O块累加 │
│ O │ ◄────────── │ O块 │ ◄────────── │ │
└──────────┘ 只写最终O └──────────┘ └──────────┘
优势:
- 中间矩阵 S、P 从不写回 HBM
- HBM 读写量从 O(N²) 降至 O(N²d²M⁻¹)(M 为 SRAM 大小)
- 结果与标准注意力完全一致(精确算法)5. 论文的主要贡献
论文总结了以下关键贡献:
FlashAttention 算法:一种 IO 感知的精确注意力实现,通过 Tiling 和在线 Softmax 将 HBM 访问量减少为 (O(N^2 d^2 M^{-1})),其中 (d) 是头维度、(M) 是 SRAM 大小。论文还证明了在所有精确注意力算法中,这是 HBM 访问次数的渐近最优下界
Kernel 融合的扩展:将 FlashAttention 扩展到支持常用的注意力变体,包括 带 Mask 的注意力(如因果掩码)和 Dropout,这些操作都在同一个 CUDA Kernel 中完成,避免了额外的内存读写
长序列建模的实际收益:基于 FlashAttention 的高效实现,论文展示了在多个基准任务上的显著提升:
- GPT-2 训练速度提升至标准 HuggingFace 实现的 3 倍
- 支持的序列长度从 1K-2K 拓展到 4K-16K,使 Transformer 首次在长文档分类(如 MIMIC-III)和长序列生成任务上取得 SOTA 表现
- Path-X(16K 序列长度的合成任务)上首次达到 超过随机水平的准确率
IO 复杂度的理论分析:论文给出了精确注意力的 HBM 访问下界证明,并分析了常见近似/稀疏注意力的 IO 复杂度,为后续注意力优化研究提供了理论基础
为什么叫"Flash"?
Flash 一语双关:既指速度极快(如闪存 Flash Memory),也暗示了算法的核心思想——像闪存一样 感知和优化 IO 访问模式,让数据在正确的存储层级被高效处理。
标准注意力实现:Algorithm 0
在深入 FlashAttention 之前,我们必须先彻底理解"标准实现到底做了什么、代价几何"。论文将其命名为 Algorithm 0,作为后续优化的基线。
形式化定义
给定输入矩阵 (Q, K, V \in \mathbb{R}^{N \times d})((N) 为序列长度,(d) 为头维度),标准注意力的计算分为三步:
其中 softmax 按行应用。
Algorithm 0 的执行流程
论文给出的标准实现伪代码如下:
Algorithm 0: Standard Attention Implementation
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
输入: Q, K, V ∈ R^{N×d},存储在 HBM 中
Step 1: 从 HBM 分块加载 Q, K,计算 S = QK^T,将 S 写回 HBM
Step 2: 从 HBM 读取 S,计算 P = softmax(S),将 P 写回 HBM
Step 3: 从 HBM 分块加载 P 和 V,计算 O = PV,将 O 写回 HBM
返回: O逐步拆解 HBM 访问量
我们来精确计算每一步对 HBM 的读写量:
| 步骤 | 操作 | HBM 读取 | HBM 写入 | 说明 |
|---|---|---|---|---|
| Step 1 | (S = QK^\top) | (Q): (Nd) + (K): (Nd) | (S): (N^2) | 矩阵乘法,结果必须写回 HBM |
| Step 2 | (P = \text{softmax}(S)) | (S): (N^2) | (P): (N^2) | 逐元素操作,但 S 太大放不进 SRAM |
| Step 3 | (O = PV) | (P): (N^2) + (V): (Nd) | (O): (Nd) | 矩阵乘法 |
| 合计 | — | (3Nd + 2N^2) | (Nd + 2N^2) | 总 HBM 访问 = (4Nd + 4N^2) |
关键观察:当 (N \gg d) 时(GPT-2: (N = 1024, d = 64)),HBM 访问量被 (4N^2) 项主导,即 (O(N^2))。
而中间矩阵 (S) 和 (P) 各占 (N^2) 的空间——它们的唯一作用是作为"中转站":被写入 HBM 后马上又被读出来。这种"写了就读、读了就扔"的模式是巨大的浪费。
为什么标准实现是 Memory-bound?
我们可以算一笔帐,用 算术强度(Arithmetic Intensity) 来判断操作是计算密集还是内存密集:
对于 Step 2(softmax):
- FLOPs:(O(N^2))(每个元素做 exp、加法、除法)
- HBM 访问:读 (N^2) + 写 (N^2) = (2N^2) 个元素 = (8N^2) bytes(fp32)
- 算术强度 ≈ (O(1))
A100 的算术强度平衡点约为 (\frac{312 \text{ TFLOPS}}{2 \text{ TB/s}} = 156) FLOP/byte。softmax 的算术强度远低于这个值,因此它是一个典型的 内存带宽受限 操作——GPU 的计算单元大部分时间在"等数据"。
即使矩阵乘法也受影响
Step 1 和 Step 3 的矩阵乘法本身是计算密集型操作,但因为 (S) 和 (P) 必须经过 HBM 这个"中转站",整个流水线的实际吞吐被 Step 2 的内存瓶颈拖慢了。Fuse 不掉 softmax,前后的 matmul 也快不起来。
Masking 和 Dropout 雪上加霜
论文特别指出,实际应用中注意力矩阵上还要叠加额外的逐元素操作:
- Masking:因果注意力需要对 (S) 施加下三角掩码,即 (S_{ij} = -\infty) for (j > i)
- Dropout:训练时对 (P) 施加随机置零
每增加一个逐元素操作,就多一轮 (N^2) 的 HBM 读写。虽然社区已尝试将 masking 和 softmax 融合进同一个 kernel,但 只要 (S) 和 (P) 还存在于 HBM 中,根本问题就没有解决。
动手试一试
下面的 C++ 代码完整模拟了 Algorithm 0 的三步流程,并 精确追踪每一步的 HBM 读写量。运行后你可以直观看到:中间矩阵 (S) 和 (P) 如何主导了内存访问开销。
小结
标准注意力实现的核心问题可以用一句话概括:
中间矩阵 (S) 和 (P) 是 (N^2) 大小的"一次性中转站",它们被写入 HBM 后马上被读出,读出后再也不用——但它们主导了整个算法的内存访问开销。
FlashAttention 的解法正是从这里出发:如果我们能在 SRAM 中分块完成 (S \to P \to O) 的全部计算,让 (S) 和 (P) 永远不触碰 HBM,那么 HBM 的访问量就可以从 (O(N^2)) 大幅下降。但这要求我们解决一个技术难题:softmax 是全局操作,如何在只看到一小块数据的情况下正确计算? 这就是下一节 FlashAttention 算法的核心挑战。
FlashAttention 算法:Tiling + Online Softmax
上一节我们看到,标准注意力的瓶颈在于中间矩阵 (S, P \in \mathbb{R}^{N \times N}) 必须在 HBM 中实体化。FlashAttention 的目标很明确:在不实体化 (S) 和 (P) 的前提下,精确计算 (\text{softmax}(QK^\top)V)。
这里我们只讨论前向传播(Forward Pass),反向传播的细节见论文 Appendix B。
核心挑战:Softmax 是"全局操作"
分块计算矩阵乘法很简单——把大矩阵切成小块逐块相乘再累加就行。但 softmax 不同:
分母是对 整行 求和,这意味着要计算第 (i) 行的 softmax,你需要知道 (S) 的整行 (N) 个值。如果我们把 (K) 分成多个块,每次只算出 (S) 的一部分列,怎么做 softmax?
FlashAttention 的答案是:在线 Softmax(Online Softmax)——一种增量式的分块 softmax 算法。
分块 Softmax 的数学推导
单块的 Softmax 统计量
对于一个向量 (x \in \mathbb{R}^{B}),数值稳定的 softmax 需要三个统计量:
最终 (\text{softmax}(x) = \frac{f(x)}{\ell(x)})。
两块合并:关键递推公式
现在假设我们有两个分块 (x^{(1)}, x^{(2)} \in \mathbb{R}^{B}),要计算拼接向量 (x = [x^{(1)}, x^{(2)}] \in \mathbb{R}^{2B}) 的 softmax。核心递推公式为:
直觉:当新块的最大值更大时,旧块的 exp 值需要"缩小"(乘以 (e^{m_{\text{old}} - m_{\text{new}}} < 1));当旧块最大值更大时,新块被缩小。这个缩放因子保证了数值的正确性。
关键洞察
这个递推公式意味着:我们可以 逐块处理 K 的列方向,每处理一块就更新 (m) 和 (\ell),最终得到的 softmax 结果与一次性处理整行 完全一致——没有任何近似!
从 Softmax 到 Attention 输出的增量更新
分块 softmax 解决了归一化的问题,但注意力的最终输出是 (O = PV),即 softmax 的结果还要和 (V) 做矩阵乘法。我们不能先算完所有 softmax 再乘 (V)——那又回到了实体化 (P) 的老路。
FlashAttention 的做法是 边算 softmax 边累加 (O)。当处理第 (j) 个 K/V 块时:
- 计算当前块的局部注意力分数 (\tilde{S}_{ij} = Q_i K_j^\top)
- 计算局部统计量 (\tilde{m}{ij}, \tilde{P}, \tilde{\ell}_{ij})
- 更新全局统计量 (m_i^{\text{new}}, \ell_i^{\text{new}})
- 修正并累加输出:
这个公式的含义是:
- 旧输出修正:之前累加的 (O_i) 是基于旧的最大值 (m_i) 计算的,现在最大值更新为 (m_i^{\text{new}}),需要乘以修正因子 (e^{m_i - m_i^{\text{new}}})
- 新块贡献:当前块的 softmax 值乘以 (V_j),同样调整到新的最大值尺度
- 重新归一化:除以新的 (\ell_i^{\text{new}}) 确保概率和为 1
Algorithm 1:FlashAttention 完整算法
将上述思路系统化,就得到了论文中的 Algorithm 1:
Algorithm 1: FlashAttention
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
输入: Q, K, V ∈ R^{N×d} (在 HBM 中), 片上 SRAM 大小 M
1: 设置块大小 Bc = ⌈M/(4d)⌉, Br = min(⌈M/(4d)⌉, d)
2: 在 HBM 中初始化 O = 0^{N×d}, ℓ = 0^N, m = (-∞)^N
3: 将 Q 分为 Tr = ⌈N/Br⌉ 个块: Q₁, ..., Q_Tr (每块 Br×d)
将 K, V 分为 Tc = ⌈N/Bc⌉ 个块: K₁,...,K_Tc 和 V₁,...,V_Tc (每块 Bc×d)
4: 将 O 分为 Tr 个块, ℓ 和 m 也分为 Tr 个块
5: for j = 1 to Tc do ← 外层循环: 遍历 K/V 块
6: 从 HBM 加载 Kⱼ, Vⱼ 到 SRAM
7: for i = 1 to Tr do ← 内层循环: 遍历 Q 块
8: 从 HBM 加载 Qᵢ, Oᵢ, ℓᵢ, mᵢ 到 SRAM
9: 在 SRAM 中计算 Sᵢⱼ = Qᵢ Kⱼᵀ ∈ R^{Br×Bc}
10: 在 SRAM 中计算:
m̃ᵢⱼ = rowmax(Sᵢⱼ) ∈ R^{Br}
P̃ᵢⱼ = exp(Sᵢⱼ - m̃ᵢⱼ) ∈ R^{Br×Bc}
ℓ̃ᵢⱼ = rowsum(P̃ᵢⱼ) ∈ R^{Br}
11: 在 SRAM 中计算:
mᵢⁿᵉʷ = max(mᵢ, m̃ᵢⱼ)
ℓᵢⁿᵉʷ = e^{mᵢ - mᵢⁿᵉʷ} · ℓᵢ + e^{m̃ᵢⱼ - mᵢⁿᵉʷ} · ℓ̃ᵢⱼ
12: 写回 HBM:
Oᵢ ← diag(ℓᵢⁿᵉʷ)⁻¹ (diag(ℓᵢ)·e^{mᵢ-mᵢⁿᵉʷ}·Oᵢ + e^{m̃ᵢⱼ-mᵢⁿᵉʷ}·P̃ᵢⱼ·Vⱼ)
13: 写回 HBM: ℓᵢ ← ℓᵢⁿᵉʷ, mᵢ ← mᵢⁿᵉʷ
14: end for
15: end for
16: 返回 O循环结构的直觉理解
算法采用 双层循环,可以用下面的图来理解:
K₁ K₂ K₃ K₄ (Tc 个 K/V 块, 外层循环 j)
┌─────┬─────┬─────┬─────┐
Q₁ │ S₁₁ │ S₁₂ │ S₁₃ │ S₁₄ │ ← 第 i=1 轮: 逐块更新 O₁
├─────┼─────┼─────┼─────┤
Q₂ │ S₂₁ │ S₂₂ │ S₂₃ │ S₂₄ │ ← 第 i=2 轮: 逐块更新 O₂
├─────┼─────┼─────┼─────┤
Q₃ │ S₃₁ │ S₃₂ │ S₃₃ │ S₃₄ │ ← 第 i=3 轮: 逐块更新 O₃
└─────┴─────┴─────┴─────┘
每个小块 Sᵢⱼ 的大小只有 Br × Bc, 完全放得进 SRAM!
整个 N×N 的 S 矩阵从未被完整构造出来.外层遍历 K/V 块(列方向),内层遍历 Q 块(行方向)。对于每个 ((i, j)) 组合:
- 计算一个小块 (S_{ij} \in \mathbb{R}^{B_r \times B_c})(在 SRAM 中,不写回 HBM)
- 更新第 (i) 行的 softmax 统计量和输出 (O_i)
块大小的选择
Algorithm 1 第 1 行给出了块大小的设定:
这是为了保证 每次循环需要的数据都能放进 SRAM。一次内层迭代需要同时在 SRAM 中保存:
- (K_j): (B_c \times d)
- (V_j): (B_c \times d)
- (Q_i): (B_r \times d)
- (O_i): (B_r \times d)
- (S_{ij}): (B_r \times B_c)(中间计算结果)
总计约 (2B_c d + 2B_r d + B_r B_c) 个浮点数,需要不超过 (M) 个元素。
重计算(Recomputation):反向传播的优化
标准反向传播需要保存 (S, P \in \mathbb{R}^{N \times N}) 用于梯度计算。FlashAttention 的策略是:只保存 (O)、(m)、(\ell),在反向传播时重新计算 (S) 和 (P)。
| 方案 | 需要保存 | 额外内存 | 代价 |
|---|---|---|---|
| 标准反向传播 | (S, P \in \mathbb{R}^{N \times N}) | (O(N^2)) | 无 |
| 梯度检查点 | 不保存,全部重算 | (O(N)) | 速度慢(重复计算) |
| FlashAttention | (O, m, \ell) | (O(N)) | 反而更快(减少 HBM 访问) |
这看似矛盾——重新计算不是增加了 FLOP 吗?确实,FlashAttention 的总 FLOP 略多于标准方法。但由于重计算发生在 SRAM 中(带宽 ~19 TB/s),而标准方法需要从 HBM 中读取 (S) 和 (P)(带宽 ~2 TB/s),减少的 HBM 访问远比多出的 FLOP 划算。
反直觉的结论
更多的 FLOP + 更少的 HBM 访问 = 更快的实际速度。这再次印证了 FlashAttention 的核心哲学:在 memory-bound 场景下,优化 IO 比优化计算更重要。
Kernel 融合:一个 CUDA Kernel 搞定一切
Tiling 使得 FlashAttention 可以将所有计算步骤融合到 一个 CUDA Kernel 中:
一个 FlashAttention Kernel 内部的完整流程:
┌─────────────────────────────────────────────┐
│ CUDA Kernel (一次启动, 一次执行) │
│ │
│ 1. 从 HBM 加载 Q块, K块, V块 → SRAM │
│ 2. SRAM 中: S块 = Q块 × K块ᵀ │
│ 3. SRAM 中: Masking (可选, 因果掩码) │
│ 4. SRAM 中: 在线 Softmax (更新 m, ℓ) │
│ 5. SRAM 中: Dropout (可选, 训练时) │
│ 6. SRAM 中: O块 累加 = P̃块 × V块 │
│ 7. 写回 O块, m, ℓ → HBM │
│ │
│ 中间结果 S块, P̃块 全程留在 SRAM, 从不触碰 HBM │
└─────────────────────────────────────────────┘相比标准实现需要多个 Kernel(matmul → mask → softmax → dropout → matmul),每个 Kernel 之间都要经过 HBM 中转,FlashAttention 只有 一次 HBM 读入 + 一次 HBM 写出。
正确性与复杂度(Theorem 1)
论文给出了严格证明(详见 Appendix C):
Theorem 1: Algorithm 1 返回 (O = \text{softmax}(QK^\top)V),所需 FLOPs 为 (O(N^2 d)),额外内存(输入输出之外)为 (O(N))。
- 精确性:输出与标准注意力 完全一致(不是近似),因为在线 softmax 的递推公式是数学恒等式
- FLOPs:与标准注意力相同,都是 (O(N^2 d))——FlashAttention 没有减少计算量
- 额外内存:只需要 (O(N)) 存储 (m) 和 (\ell)(而非标准方法的 (O(N^2)) 存储 (S) 和 (P))
动手试一试
下面的 C++ 代码完整实现了 Algorithm 1,你可以对比上一节的 Algorithm 0,观察两者的输出 完全一致,但 HBM 访问量大幅下降。
小结
FlashAttention 的 Algorithm 1 通过三个关键技术实现了 HBM 访问量的大幅下降:
| 技术 | 解决的问题 | 效果 |
|---|---|---|
| Tiling(分块) | (S, P) 太大放不进 SRAM | 分成 (B_r \times B_c) 的小块,每块在 SRAM 中完成全部计算 |
| Online Softmax(在线 Softmax) | Softmax 需要全局归一化 | 通过维护 (m, \ell) 统计量增量更新,结果精确一致 |
| Recomputation(重计算) | 反向传播需要 (S, P) | 只保存 (O, m, \ell),反向时重新计算,减少 (O(N^2) \to O(N)) 额外内存 |
最终效果:FLOPs 不变(甚至略多),但 HBM 访问从 (O(N^2)) 降至 (O(N^2 d^2 M^{-1}))——在 memory-bound 场景下,这直接转化为 2-4 倍的墙钟加速。
分析:FlashAttention 的 IO 复杂度
上一节我们展示了 FlashAttention 算法的具体实现,但一个关键问题尚未严格回答:FlashAttention 的 HBM 访问量到底是多少?这个量是最优的吗?
论文在 Section 3.2 给出了两个核心理论结果:
- Theorem 2(上界):Algorithm 1 的 HBM 访问量为 (O(N^2 d^2 M^{-1}))
- Proposition 3(下界):任何精确注意力算法的 HBM 访问量不低于 (\Omega(N^2 d^2 M^{-1}))
两者匹配,说明 FlashAttention 在 HBM 访问次数上是 渐近最优 的。
Theorem 2:HBM 访问量的上界
Theorem 2: 设 (N) 为序列长度,(d) 为头维度,(M) 为 SRAM 大小(以元素计),且 (d \leq M \leq Nd)。Algorithm 1(FlashAttention)的 HBM 访问量为 (O(N^2 d^2 M^{-1}))。
证明方法很直接:逐一统计 Algorithm 1 每一步的 HBM 读写量,然后求和。
逐步推导
回顾 Algorithm 1 的块大小设定:
对应的块数为:
现在统计每一层循环的 HBM 访问量:
外层循环(遍历 K/V 块,共 (T_c) 次):
- 每次加载 (K_j):(B_c \times d) 个元素
- 每次加载 (V_j):(B_c \times d) 个元素
- 小计:(2 B_c d) 次读取
内层循环(遍历 Q 块,每个外层迭代执行 (T_r) 次):
- 读取 (Q_i, O_i):(2 B_r d) 次读取
- 读取 (\ell_i, m_i):(2 B_r) 次读取
- 写回 (O_i):(B_r d) 次写入
- 写回 (\ell_i, m_i):(2 B_r) 次写入
- 小计:(3 B_r d + 4 B_r) 次访问
总 HBM 访问量:
主导项为内层循环部分:(T_c \cdot T_r \cdot 3 B_r d)。接下来分两种情况讨论:
情况 1:(M \leq 4d^2)(SRAM 较小)
此时 (B_r = B_c = \frac{M}{4d}),两个块大小相同。
主导项:
情况 2:(M > 4d^2)(SRAM 较大)
此时 (B_r = d),(B_c = \frac{M}{4d})。
主导项:
两种情况均给出 (\Theta!\left(\frac{N^2 d^2}{M}\right))。
直觉理解:这个公式在说什么?
(O(N^2 d^2 M^{-1})) 这个表达式可以拆解理解:
| 因子 | 含义 |
|---|---|
| (N^2) | 注意力矩阵的"逻辑大小"——我们必须计算所有 (N^2) 个 token-pair 的交互 |
| (d^2) | 每对 token 的交互需要 (d) 维向量的内积;分块时每块还需要独立加载 (d) 维数据 |
| (M^{-1}) | SRAM 越大,每次能处理的数据块越大,需要的"轮次"越少 |
我们可以检查两个边界条件是否合理:
| 条件 | (M) 的值 | HBM 访问量 | 含义 |
|---|---|---|---|
| SRAM 极小 | (M = d^2) | (N^2) | 退化为标准注意力,和 Algorithm 0 一样 |
| SRAM 极大 | (M = Nd) | (Nd) | 只需读一遍输入,无需分块 |
这正好覆盖了从"完全放不下"到"完全放得下"的整个 SRAM 容量谱。
Proposition 3:下界——FlashAttention 是最优的
Proposition 3: 设 (N \geq d),(M) 满足 (d \leq M \leq Nd)。则任何计算精确注意力的算法,其 HBM 访问量不低于 (\Omega(N^2 d^2 M^{-1}))。
这意味着 FlashAttention 的 HBM 访问量不仅仅是"足够好"——它已经达到了 理论最优下界,不可能再进一步减少(在渐近意义上)。
证明思路
下界证明基于 矩阵乘法的 IO 复杂度下界(参考 Hong & Kung 1981 的"红蓝石子博弈"框架):
规约到矩阵乘法:注意力计算必须(显式或隐式地)完成 (S = QK^\top) 这一步——这是一个 ((N \times d) \times (d \times N)) 的矩阵乘法
矩阵乘法的 IO 下界:将 (m \times k) 矩阵与 (k \times n) 矩阵相乘,在 SRAM 大小为 (M) 的两级存储模型下,HBM 访问量的下界为:
- 应用到注意力:对于 (S = QK^\top),(m = N, k = d, n = N):
- 注意力比纯矩阵乘法更难:但注意力不仅要计算 (S),还要对 (S) 做 softmax 再乘 (V)。softmax 是逐行操作,要求同一行的所有元素在同一时刻可达。论文通过更精细的分析证明,这个额外约束将下界提升至 (\Omega(N^2 d^2 M^{-1}))
为什么下界比纯矩阵乘法更高?
纯矩阵乘法允许任意顺序计算输出元素,但 softmax 引入了 行内全局依赖:你必须看完 (S) 的整行才能做归一化。这迫使算法在 K/V 块的方向上"多扫几遍",每多扫一遍就多一轮 HBM 读写。这个额外约束使 IO 下界从 (\frac{N^2d}{\sqrt{M}}) 提升到 (\frac{N^2d^2}{M})。
上下界匹配的意义
将 Theorem 2 和 Proposition 3 放在一起:
上下界完全匹配:FlashAttention 是渐近最优的精确注意力算法。
这是一个非常强的理论保证——它告诉我们:
- 不要再找更好的:在这个计算模型下,没有任何精确注意力算法能比 FlashAttention 做更少的 HBM 访问
- 要想进一步加速,必须换赛道:只能通过增加 SRAM(硬件改进)、使用近似注意力(允许误差)、或者改变注意力模式(如稀疏)来突破
- 硬件设计的指导意义:(M^{-1}) 的依赖关系表明,增加 SRAM 容量对注意力计算有直接的性能收益
动手试一试
下面的代码让你直观感受 SRAM 大小 (M) 如何影响 HBM 访问量,并验证理论公式 (O(N^2 d^2 / M)) 的准确性。
与近似注意力方法的 IO 对比
论文还比较了 FlashAttention 与常见近似注意力方法的 IO 复杂度。虽然近似方法的 FLOPs 更少,但它们的 IO 模式未必更优:
| 方法 | FLOPs | HBM 访问量 | 精确? |
|---|---|---|---|
| 标准注意力 (Algo 0) | (O(N^2 d)) | (O(N^2 + Nd)) | 是 |
| FlashAttention (Algo 1) | (O(N^2 d)) | (O(N^2 d^2 M^{-1})) | 是 |
| 稀疏注意力 (如 Longformer) | (O(N \cdot S \cdot d)) | (O(N \cdot S + Nd)) | 否 |
| 线性注意力 (如 Performer) | (O(Nd^2)) | (O(Nd)) | 否 |
其中 (S) 是稀疏注意力中每个 token 关注的邻域大小。
一个关键观察
当 (M) 足够大(即 (d^2 / M) 足够小)时,FlashAttention 的 HBM 访问量 (N^2 d^2 / M) 可以远小于标准注意力的 (N^2),甚至逼近线性注意力的 (Nd) 量级——同时保持精确计算。这就是 FlashAttention 在实践中能与近似方法竞争甚至胜出的理论基础。
小结
| 理论结果 | 结论 |
|---|---|
| Theorem 2(上界) | FlashAttention 的 HBM 访问量 (\leq O(N^2 d^2 M^{-1})) |
| Proposition 3(下界) | 任何精确注意力算法 (\geq \Omega(N^2 d^2 M^{-1})) |
| 合在一起 | FlashAttention 是渐近最优的 |
核心启示:
- 算法已到极限:在精确注意力的框架内,FlashAttention 的 IO 效率无法被本质性地超越
- SRAM 是关键资源:(M^{-1}) 依赖意味着片上 SRAM 每扩大一倍,HBM 访问量减半。这为硬件设计(更大的 shared memory)和软件优化(更精细的 SRAM 管理)指明了方向
- FLOPs 不是全部:FlashAttention 和标准注意力有相同的 FLOPs,但 IO 效率天差地别。在 memory-bound 场景下,优化 IO 比优化计算更重要
扩展:块稀疏 FlashAttention(Block-Sparse FlashAttention)
上两节我们证明了 FlashAttention 在 精确注意力 的框架下已达到 IO 最优。但在很多实际场景中,我们并不需要每个 token 都关注所有其他 token——稀疏注意力(Sparse Attention) 是一种广泛使用的提升效率的手段。
论文的一个重要贡献是:FlashAttention 的分块架构天然适合与块稀疏模式结合,只需 极少的修改 就能获得 IO 和 FLOPs 的双重收益。
什么是块稀疏注意力?
标准注意力计算完整的 (N \times N) 注意力矩阵。块稀疏注意力则引入一个 块级掩码矩阵:
其中 (T_r = \lceil N / B_r \rceil),(T_c = \lceil N / B_c \rceil)。当 (\tilde{M}{ij} = 0) 时,表示 Q 的第 (i) 块和 K 的第 (j) 块之间的注意力被完全跳过——对应的 (S) 块既不计算,也不参与 softmax。
块稀疏掩码示例 (N=16, Br=Bc=4, 所以 Tr=Tc=4):
K₁ K₂ K₃ K₄
┌────┬────┬────┬────┐
Q₁ │ 1 │ 1 │ 0 │ 0 │ → Q₁ 只关注 K₁, K₂
├────┼────┼────┼────┤
Q₂ │ 1 │ 1 │ 1 │ 0 │ → Q₂ 关注 K₁, K₂, K₃
├────┼────┼────┼────┤
Q₃ │ 0 │ 1 │ 1 │ 1 │ → Q₃ 关注 K₂, K₃, K₄
├────┼────┼────┼────┤
Q₄ │ 0 │ 0 │ 1 │ 1 │ → Q₄ 只关注 K₃, K₄
└────┴────┴────┴────┘
1 = 计算该块 0 = 跳过该块
非零块数: 10/16 = 62.5% → 稀疏度 s = 0.625从 Algorithm 1 到 Block-Sparse FlashAttention
由于 FlashAttention 已经将注意力按 ((B_r, B_c)) 大小分块处理,加入块稀疏支持只需 一行修改:在内层循环中,检查掩码 (\tilde{M}_{ij}),如果为 0 则直接跳过该块。
Algorithm 5: Block-Sparse FlashAttention (基于 Algorithm 1 的修改)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
输入: Q, K, V ∈ R^{N×d}, 块稀疏掩码 M~ ∈ {0,1}^{Tr×Tc}
... (其余初始化与 Algorithm 1 相同) ...
5: for j = 1 to Tc do
6: 从 HBM 加载 Kⱼ, Vⱼ 到 SRAM
7: for i = 1 to Tr do
7.5: if M~[i][j] = 0 then continue ← 唯一的新增行!
8: 从 HBM 加载 Qᵢ, Oᵢ, ℓᵢ, mᵢ 到 SRAM
9-13: (与 Algorithm 1 完全相同)
14: end for
15: end for
返回: O这个"一行修改"看似简单,背后的深意是:FlashAttention 的分块粒度恰好与块稀疏掩码的粒度对齐。不需要任何索引重排、scatter/gather 操作或额外的数据结构——只需跳过零块即可。
常见的稀疏注意力模式
块稀疏框架可以统一表达多种流行的稀疏注意力模式。定义稀疏度 (s) 为非零块的比例:(s = \frac{|{(i,j) : \tilde{M}_{ij} = 1}|}{T_r \times T_c})。
1. 滑动窗口注意力(Local / Sliding Window)
每个 token 只关注前后 (w) 个位置范围内的 token:
滑动窗口 (w=2 个块):
┌─┬─┬─┬─┬─┬─┬─┬─┐
│█│█│█│ │ │ │ │ │ 稀疏度: s = (2w+1)/Tc
│█│█│█│█│ │ │ │ │ 当 w << N 时, s ≈ 2w/N
│█│█│█│█│█│ │ │ │
│ │█│█│█│█│█│ │ │ 典型应用:
│ │ │█│█│█│█│█│ │ - Longformer
│ │ │ │█│█│█│█│█│ - Mistral
│ │ │ │ │█│█│█│█│ - 局部上下文建模
│ │ │ │ │ │█│█│█│
└─┴─┴─┴─┴─┴─┴─┴─┘2. 因果注意力(Causal Attention)
自回归模型的标准模式——每个 token 只能看到自己及之前的 token:
因果掩码 (下三角):
┌─┬─┬─┬─┬─┬─┬─┬─┐
│█│ │ │ │ │ │ │ │ 稀疏度: s ≈ 0.5
│█│█│ │ │ │ │ │ │
│█│█│█│ │ │ │ │ │ 典型应用:
│█│█│█│█│ │ │ │ │ - GPT 系列
│█│█│█│█│█│ │ │ │ - LLaMA
│█│█│█│█│█│█│ │ │ - 所有自回归 LLM
│█│█│█│█│█│█│█│ │
│█│█│█│█│█│█│█│█│
└─┴─┴─┴─┴─┴─┴─┴─┘3. 全局 + 滑动窗口(Global + Local)
少数特殊 token(如 [CLS]、句首)关注所有位置,其余 token 使用滑动窗口:
全局(前 2 行) + 滑动窗口:
┌─┬─┬─┬─┬─┬─┬─┬─┐
│█│█│█│█│█│█│█│█│ ← 全局 token
│█│█│█│█│█│█│█│█│ ← 全局 token
│█│█│█│█│█│ │ │ │
│█│█│█│█│█│█│ │ │ 典型应用:
│█│█│ │█│█│█│█│ │ - BigBird
│█│█│ │ │█│█│█│█│ - Longformer
│█│█│ │ │ │█│█│█│ - ETC
│█│█│ │ │ │ │█│█│
└─┴─┴─┴─┴─┴─┴─┴─┘4. 步幅稀疏(Strided / Dilated)
以固定间隔采样关注位置,捕捉长距离依赖:
步幅稀疏 (stride=2):
┌─┬─┬─┬─┬─┬─┬─┬─┐
│█│ │█│ │█│ │█│ │ 稀疏度: s = 1/stride
│ │█│ │█│ │█│ │█│
│█│ │█│ │█│ │█│ │ 典型应用:
│ │█│ │█│ │█│ │█│ - Sparse Transformer
│█│ │█│ │█│ │█│ │ - 长序列音频建模
│ │█│ │█│ │█│ │█│
│█│ │█│ │█│ │█│ │
│ │█│ │█│ │█│ │█│
└─┴─┴─┴─┴─┴─┴─┴─┘IO 复杂度分析(Proposition 4)
Proposition 4: 设块稀疏掩码的非零块比例为 (s)((0 < s \leq 1))。Block-Sparse FlashAttention 的 HBM 访问量为:
证明直觉:在 Algorithm 1 中,内层循环的迭代次数从 (T_r \times T_c) 减少为 (s \cdot T_r \times T_c)(只处理非零块)。每次迭代的 HBM 访问量不变,因此总量线性地乘以稀疏度 (s)。
与各方案的完整对比:
| 方法 | FLOPs | HBM 访问量 | 精确? | 适用场景 |
|---|---|---|---|---|
| 标准注意力 | (O(N^2 d)) | (O(N^2)) | 是 | 基线 |
| FlashAttention | (O(N^2 d)) | (O(N^2 d^2 / M)) | 是 | 通用加速 |
| Block-Sparse Flash | (O(N^2 d \cdot s)) | (O(N^2 d^2 s / M)) | 在稀疏模式内精确 | 已知稀疏模式 |
| 近似注意力 | 各异 | 各异 | 否 | 特定任务 |
Block-Sparse 的双重收益
注意 Block-Sparse FlashAttention 同时获得了 FLOPs 和 IO 两方面的节省——两者都乘以稀疏度 (s)。这与标准稀疏注意力不同:后者虽然减少了 FLOPs,但由于稀疏索引和 scatter/gather 操作,IO 反而可能更差。FlashAttention 的分块架构让稀疏模式的收益被"干净地"兑现。
为什么 FlashAttention 特别适合块稀疏?
之所以 FlashAttention + 块稀疏如此自然,根本原因是 粒度对齐:
传统稀疏注意力: Block-Sparse FlashAttention:
逐元素稀疏 → 需要稀疏索引 块级稀疏 → 只需 skip 整个块
┌─────────────────┐ ┌─────────────────┐
│ 1 0 1 0 0 1 0 0 │ │ ██ │ │ │ │
│ 0 1 0 0 1 0 0 1 │ │────┼────┼────┼────│
│ 1 0 1 0 0 0 1 0 │ vs │ ██ │ ██ │ │ │
│ 0 0 0 1 0 1 0 0 │ │────┼────┼────┼────│
│ 0 1 0 0 1 0 0 1 │ │ │ ██ │ ██ │ │
│ 1 0 0 1 0 1 0 0 │ │────┼────┼────┼────│
│ 0 0 1 0 0 0 1 0 │ │ │ │ ██ │ ██ │
│ 0 1 0 0 1 0 0 1 │ └─────────────────┘
└─────────────────┘
- 需要 CSR/COO 等稀疏格式 - 掩码只有 Tr×Tc 个 bit
- scatter/gather 内存不连续 - 整块 skip, 内存访问连续
- GPU 利用率低 (warp divergence) - 无 warp divergence
- 额外索引开销抵消稀疏收益 - 稀疏收益被干净兑现动手试一试
下面的代码对比标准 FlashAttention 和 Block-Sparse FlashAttention 在不同稀疏模式下的 HBM 访问量和 FLOPs。
实际影响
Block-Sparse FlashAttention 的实践意义体现在几个方面:
1. 统一框架
之前的稀疏注意力实现各自为政——Longformer 有自己的 CUDA kernel,BigBird 有自己的,Sparse Transformer 又是另一套。Block-Sparse FlashAttention 提供了一个统一的高效实现:只需指定不同的块掩码 (\tilde{M}),就能支持任意块稀疏模式。
2. IO 高效的稀疏注意力
传统稀疏注意力的一个尴尬是:虽然 FLOPs 减少了,但由于稀疏索引和不连续内存访问,实际墙钟时间未必更快(论文实验中很多近似方法在 (N < 2048) 时反而比稠密注意力慢)。Block-Sparse FlashAttention 消除了这个问题——稀疏收益被干净地兑现为墙钟加速。
3. 可组合性
块稀疏掩码可以自由组合。例如:"因果 + 滑动窗口"就是两个掩码的逐元素 AND。这使得用户可以根据任务需求灵活定制注意力模式,而无需为每种组合编写新的 kernel。
小结
| 维度 | 标准 FlashAttention | Block-Sparse FlashAttention |
|---|---|---|
| FLOPs | (O(N^2 d)) | (O(N^2 d \cdot s)) |
| HBM 访问 | (O(N^2 d^2 / M)) | (O(N^2 d^2 s / M)) |
| 代码改动 | 基线 | 内层循环加一个 if |
| 支持的模式 | 全注意力 | 任意块稀疏模式 |
| 精确性 | 精确 | 在非零块内精确 |
核心观点:FlashAttention 的分块架构让块稀疏变成了"免费午餐"——几乎零代价地将稀疏模式的理论收益转化为实际加速。这也解释了为什么 FlashAttention 迅速成为几乎所有 Transformer 推理框架的底层注意力实现。
后续章节持续更新中
实验结果分析、FlashAttention-2 的改进等内容将在后续更新中补充。