Skip to content
/

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

2026-02-14  ·  推理引擎  ·  论文精读

论文信息

  • 作者: Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
  • 机构: Stanford University, University at Buffalo
  • 发表: NeurIPS 2022
  • 链接: arXiv:2205.14135

一句话总结

FlashAttention 提出了一种 IO 感知(IO-Aware) 的精确注意力算法,通过 分块计算(Tiling)核函数融合(Kernel Fusion) 避免在 GPU 高带宽内存(HBM)中实体化巨大的注意力矩阵,将注意力计算的内存复杂度从 (O(N^2)) 降至 (O(N)),同时在墙钟时间上比标准注意力快 2-4 倍


Introduction:为什么需要 FlashAttention?

1. Transformer 的核心瓶颈:Self-Attention

Transformer 已成为 NLP、CV、语音等领域的基础架构。然而,自注意力机制(Self-Attention)在序列长度 (N) 上具有 (O(N^2)) 的时间和空间复杂度,这成为 Transformer 处理长序列的根本瓶颈。

标准注意力的计算流程

Attention(Q,K,V)=softmax(QKd)V

在标准实现中,这个过程需要:

  1. 计算 (S = QK^\top),生成一个 (N \times N) 的注意力分数矩阵
  2. 对 (S) 施加 softmax,得到 (P = \text{softmax}(S))
  3. 计算输出 (O = PV)

问题在于:中间矩阵 (S) 和 (P) 都是 (N \times N) 的,当序列长度 (N = 8192) 时,仅这两个矩阵就需要约 512 MB 显存(fp32)。这不仅占用大量内存,更关键的是需要频繁在 GPU 的不同存储层级之间搬运数据。

2. GPU 内存层次结构:被忽视的瓶颈

论文指出,现有工作几乎都以 FLOP 数(浮点运算次数) 作为优化目标,但这忽略了一个关键事实:现代 GPU 的计算速度远超内存读写速度

以 A100 GPU 为例:

存储层级容量带宽
SRAM(片上缓存,每个 SM)~20 MB(总计)~19 TB/s
HBM(高带宽内存)40/80 GB~2 TB/s

可以看到,SRAM 的带宽是 HBM 的近 10 倍,但容量却小得多。标准注意力实现的问题在于:

标准 Attention 的内存访问模式:

  HBM (慢)                    SRAM (快)                  计算单元
┌──────────┐   读取 Q,K    ┌──────────┐   矩阵乘    ┌──────────┐
│  Q, K, V │ ──────────►  │          │ ──────────► │  S=QK^T  │
│          │              │          │             │          │
│  S (N²)  │ ◄────────── │  结果 S  │ ◄────────── │          │
│          │   写回 S     │          │             │          │
│          │   读取 S     │          │   softmax   │          │
│  P (N²)  │ ◄────────── │  结果 P  │ ◄────────── │  P=sm(S) │
│          │   写回 P     │          │             │          │
│          │   读取 P,V   │          │   矩阵乘    │          │
│  O       │ ◄────────── │  结果 O  │ ◄────────── │  O=PV    │
└──────────┘   写回 O     └──────────┘             └──────────┘

问题:S 和 P 各 N² 大小,反复在 HBM ↔ SRAM 之间搬运!

核心洞察:注意力计算其实是 内存带宽受限(Memory-bound) 的操作,而非计算受限。瓶颈不是"算不过来",而是"数据搬不过来"。

3. 现有的近似注意力方法的困境

为了突破 (O(N^2)) 的限制,研究社区提出了大量的 近似注意力(Approximate Attention) 方法,包括:

  • 稀疏注意力(Sparse Attention):只计算部分位置对的注意力(如 Longformer、BigBird)
  • 低秩近似(Low-rank Approximation):用低秩矩阵近似完整注意力矩阵(如 Linformer、Performer)
  • 线性注意力(Linear Attention):通过核方法将 softmax 近似为可分解形式,实现线性复杂度

然而,论文指出这些方法存在两个共性问题:

  1. 精度损失:近似方法在长序列上经常出现质量退化,尤其是在需要精确建模长距离依赖的任务中
  2. 墙钟时间并未真正加速:虽然 FLOP 数降低了,但由于这些方法往往引入了更多的内存访问开销(如稀疏索引、额外的矩阵变换),在实际 GPU 上跑起来并没有标准注意力快。论文中的实验表明,很多近似方法在序列长度达到 512-2048 之前甚至比标准注意力更慢

一个反直觉的事实

减少 FLOP ≠ 减少运行时间。在 GPU 上,如果一个算法 FLOP 更少但内存访问更多,它完全可能比 FLOP 更多但内存访问模式更优的算法更慢。这就是 FlashAttention 的出发点。

4. FlashAttention 的核心思路

FlashAttention 不走近似路线,而是从 IO 复杂度(IO Complexity) 的角度重新审视标准注意力,通过优化内存访问模式来实现加速,同时保持结果的 数值精确性

核心策略包括两点:

(1)分块计算(Tiling):将 Q、K、V 分成小块,每次只加载一小块到 SRAM 中进行计算,避免实体化完整的 (N \times N) 注意力矩阵。

(2)在线 Softmax(Online Softmax):传统 softmax 需要先遍历整行求最大值和求和,再做归一化——这要求整行数据同时在内存中。FlashAttention 采用了 Milakov & Gimelshein (2018) 提出的在线 softmax 技巧,在分块流式处理的过程中 增量更新 softmax 统计量(running max 和 running sum),无需回头修正。

FlashAttention 的内存访问模式:

  HBM (慢)                    SRAM (快)                  计算单元
┌──────────┐   读取 Q块,   ┌──────────┐   一次性     ┌──────────┐
│  Q, K, V │ ──────────►  │ Q块,K块, │ ──────────► │ 分块计算  │
│          │   K块, V块    │  V块     │             │ S块→P块→ │
│          │              │          │   融合计算   │  O块累加  │
│  O       │ ◄────────── │  O块     │ ◄────────── │          │
└──────────┘   只写最终O   └──────────┘             └──────────┘

优势:
  - 中间矩阵 S、P 从不写回 HBM
  - HBM 读写量从 O(N²) 降至 O(N²d²M⁻¹)(M 为 SRAM 大小)
  - 结果与标准注意力完全一致(精确算法)

5. 论文的主要贡献

论文总结了以下关键贡献:

  1. FlashAttention 算法:一种 IO 感知的精确注意力实现,通过 Tiling 和在线 Softmax 将 HBM 访问量减少为 (O(N^2 d^2 M^{-1})),其中 (d) 是头维度、(M) 是 SRAM 大小。论文还证明了在所有精确注意力算法中,这是 HBM 访问次数的渐近最优下界

  2. Kernel 融合的扩展:将 FlashAttention 扩展到支持常用的注意力变体,包括 带 Mask 的注意力(如因果掩码)和 Dropout,这些操作都在同一个 CUDA Kernel 中完成,避免了额外的内存读写

  3. 长序列建模的实际收益:基于 FlashAttention 的高效实现,论文展示了在多个基准任务上的显著提升:

    • GPT-2 训练速度提升至标准 HuggingFace 实现的 3 倍
    • 支持的序列长度从 1K-2K 拓展到 4K-16K,使 Transformer 首次在长文档分类(如 MIMIC-III)和长序列生成任务上取得 SOTA 表现
    • Path-X(16K 序列长度的合成任务)上首次达到 超过随机水平的准确率
  4. IO 复杂度的理论分析:论文给出了精确注意力的 HBM 访问下界证明,并分析了常见近似/稀疏注意力的 IO 复杂度,为后续注意力优化研究提供了理论基础

为什么叫"Flash"?

Flash 一语双关:既指速度极快(如闪存 Flash Memory),也暗示了算法的核心思想——像闪存一样 感知和优化 IO 访问模式,让数据在正确的存储层级被高效处理。


标准注意力实现:Algorithm 0

在深入 FlashAttention 之前,我们必须先彻底理解"标准实现到底做了什么、代价几何"。论文将其命名为 Algorithm 0,作为后续优化的基线。

形式化定义

给定输入矩阵 (Q, K, V \in \mathbb{R}^{N \times d})((N) 为序列长度,(d) 为头维度),标准注意力的计算分为三步:

S=QKRN×N,P=softmax(S)RN×N,O=PVRN×d

其中 softmax 按行应用。

Algorithm 0 的执行流程

论文给出的标准实现伪代码如下:

Algorithm 0: Standard Attention Implementation
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
输入: Q, K, V ∈ R^{N×d},存储在 HBM 中

Step 1: 从 HBM 分块加载 Q, K,计算 S = QK^T,将 S 写回 HBM
Step 2: 从 HBM 读取 S,计算 P = softmax(S),将 P 写回 HBM
Step 3: 从 HBM 分块加载 P 和 V,计算 O = PV,将 O 写回 HBM

返回: O

逐步拆解 HBM 访问量

我们来精确计算每一步对 HBM 的读写量:

步骤操作HBM 读取HBM 写入说明
Step 1(S = QK^\top)(Q): (Nd) + (K): (Nd)(S): (N^2)矩阵乘法,结果必须写回 HBM
Step 2(P = \text{softmax}(S))(S): (N^2)(P): (N^2)逐元素操作,但 S 太大放不进 SRAM
Step 3(O = PV)(P): (N^2) + (V): (Nd)(O): (Nd)矩阵乘法
合计(3Nd + 2N^2)(Nd + 2N^2)总 HBM 访问 = (4Nd + 4N^2)

关键观察:当 (N \gg d) 时(GPT-2: (N = 1024, d = 64)),HBM 访问量被 (4N^2) 项主导,即 (O(N^2))

而中间矩阵 (S) 和 (P) 各占 (N^2) 的空间——它们的唯一作用是作为"中转站":被写入 HBM 后马上又被读出来。这种"写了就读、读了就扔"的模式是巨大的浪费。

为什么标准实现是 Memory-bound?

我们可以算一笔帐,用 算术强度(Arithmetic Intensity) 来判断操作是计算密集还是内存密集:

算术强度=FLOPsBytes accessed

对于 Step 2(softmax):

  • FLOPs:(O(N^2))(每个元素做 exp、加法、除法)
  • HBM 访问:读 (N^2) + 写 (N^2) = (2N^2) 个元素 = (8N^2) bytes(fp32)
  • 算术强度 ≈ (O(1))

A100 的算术强度平衡点约为 (\frac{312 \text{ TFLOPS}}{2 \text{ TB/s}} = 156) FLOP/byte。softmax 的算术强度远低于这个值,因此它是一个典型的 内存带宽受限 操作——GPU 的计算单元大部分时间在"等数据"。

即使矩阵乘法也受影响

Step 1 和 Step 3 的矩阵乘法本身是计算密集型操作,但因为 (S) 和 (P) 必须经过 HBM 这个"中转站",整个流水线的实际吞吐被 Step 2 的内存瓶颈拖慢了。Fuse 不掉 softmax,前后的 matmul 也快不起来。

Masking 和 Dropout 雪上加霜

论文特别指出,实际应用中注意力矩阵上还要叠加额外的逐元素操作:

  • Masking:因果注意力需要对 (S) 施加下三角掩码,即 (S_{ij} = -\infty) for (j > i)
  • Dropout:训练时对 (P) 施加随机置零

每增加一个逐元素操作,就多一轮 (N^2) 的 HBM 读写。虽然社区已尝试将 masking 和 softmax 融合进同一个 kernel,但 只要 (S) 和 (P) 还存在于 HBM 中,根本问题就没有解决。

动手试一试

下面的 C++ 代码完整模拟了 Algorithm 0 的三步流程,并 精确追踪每一步的 HBM 读写量。运行后你可以直观看到:中间矩阵 (S) 和 (P) 如何主导了内存访问开销。

C++Algorithm 0: 标准注意力实现 — 逐步追踪 HBM 读写

小结

标准注意力实现的核心问题可以用一句话概括:

中间矩阵 (S) 和 (P) 是 (N^2) 大小的"一次性中转站",它们被写入 HBM 后马上被读出,读出后再也不用——但它们主导了整个算法的内存访问开销。

FlashAttention 的解法正是从这里出发:如果我们能在 SRAM 中分块完成 (S \to P \to O) 的全部计算,让 (S) 和 (P) 永远不触碰 HBM,那么 HBM 的访问量就可以从 (O(N^2)) 大幅下降。但这要求我们解决一个技术难题:softmax 是全局操作,如何在只看到一小块数据的情况下正确计算? 这就是下一节 FlashAttention 算法的核心挑战。


FlashAttention 算法:Tiling + Online Softmax

上一节我们看到,标准注意力的瓶颈在于中间矩阵 (S, P \in \mathbb{R}^{N \times N}) 必须在 HBM 中实体化。FlashAttention 的目标很明确:在不实体化 (S) 和 (P) 的前提下,精确计算 (\text{softmax}(QK^\top)V)

这里我们只讨论前向传播(Forward Pass),反向传播的细节见论文 Appendix B。

核心挑战:Softmax 是"全局操作"

分块计算矩阵乘法很简单——把大矩阵切成小块逐块相乘再累加就行。但 softmax 不同:

softmax(Si,:)=eSi,jk=1NeSi,k

分母是对 整行 求和,这意味着要计算第 (i) 行的 softmax,你需要知道 (S) 的整行 (N) 个值。如果我们把 (K) 分成多个块,每次只算出 (S) 的一部分列,怎么做 softmax?

FlashAttention 的答案是:在线 Softmax(Online Softmax)——一种增量式的分块 softmax 算法。

分块 Softmax 的数学推导

单块的 Softmax 统计量

对于一个向量 (x \in \mathbb{R}^{B}),数值稳定的 softmax 需要三个统计量:

m(x):=maxixi,f(x):=[ex1m(x)exBm(x)],(x):=if(x)i

最终 (\text{softmax}(x) = \frac{f(x)}{\ell(x)})。

两块合并:关键递推公式

现在假设我们有两个分块 (x^{(1)}, x^{(2)} \in \mathbb{R}^{B}),要计算拼接向量 (x = [x^{(1)}, x^{(2)}] \in \mathbb{R}^{2B}) 的 softmax。核心递推公式为:

m(x)=max(m(x(1)),m(x(2)))f(x)=[em(x(1))m(x)f(x(1)),em(x(2))m(x)f(x(2))](x)=em(x(1))m(x)(x(1))+em(x(2))m(x)(x(2))

直觉:当新块的最大值更大时,旧块的 exp 值需要"缩小"(乘以 (e^{m_{\text{old}} - m_{\text{new}}} < 1));当旧块最大值更大时,新块被缩小。这个缩放因子保证了数值的正确性。

关键洞察

这个递推公式意味着:我们可以 逐块处理 K 的列方向,每处理一块就更新 (m) 和 (\ell),最终得到的 softmax 结果与一次性处理整行 完全一致——没有任何近似!

从 Softmax 到 Attention 输出的增量更新

分块 softmax 解决了归一化的问题,但注意力的最终输出是 (O = PV),即 softmax 的结果还要和 (V) 做矩阵乘法。我们不能先算完所有 softmax 再乘 (V)——那又回到了实体化 (P) 的老路。

FlashAttention 的做法是 边算 softmax 边累加 (O)。当处理第 (j) 个 K/V 块时:

  1. 计算当前块的局部注意力分数 (\tilde{S}_{ij} = Q_i K_j^\top)
  2. 计算局部统计量 (\tilde{m}{ij}, \tilde{P}, \tilde{\ell}_{ij})
  3. 更新全局统计量 (m_i^{\text{new}}, \ell_i^{\text{new}})
  4. 修正并累加输出
Oidiag(inew)1(diag(i)emiminewOi旧输出修正+em~ijminewP~ijVj新块贡献)

这个公式的含义是:

  • 旧输出修正:之前累加的 (O_i) 是基于旧的最大值 (m_i) 计算的,现在最大值更新为 (m_i^{\text{new}}),需要乘以修正因子 (e^{m_i - m_i^{\text{new}}})
  • 新块贡献:当前块的 softmax 值乘以 (V_j),同样调整到新的最大值尺度
  • 重新归一化:除以新的 (\ell_i^{\text{new}}) 确保概率和为 1

Algorithm 1:FlashAttention 完整算法

将上述思路系统化,就得到了论文中的 Algorithm 1:

Algorithm 1: FlashAttention
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
输入: Q, K, V ∈ R^{N×d} (在 HBM 中), 片上 SRAM 大小 M

1: 设置块大小 Bc = ⌈M/(4d)⌉, Br = min(⌈M/(4d)⌉, d)
2: 在 HBM 中初始化 O = 0^{N×d}, ℓ = 0^N, m = (-∞)^N

3: 将 Q 分为 Tr = ⌈N/Br⌉ 个块: Q₁, ..., Q_Tr  (每块 Br×d)
   将 K, V 分为 Tc = ⌈N/Bc⌉ 个块: K₁,...,K_Tc 和 V₁,...,V_Tc  (每块 Bc×d)
4: 将 O 分为 Tr 个块, ℓ 和 m 也分为 Tr 个块

5: for j = 1 to Tc do                    ← 外层循环: 遍历 K/V 块
6:     从 HBM 加载 Kⱼ, Vⱼ 到 SRAM
7:     for i = 1 to Tr do                ← 内层循环: 遍历 Q 块
8:         从 HBM 加载 Qᵢ, Oᵢ, ℓᵢ, mᵢ 到 SRAM
9:         在 SRAM 中计算 Sᵢⱼ = Qᵢ Kⱼᵀ ∈ R^{Br×Bc}
10:        在 SRAM 中计算:
             m̃ᵢⱼ = rowmax(Sᵢⱼ)          ∈ R^{Br}
             P̃ᵢⱼ = exp(Sᵢⱼ - m̃ᵢⱼ)      ∈ R^{Br×Bc}
             ℓ̃ᵢⱼ = rowsum(P̃ᵢⱼ)          ∈ R^{Br}
11:        在 SRAM 中计算:
             mᵢⁿᵉʷ = max(mᵢ, m̃ᵢⱼ)
             ℓᵢⁿᵉʷ = e^{mᵢ - mᵢⁿᵉʷ} · ℓᵢ + e^{m̃ᵢⱼ - mᵢⁿᵉʷ} · ℓ̃ᵢⱼ
12:        写回 HBM:
             Oᵢ ← diag(ℓᵢⁿᵉʷ)⁻¹ (diag(ℓᵢ)·e^{mᵢ-mᵢⁿᵉʷ}·Oᵢ + e^{m̃ᵢⱼ-mᵢⁿᵉʷ}·P̃ᵢⱼ·Vⱼ)
13:        写回 HBM: ℓᵢ ← ℓᵢⁿᵉʷ, mᵢ ← mᵢⁿᵉʷ
14:    end for
15: end for
16: 返回 O

循环结构的直觉理解

算法采用 双层循环,可以用下面的图来理解:

        K₁    K₂    K₃    K₄     (Tc 个 K/V 块, 外层循环 j)
      ┌─────┬─────┬─────┬─────┐
  Q₁  │ S₁₁ │ S₁₂ │ S₁₃ │ S₁₄ │  ← 第 i=1 轮: 逐块更新 O₁
      ├─────┼─────┼─────┼─────┤
  Q₂  │ S₂₁ │ S₂₂ │ S₂₃ │ S₂₄ │  ← 第 i=2 轮: 逐块更新 O₂
      ├─────┼─────┼─────┼─────┤
  Q₃  │ S₃₁ │ S₃₂ │ S₃₃ │ S₃₄ │  ← 第 i=3 轮: 逐块更新 O₃
      └─────┴─────┴─────┴─────┘

每个小块 Sᵢⱼ 的大小只有 Br × Bc, 完全放得进 SRAM!
整个 N×N 的 S 矩阵从未被完整构造出来.

外层遍历 K/V 块(列方向),内层遍历 Q 块(行方向)。对于每个 ((i, j)) 组合:

  • 计算一个小块 (S_{ij} \in \mathbb{R}^{B_r \times B_c})(在 SRAM 中,不写回 HBM
  • 更新第 (i) 行的 softmax 统计量和输出 (O_i)

块大小的选择

Algorithm 1 第 1 行给出了块大小的设定:

Bc=M4d,Br=min(M4d,d)

这是为了保证 每次循环需要的数据都能放进 SRAM。一次内层迭代需要同时在 SRAM 中保存:

  • (K_j): (B_c \times d)
  • (V_j): (B_c \times d)
  • (Q_i): (B_r \times d)
  • (O_i): (B_r \times d)
  • (S_{ij}): (B_r \times B_c)(中间计算结果)

总计约 (2B_c d + 2B_r d + B_r B_c) 个浮点数,需要不超过 (M) 个元素。

重计算(Recomputation):反向传播的优化

标准反向传播需要保存 (S, P \in \mathbb{R}^{N \times N}) 用于梯度计算。FlashAttention 的策略是:只保存 (O)、(m)、(\ell),在反向传播时重新计算 (S) 和 (P)

方案需要保存额外内存代价
标准反向传播(S, P \in \mathbb{R}^{N \times N})(O(N^2))
梯度检查点不保存,全部重算(O(N))速度慢(重复计算)
FlashAttention(O, m, \ell)(O(N))反而更快(减少 HBM 访问)

这看似矛盾——重新计算不是增加了 FLOP 吗?确实,FlashAttention 的总 FLOP 略多于标准方法。但由于重计算发生在 SRAM 中(带宽 ~19 TB/s),而标准方法需要从 HBM 中读取 (S) 和 (P)(带宽 ~2 TB/s),减少的 HBM 访问远比多出的 FLOP 划算

反直觉的结论

更多的 FLOP + 更少的 HBM 访问 = 更快的实际速度。这再次印证了 FlashAttention 的核心哲学:在 memory-bound 场景下,优化 IO 比优化计算更重要。

Kernel 融合:一个 CUDA Kernel 搞定一切

Tiling 使得 FlashAttention 可以将所有计算步骤融合到 一个 CUDA Kernel 中:

一个 FlashAttention Kernel 内部的完整流程:

┌─────────────────────────────────────────────┐
│  CUDA Kernel (一次启动, 一次执行)            │
│                                             │
│  1. 从 HBM 加载 Q块, K块, V块 → SRAM       │
│  2. SRAM 中: S块 = Q块 × K块ᵀ              │
│  3. SRAM 中: Masking (可选, 因果掩码)        │
│  4. SRAM 中: 在线 Softmax (更新 m, ℓ)       │
│  5. SRAM 中: Dropout (可选, 训练时)          │
│  6. SRAM 中: O块 累加 = P̃块 × V块           │
│  7. 写回 O块, m, ℓ → HBM                    │
│                                             │
│  中间结果 S块, P̃块 全程留在 SRAM, 从不触碰 HBM │
└─────────────────────────────────────────────┘

相比标准实现需要多个 Kernel(matmul → mask → softmax → dropout → matmul),每个 Kernel 之间都要经过 HBM 中转,FlashAttention 只有 一次 HBM 读入 + 一次 HBM 写出

正确性与复杂度(Theorem 1)

论文给出了严格证明(详见 Appendix C):

Theorem 1: Algorithm 1 返回 (O = \text{softmax}(QK^\top)V),所需 FLOPs 为 (O(N^2 d)),额外内存(输入输出之外)为 (O(N))。

  • 精确性:输出与标准注意力 完全一致(不是近似),因为在线 softmax 的递推公式是数学恒等式
  • FLOPs:与标准注意力相同,都是 (O(N^2 d))——FlashAttention 没有减少计算量
  • 额外内存:只需要 (O(N)) 存储 (m) 和 (\ell)(而非标准方法的 (O(N^2)) 存储 (S) 和 (P))

动手试一试

下面的 C++ 代码完整实现了 Algorithm 1,你可以对比上一节的 Algorithm 0,观察两者的输出 完全一致,但 HBM 访问量大幅下降。

C++Algorithm 1: FlashAttention — 分块 Tiling + 在线 Softmax

小结

FlashAttention 的 Algorithm 1 通过三个关键技术实现了 HBM 访问量的大幅下降:

技术解决的问题效果
Tiling(分块)(S, P) 太大放不进 SRAM分成 (B_r \times B_c) 的小块,每块在 SRAM 中完成全部计算
Online Softmax(在线 Softmax)Softmax 需要全局归一化通过维护 (m, \ell) 统计量增量更新,结果精确一致
Recomputation(重计算)反向传播需要 (S, P)只保存 (O, m, \ell),反向时重新计算,减少 (O(N^2) \to O(N)) 额外内存

最终效果:FLOPs 不变(甚至略多),但 HBM 访问从 (O(N^2)) 降至 (O(N^2 d^2 M^{-1}))——在 memory-bound 场景下,这直接转化为 2-4 倍的墙钟加速。


分析:FlashAttention 的 IO 复杂度

上一节我们展示了 FlashAttention 算法的具体实现,但一个关键问题尚未严格回答:FlashAttention 的 HBM 访问量到底是多少?这个量是最优的吗?

论文在 Section 3.2 给出了两个核心理论结果:

  1. Theorem 2(上界):Algorithm 1 的 HBM 访问量为 (O(N^2 d^2 M^{-1}))
  2. Proposition 3(下界):任何精确注意力算法的 HBM 访问量不低于 (\Omega(N^2 d^2 M^{-1}))

两者匹配,说明 FlashAttention 在 HBM 访问次数上是 渐近最优 的。

Theorem 2:HBM 访问量的上界

Theorem 2: 设 (N) 为序列长度,(d) 为头维度,(M) 为 SRAM 大小(以元素计),且 (d \leq M \leq Nd)。Algorithm 1(FlashAttention)的 HBM 访问量为 (O(N^2 d^2 M^{-1}))。

证明方法很直接:逐一统计 Algorithm 1 每一步的 HBM 读写量,然后求和

逐步推导

回顾 Algorithm 1 的块大小设定:

Bc=M4d,Br=min(M4d,d)

对应的块数为:

Tc=NBc,Tr=NBr

现在统计每一层循环的 HBM 访问量:

外层循环(遍历 K/V 块,共 (T_c) 次):

  • 每次加载 (K_j):(B_c \times d) 个元素
  • 每次加载 (V_j):(B_c \times d) 个元素
  • 小计:(2 B_c d) 次读取

内层循环(遍历 Q 块,每个外层迭代执行 (T_r) 次):

  • 读取 (Q_i, O_i):(2 B_r d) 次读取
  • 读取 (\ell_i, m_i):(2 B_r) 次读取
  • 写回 (O_i):(B_r d) 次写入
  • 写回 (\ell_i, m_i):(2 B_r) 次写入
  • 小计:(3 B_r d + 4 B_r) 次访问

总 HBM 访问量

Total=Tc(2Bcd+Tr(3Brd+4Br))

主导项为内层循环部分:(T_c \cdot T_r \cdot 3 B_r d)。接下来分两种情况讨论:

情况 1:(M \leq 4d^2)(SRAM 较小)

此时 (B_r = B_c = \frac{M}{4d}),两个块大小相同。

Tc=4NdM,Tr=4NdM

主导项:

TcTr3Brd=4NdM4NdM3M4dd=16N2d2M23M4=12N2d2M

情况 2:(M > 4d^2)(SRAM 较大)

此时 (B_r = d),(B_c = \frac{M}{4d})。

Tc=4NdM,Tr=Nd

主导项:

TcTr3Brd=4NdMNd3d2=12N2d2M

两种情况均给出 (\Theta!\left(\frac{N^2 d^2}{M}\right))。

直觉理解:这个公式在说什么?

(O(N^2 d^2 M^{-1})) 这个表达式可以拆解理解:

因子含义
(N^2)注意力矩阵的"逻辑大小"——我们必须计算所有 (N^2) 个 token-pair 的交互
(d^2)每对 token 的交互需要 (d) 维向量的内积;分块时每块还需要独立加载 (d) 维数据
(M^{-1})SRAM 越大,每次能处理的数据块越大,需要的"轮次"越少

我们可以检查两个边界条件是否合理:

条件(M) 的值HBM 访问量含义
SRAM 极小(M = d^2)(N^2)退化为标准注意力,和 Algorithm 0 一样
SRAM 极大(M = Nd)(Nd)只需读一遍输入,无需分块

这正好覆盖了从"完全放不下"到"完全放得下"的整个 SRAM 容量谱。

Proposition 3:下界——FlashAttention 是最优的

Proposition 3: 设 (N \geq d),(M) 满足 (d \leq M \leq Nd)。则任何计算精确注意力的算法,其 HBM 访问量不低于 (\Omega(N^2 d^2 M^{-1}))。

这意味着 FlashAttention 的 HBM 访问量不仅仅是"足够好"——它已经达到了 理论最优下界,不可能再进一步减少(在渐近意义上)。

证明思路

下界证明基于 矩阵乘法的 IO 复杂度下界(参考 Hong & Kung 1981 的"红蓝石子博弈"框架):

  1. 规约到矩阵乘法:注意力计算必须(显式或隐式地)完成 (S = QK^\top) 这一步——这是一个 ((N \times d) \times (d \times N)) 的矩阵乘法

  2. 矩阵乘法的 IO 下界:将 (m \times k) 矩阵与 (k \times n) 矩阵相乘,在 SRAM 大小为 (M) 的两级存储模型下,HBM 访问量的下界为:

Ω(mknM)
  1. 应用到注意力:对于 (S = QK^\top),(m = N, k = d, n = N):
Ω(N2dM)
  1. 注意力比纯矩阵乘法更难:但注意力不仅要计算 (S),还要对 (S) 做 softmax 再乘 (V)。softmax 是逐行操作,要求同一行的所有元素在同一时刻可达。论文通过更精细的分析证明,这个额外约束将下界提升至 (\Omega(N^2 d^2 M^{-1}))

为什么下界比纯矩阵乘法更高?

纯矩阵乘法允许任意顺序计算输出元素,但 softmax 引入了 行内全局依赖:你必须看完 (S) 的整行才能做归一化。这迫使算法在 K/V 块的方向上"多扫几遍",每多扫一遍就多一轮 HBM 读写。这个额外约束使 IO 下界从 (\frac{N^2d}{\sqrt{M}}) 提升到 (\frac{N^2d^2}{M})。

上下界匹配的意义

将 Theorem 2 和 Proposition 3 放在一起:

Ω(N2d2M)下界 (Prop. 3)FlashAttention 的 HBM 访问量O(N2d2M)上界 (Thm. 2)

上下界完全匹配:FlashAttention 是渐近最优的精确注意力算法

这是一个非常强的理论保证——它告诉我们:

  1. 不要再找更好的:在这个计算模型下,没有任何精确注意力算法能比 FlashAttention 做更少的 HBM 访问
  2. 要想进一步加速,必须换赛道:只能通过增加 SRAM(硬件改进)、使用近似注意力(允许误差)、或者改变注意力模式(如稀疏)来突破
  3. 硬件设计的指导意义:(M^{-1}) 的依赖关系表明,增加 SRAM 容量对注意力计算有直接的性能收益

动手试一试

下面的代码让你直观感受 SRAM 大小 (M) 如何影响 HBM 访问量,并验证理论公式 (O(N^2 d^2 / M)) 的准确性。

C++IO 复杂度分析: HBM 访问量 vs SRAM 大小 M

与近似注意力方法的 IO 对比

论文还比较了 FlashAttention 与常见近似注意力方法的 IO 复杂度。虽然近似方法的 FLOPs 更少,但它们的 IO 模式未必更优:

方法FLOPsHBM 访问量精确?
标准注意力 (Algo 0)(O(N^2 d))(O(N^2 + Nd))
FlashAttention (Algo 1)(O(N^2 d))(O(N^2 d^2 M^{-1}))
稀疏注意力 (如 Longformer)(O(N \cdot S \cdot d))(O(N \cdot S + Nd))
线性注意力 (如 Performer)(O(Nd^2))(O(Nd))

其中 (S) 是稀疏注意力中每个 token 关注的邻域大小。

一个关键观察

当 (M) 足够大(即 (d^2 / M) 足够小)时,FlashAttention 的 HBM 访问量 (N^2 d^2 / M) 可以远小于标准注意力的 (N^2),甚至逼近线性注意力的 (Nd) 量级——同时保持精确计算。这就是 FlashAttention 在实践中能与近似方法竞争甚至胜出的理论基础。

小结

理论结果结论
Theorem 2(上界)FlashAttention 的 HBM 访问量 (\leq O(N^2 d^2 M^{-1}))
Proposition 3(下界)任何精确注意力算法 (\geq \Omega(N^2 d^2 M^{-1}))
合在一起FlashAttention 是渐近最优的

核心启示:

  1. 算法已到极限:在精确注意力的框架内,FlashAttention 的 IO 效率无法被本质性地超越
  2. SRAM 是关键资源:(M^{-1}) 依赖意味着片上 SRAM 每扩大一倍,HBM 访问量减半。这为硬件设计(更大的 shared memory)和软件优化(更精细的 SRAM 管理)指明了方向
  3. FLOPs 不是全部:FlashAttention 和标准注意力有相同的 FLOPs,但 IO 效率天差地别。在 memory-bound 场景下,优化 IO 比优化计算更重要

扩展:块稀疏 FlashAttention(Block-Sparse FlashAttention)

上两节我们证明了 FlashAttention 在 精确注意力 的框架下已达到 IO 最优。但在很多实际场景中,我们并不需要每个 token 都关注所有其他 token——稀疏注意力(Sparse Attention) 是一种广泛使用的提升效率的手段。

论文的一个重要贡献是:FlashAttention 的分块架构天然适合与块稀疏模式结合,只需 极少的修改 就能获得 IO 和 FLOPs 的双重收益。

什么是块稀疏注意力?

标准注意力计算完整的 (N \times N) 注意力矩阵。块稀疏注意力则引入一个 块级掩码矩阵

M~{0,1}Tr×Tc

其中 (T_r = \lceil N / B_r \rceil),(T_c = \lceil N / B_c \rceil)。当 (\tilde{M}{ij} = 0) 时,表示 Q 的第 (i) 块和 K 的第 (j) 块之间的注意力被完全跳过——对应的 (S) 块既不计算,也不参与 softmax。

块稀疏掩码示例 (N=16, Br=Bc=4, 所以 Tr=Tc=4):

        K₁   K₂   K₃   K₄
      ┌────┬────┬────┬────┐
  Q₁  │ 1  │ 1  │ 0  │ 0  │  → Q₁ 只关注 K₁, K₂
      ├────┼────┼────┼────┤
  Q₂  │ 1  │ 1  │ 1  │ 0  │  → Q₂ 关注 K₁, K₂, K₃
      ├────┼────┼────┼────┤
  Q₃  │ 0  │ 1  │ 1  │ 1  │  → Q₃ 关注 K₂, K₃, K₄
      ├────┼────┼────┼────┤
  Q₄  │ 0  │ 0  │ 1  │ 1  │  → Q₄ 只关注 K₃, K₄
      └────┴────┴────┴────┘

  1 = 计算该块     0 = 跳过该块
  非零块数: 10/16 = 62.5%  → 稀疏度 s = 0.625

从 Algorithm 1 到 Block-Sparse FlashAttention

由于 FlashAttention 已经将注意力按 ((B_r, B_c)) 大小分块处理,加入块稀疏支持只需 一行修改:在内层循环中,检查掩码 (\tilde{M}_{ij}),如果为 0 则直接跳过该块。

Algorithm 5: Block-Sparse FlashAttention (基于 Algorithm 1 的修改)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
输入: Q, K, V ∈ R^{N×d}, 块稀疏掩码 M~ ∈ {0,1}^{Tr×Tc}

... (其余初始化与 Algorithm 1 相同) ...

5: for j = 1 to Tc do
6:     从 HBM 加载 Kⱼ, Vⱼ 到 SRAM
7:     for i = 1 to Tr do
7.5:       if M~[i][j] = 0 then continue    ← 唯一的新增行!
8:         从 HBM 加载 Qᵢ, Oᵢ, ℓᵢ, mᵢ 到 SRAM
9-13:     (与 Algorithm 1 完全相同)
14:    end for
15: end for

返回: O

这个"一行修改"看似简单,背后的深意是:FlashAttention 的分块粒度恰好与块稀疏掩码的粒度对齐。不需要任何索引重排、scatter/gather 操作或额外的数据结构——只需跳过零块即可。

常见的稀疏注意力模式

块稀疏框架可以统一表达多种流行的稀疏注意力模式。定义稀疏度 (s) 为非零块的比例:(s = \frac{|{(i,j) : \tilde{M}_{ij} = 1}|}{T_r \times T_c})。

1. 滑动窗口注意力(Local / Sliding Window)

每个 token 只关注前后 (w) 个位置范围内的 token:

M~ij=1[|ij|w/Bc]
滑动窗口 (w=2 个块):

  ┌─┬─┬─┬─┬─┬─┬─┬─┐
  │█│█│█│ │ │ │ │ │   稀疏度: s = (2w+1)/Tc
  │█│█│█│█│ │ │ │ │   当 w << N 时, s ≈ 2w/N
  │█│█│█│█│█│ │ │ │
  │ │█│█│█│█│█│ │ │   典型应用:
  │ │ │█│█│█│█│█│ │   - Longformer
  │ │ │ │█│█│█│█│█│   - Mistral
  │ │ │ │ │█│█│█│█│   - 局部上下文建模
  │ │ │ │ │ │█│█│█│
  └─┴─┴─┴─┴─┴─┴─┴─┘

2. 因果注意力(Causal Attention)

自回归模型的标准模式——每个 token 只能看到自己及之前的 token:

M~ij=1[ij]
因果掩码 (下三角):

  ┌─┬─┬─┬─┬─┬─┬─┬─┐
  │█│ │ │ │ │ │ │ │   稀疏度: s ≈ 0.5
  │█│█│ │ │ │ │ │ │
  │█│█│█│ │ │ │ │ │   典型应用:
  │█│█│█│█│ │ │ │ │   - GPT 系列
  │█│█│█│█│█│ │ │ │   - LLaMA
  │█│█│█│█│█│█│ │ │   - 所有自回归 LLM
  │█│█│█│█│█│█│█│ │
  │█│█│█│█│█│█│█│█│
  └─┴─┴─┴─┴─┴─┴─┴─┘

3. 全局 + 滑动窗口(Global + Local)

少数特殊 token(如 [CLS]、句首)关注所有位置,其余 token 使用滑动窗口:

全局(前 2 行) + 滑动窗口:

  ┌─┬─┬─┬─┬─┬─┬─┬─┐
  │█│█│█│█│█│█│█│█│   ← 全局 token
  │█│█│█│█│█│█│█│█│   ← 全局 token
  │█│█│█│█│█│ │ │ │
  │█│█│█│█│█│█│ │ │   典型应用:
  │█│█│ │█│█│█│█│ │   - BigBird
  │█│█│ │ │█│█│█│█│   - Longformer
  │█│█│ │ │ │█│█│█│   - ETC
  │█│█│ │ │ │ │█│█│
  └─┴─┴─┴─┴─┴─┴─┴─┘

4. 步幅稀疏(Strided / Dilated)

以固定间隔采样关注位置,捕捉长距离依赖:

步幅稀疏 (stride=2):

  ┌─┬─┬─┬─┬─┬─┬─┬─┐
  │█│ │█│ │█│ │█│ │   稀疏度: s = 1/stride
  │ │█│ │█│ │█│ │█│
  │█│ │█│ │█│ │█│ │   典型应用:
  │ │█│ │█│ │█│ │█│   - Sparse Transformer
  │█│ │█│ │█│ │█│ │   - 长序列音频建模
  │ │█│ │█│ │█│ │█│
  │█│ │█│ │█│ │█│ │
  │ │█│ │█│ │█│ │█│
  └─┴─┴─┴─┴─┴─┴─┴─┘

IO 复杂度分析(Proposition 4)

Proposition 4: 设块稀疏掩码的非零块比例为 (s)((0 < s \leq 1))。Block-Sparse FlashAttention 的 HBM 访问量为:

O(N2d2sM)

证明直觉:在 Algorithm 1 中,内层循环的迭代次数从 (T_r \times T_c) 减少为 (s \cdot T_r \times T_c)(只处理非零块)。每次迭代的 HBM 访问量不变,因此总量线性地乘以稀疏度 (s)。

与各方案的完整对比:

方法FLOPsHBM 访问量精确?适用场景
标准注意力(O(N^2 d))(O(N^2))基线
FlashAttention(O(N^2 d))(O(N^2 d^2 / M))通用加速
Block-Sparse Flash(O(N^2 d \cdot s))(O(N^2 d^2 s / M))在稀疏模式内精确已知稀疏模式
近似注意力各异各异特定任务

Block-Sparse 的双重收益

注意 Block-Sparse FlashAttention 同时获得了 FLOPs 和 IO 两方面的节省——两者都乘以稀疏度 (s)。这与标准稀疏注意力不同:后者虽然减少了 FLOPs,但由于稀疏索引和 scatter/gather 操作,IO 反而可能更差。FlashAttention 的分块架构让稀疏模式的收益被"干净地"兑现。

为什么 FlashAttention 特别适合块稀疏?

之所以 FlashAttention + 块稀疏如此自然,根本原因是 粒度对齐

传统稀疏注意力:                    Block-Sparse FlashAttention:

  逐元素稀疏 → 需要稀疏索引          块级稀疏 → 只需 skip 整个块
  ┌─────────────────┐               ┌─────────────────┐
  │ 1 0 1 0 0 1 0 0 │               │ ██ │    │    │    │
  │ 0 1 0 0 1 0 0 1 │               │────┼────┼────┼────│
  │ 1 0 1 0 0 0 1 0 │   vs          │ ██ │ ██ │    │    │
  │ 0 0 0 1 0 1 0 0 │               │────┼────┼────┼────│
  │ 0 1 0 0 1 0 0 1 │               │    │ ██ │ ██ │    │
  │ 1 0 0 1 0 1 0 0 │               │────┼────┼────┼────│
  │ 0 0 1 0 0 0 1 0 │               │    │    │ ██ │ ██ │
  │ 0 1 0 0 1 0 0 1 │               └─────────────────┘
  └─────────────────┘
  - 需要 CSR/COO 等稀疏格式          - 掩码只有 Tr×Tc 个 bit
  - scatter/gather 内存不连续         - 整块 skip, 内存访问连续
  - GPU 利用率低 (warp divergence)    - 无 warp divergence
  - 额外索引开销抵消稀疏收益          - 稀疏收益被干净兑现

动手试一试

下面的代码对比标准 FlashAttention 和 Block-Sparse FlashAttention 在不同稀疏模式下的 HBM 访问量和 FLOPs。

C++Block-Sparse FlashAttention: 稀疏模式对比

实际影响

Block-Sparse FlashAttention 的实践意义体现在几个方面:

1. 统一框架

之前的稀疏注意力实现各自为政——Longformer 有自己的 CUDA kernel,BigBird 有自己的,Sparse Transformer 又是另一套。Block-Sparse FlashAttention 提供了一个统一的高效实现:只需指定不同的块掩码 (\tilde{M}),就能支持任意块稀疏模式。

2. IO 高效的稀疏注意力

传统稀疏注意力的一个尴尬是:虽然 FLOPs 减少了,但由于稀疏索引和不连续内存访问,实际墙钟时间未必更快(论文实验中很多近似方法在 (N < 2048) 时反而比稠密注意力慢)。Block-Sparse FlashAttention 消除了这个问题——稀疏收益被干净地兑现为墙钟加速。

3. 可组合性

块稀疏掩码可以自由组合。例如:"因果 + 滑动窗口"就是两个掩码的逐元素 AND。这使得用户可以根据任务需求灵活定制注意力模式,而无需为每种组合编写新的 kernel。

小结

维度标准 FlashAttentionBlock-Sparse FlashAttention
FLOPs(O(N^2 d))(O(N^2 d \cdot s))
HBM 访问(O(N^2 d^2 / M))(O(N^2 d^2 s / M))
代码改动基线内层循环加一个 if
支持的模式全注意力任意块稀疏模式
精确性精确在非零块内精确

核心观点:FlashAttention 的分块架构让块稀疏变成了"免费午餐"——几乎零代价地将稀疏模式的理论收益转化为实际加速。这也解释了为什么 FlashAttention 迅速成为几乎所有 Transformer 推理框架的底层注意力实现。


后续章节持续更新中

实验结果分析、FlashAttention-2 的改进等内容将在后续更新中补充。

推荐阅读

Released under the MIT License.