Skip to content
/

Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism

2026-02-15  ·  分布式训练  ·  论文精读

论文信息

  • 作者: Mohammad Shoeybi, Mostofa Patwary, Raul Puri, Patrick LeGresley, Jared Casper, Bryan Catanzaro
  • 机构: NVIDIA
  • 发表: arXiv 2019 (后续被广泛引用)
  • 链接: arXiv:1909.08053

一句话总结

Megatron-LM 提出了一套针对 Transformer 架构的 层内张量并行(Intra-layer Tensor Parallelism) 方案,通过对 MLP 层、Self-Attention 层、Embedding 层和交叉熵损失的精心切分,仅需 每层 2 次 AllReduce(前向 + 反向各 1 次)即可实现高效的模型并行,在 512 张 V100 上训练了当时最大的 83 亿参数 Transformer 语言模型,达到 76% 的弱扩展效率


Introduction:为什么需要 Megatron-LM?

1. 大模型的崛起与单卡内存瓶颈

2018-2019 年,预训练语言模型的参数量快速增长:

模型发布时间参数量关键突破
BERT-Large2018.10340M双向预训练
GPT-22019.021.5B大规模单向 LM
XLNet2019.06340M排列语言模型
Megatron-LM2019.098.3B高效张量并行

GPT-2 (1.5B) 的参数在 fp16 下占 3 GB,加上优化器状态和激活值,单张 V100 (32 GB) 已经非常紧张。想训练 8B+ 的模型,必须将模型切分到多张 GPU 上。

2. 数据并行的局限

ZeRO 系列 解决了数据并行中的内存冗余问题。但在 ZeRO 发表之前(2019 年),标准数据并行的方式是每张卡存完整模型,受限于单卡显存。

即使有了 ZeRO,数据并行仍有一个根本限制:它不能减少单层的计算/内存开销。如果模型的一个 Transformer 层就超过了单卡显存,数据并行无能为力——你需要将 层内的参数切分到多张 GPU 上。

3. 现有模型并行方法的问题

2019 年已有一些模型并行的实现,但都存在明显缺陷:

流水线并行(Pipeline Parallelism)

  • 将不同层放到不同 GPU
  • 简单直观,但存在严重的 流水线气泡(Pipeline Bubble)——当前方 GPU 在算前向传播时,后方 GPU 在空等
  • 典型效率只有 50-60%

朴素张量并行(Naive Tensor Parallelism)

  • 将每层的参数矩阵简单切分
  • 需要大量的 点对点通信(Send/Recv),通信模式复杂
  • 难以在现有深度学习框架中高效实现

Mesh-TensorFlow

  • Google 提出的 TPU 上的张量并行框架
  • 需要重新定义整个计算图,与 PyTorch/TensorFlow 的编程模型不兼容
  • 在 GPU 集群上表现不佳

4. Megatron-LM 的核心思想

Megatron-LM 的创新在于:针对 Transformer 的具体结构,设计精巧的张量切分方式,使得每层只需要 2 次 AllReduce 操作

核心原则:

  1. 利用矩阵乘法的可分解性:(Y = XA) 中,如果按列切分 (A = [A_1, A_2]),则 (Y = [XA_1, XA_2])——每张卡可以独立计算一部分
  2. 最小化同步点:精心设计切分方式,让需要同步的位置恰好在层的边界,每层只需 1 次前向 AllReduce + 1 次反向 AllReduce
  3. 不修改模型代码的语义:只需在现有 PyTorch 代码中插入少量通信原语

张量并行 vs 数据并行 vs 流水线并行

  • 数据并行(DP):每卡存完整模型,切分数据 → 通信梯度
  • 张量并行(TP):每卡存 一层的一部分,共同计算一个 batch → 通信中间激活值
  • 流水线并行(PP):每卡存 不同的层,流水线执行 → 通信层间激活值

Megatron-LM 专注于 TP,后续论文(Megatron-LM v2, 2021)将三者结合形成 3D 并行

5. 论文的主要贡献

  1. Transformer 专用的张量并行方案:针对 MLP 层和 Self-Attention 层设计了简洁高效的切分策略,每层仅需 2 次 AllReduce

  2. Embedding 层和 Cross-Entropy 的并行化:对输入/输出 Embedding 和交叉熵损失函数也进行了并行切分,避免在巨大词表上的冗余计算

  3. 高效工程实现:在 PyTorch 中用简洁的自定义算子(fg)实现,无需修改框架底层

  4. 规模化验证:成功训练了 8.3B 参数的 GPT-2 和 3.9B 参数的 BERT,在 512 张 V100 上达到 76% 的弱扩展效率

  5. 下游任务 SOTA:8.3B GPT-2 在 WikiText-103 上达到困惑度 10.8(当时 SOTA),3.9B BERT 在多项 NLU 基准上超越 RoBERTa


预备知识:Transformer 的计算结构

在深入 Megatron-LM 的切分方案之前,我们需要回顾 Transformer 的核心计算模块。

标准 Transformer 层

一个标准的 Transformer 层由两个子模块组成:

Transformer 层的计算流程:

输入 X ∈ R^{b×s×h}  (batch × sequence × hidden)


┌───────────────────────┐
│  Multi-Head Attention  │
│                       │
│  Q = XW_Q             │  W_Q, W_K, W_V ∈ R^{h×h}
│  K = XW_K             │
│  V = XW_V             │
│  Attn = softmax(QK^T/√d)V  │
│  Y = Attn · W_O       │  W_O ∈ R^{h×h}
│                       │
│  输出 = LayerNorm(X + Y)  │  ← 残差连接 + 归一化
└───────────┬───────────┘


┌───────────────────────┐
│  Feed-Forward (MLP)    │
│                       │
│  H = GeLU(X · W_1)    │  W_1 ∈ R^{h×4h}  (扩展 4 倍)
│  Y = H · W_2          │  W_2 ∈ R^{4h×h}  (缩回)
│                       │
│  输出 = LayerNorm(X + Y)  │  ← 残差连接 + 归一化
└───────────┬───────────┘


        输出 X' ∈ R^{b×s×h}

参数量分析

对于隐藏维度 (h) 的 Transformer 层:

模块参数矩阵形状参数量
Q 投影(W_Q)(h \times h)(h^2)
K 投影(W_K)(h \times h)(h^2)
V 投影(W_V)(h \times h)(h^2)
输出投影(W_O)(h \times h)(h^2)
MLP 第一层(W_1)(h \times 4h)(4h^2)
MLP 第二层(W_2)(4h \times h)(4h^2)
每层合计(12h^2)

对于 8.3B 参数的模型((h = 3072, L = 72) 层):每层约 (12 \times 3072^2 \approx 113M) 参数。

核心观察:矩阵乘法可以按列/行切分

这是理解 Megatron-LM 的数学基础。对于矩阵乘法 (Y = XA):

按列切分(Column Parallelism)

A=[A1,A2]Y=X[A1,A2]=[XA1,XA2]=[Y1,Y2]

每张卡拿到 (A) 的一部分列,用完整的 (X) 做乘法,得到 (Y) 的一部分列。无需通信就能独立计算。

按行切分(Row Parallelism)

A=[A1A2],X=[X1,X2]Y=X1A1+X2A2

每张卡拿到 (A) 的一部分行和 (X) 的对应列,做乘法后需要 AllReduce 求和


MLP 层的张量并行

MLP 层是 Megatron-LM 张量并行方案的核心。论文的切分方式简洁而精妙。

MLP 的计算结构

标准 MLP 由两个线性变换 + GeLU 激活组成:

Y=GeLU(XW1)W2

其中 (W_1 \in \mathbb{R}^{h \times 4h})(扩展),(W_2 \in \mathbb{R}^{4h \times h})(收缩)。

切分策略:列并行 + 行并行

Megatron-LM 对两个权重矩阵采用互补的切分方式:

  • (W_1):按列切分(Column Parallel) → 每张卡得到 (W_1) 的一部分列
  • (W_2):按行切分(Row Parallel) → 每张卡得到 (W_2) 的一部分行
MLP 层的张量并行 (2 张 GPU):

           GPU 0                        GPU 1
     ┌──────────────┐             ┌──────────────┐
     │              │             │              │
X ──→│ W_1 的左半列  │         X ──→│ W_1 的右半列  │
     │ h × 2h       │             │ h × 2h       │
     │      ↓       │             │      ↓       │
     │  GeLU(XW_1₁) │             │  GeLU(XW_1₂) │
     │  [b,s,2h]    │             │  [b,s,2h]    │
     │      ↓       │             │      ↓       │
     │ W_2 的上半行  │             │ W_2 的下半行  │
     │ 2h × h       │             │ 2h × h       │
     │      ↓       │             │      ↓       │
     │    Y₁        │             │    Y₂        │
     └──────┬───────┘             └──────┬───────┘
            │                            │
            └──────── AllReduce ─────────┘
                    Y = Y₁ + Y₂

为什么这样切分是对的?

让我们用数学严格验证。设 (t = 2) 张 GPU,将 (W_1) 按列分为 (W_1 = [W_{1,1}, W_{1,2}]),将 (W_2) 按行分为 (W_2 = \begin{bmatrix} W_{2,1} \ W_{2,2} \end{bmatrix})。

Step 1:每张卡独立计算第一个线性层 + 激活函数:

Hi=GeLU(XW1,i),i=1,2

关键问题:GeLU 是非线性函数,(\text{GeLU}(XW_1)) 能否拆成 ([\text{GeLU}(XW_{1,1}), \text{GeLU}(XW_{1,2})])?

答案是可以的! 因为列切分意味着 (XW_1 = [XW_{1,1}, XW_{1,2}]),GeLU 是逐元素操作,所以:

GeLU([XW1,1,XW1,2])=[GeLU(XW1,1),GeLU(XW1,2)]

这正是列切分与非线性激活函数兼容的原因。

如果按行切分 (W_1) 呢?

如果 (W_1) 按行切分,则 (XW_1 = X_1 W_{1,1} + X_2 W_{1,2}),在 GeLU 之前需要先 AllReduce 求和。因为 (\text{GeLU}(a+b) \neq \text{GeLU}(a) + \text{GeLU}(b))——非线性函数不能分配到加法上。这就多了一次通信!所以 第一层必须按列切分

Step 2:每张卡独立计算第二个线性层:

Yi=HiW2,i,i=1,2

注意 (H_i \in \mathbb{R}^{b \times s \times 2h}) 和 (W_{2,i} \in \mathbb{R}^{2h \times h}),维度刚好匹配。

Step 3:AllReduce 求和得到最终结果:

Y=Y1+Y2=H1W2,1+H2W2,2=[H1,H2][W2,1W2,2]=GeLU(XW1)W2

结果与单卡计算完全一致!

通信分析

MLP 层的前向传播只需要 1 次 AllReduce(在 (W_2) 之后),反向传播也只需要 1 次 AllReduce(将梯度传回 (W_1) 之前)。

方向AllReduce 次数数据量位置
前向1(b \times s \times h)(W_2) 输出后
反向1(b \times s \times h)(W_1) 梯度前

Self-Attention 层的张量并行

多头注意力的天然可并行性

多头注意力(Multi-Head Attention)有一个天然的并行结构:各个头之间是完全独立的计算

标准 Multi-Head Attention:

MultiHead(X)=Concat(head1,,headk)WO

其中 (\text{head}_i = \text{Attention}(XW_Q^i, XW_K^i, XW_V^i))。

Megatron-LM 的策略是:将注意力头均匀分配到各 GPU

切分方式

设有 (t) 张 GPU,(k) 个注意力头(要求 (k) 能被 (t) 整除),每张卡分到 (k/t) 个头:

Self-Attention 的张量并行 (t=2, k=16 个头):

              GPU 0                         GPU 1
        ┌───────────────┐            ┌───────────────┐
        │ 头 0-7 (8个头) │            │ 头 8-15(8个头) │
    X ──→ W_Q₁,W_K₁,W_V₁│        X ──→ W_Q₂,W_K₂,W_V₂│
        │     ↓         │            │     ↓         │
        │ Q₁,K₁,V₁     │            │ Q₂,K₂,V₂     │
        │     ↓         │            │     ↓         │
        │ Attention₁    │            │ Attention₂    │
        │ (8 个头的输出)  │            │ (8 个头的输出)  │
        │     ↓         │            │     ↓         │
        │   × W_O₁      │            │   × W_O₂      │
        │ (行并行切分)    │            │ (行并行切分)    │
        │     ↓         │            │     ↓         │
        │    Y₁         │            │    Y₂         │
        └──────┬────────┘            └──────┬────────┘
               │                            │
               └──────── AllReduce ─────────┘
                       Y = Y₁ + Y₂

实现细节

投影矩阵 (W_Q, W_K, W_V) 的切分方式与 MLP 的 (W_1) 相同——按列切分

WQ=[WQ(1),WQ(2)],WK=[WK(1),WK(2)],WV=[WV(1),WV(2)]

每张卡的 (W_Q^{(i)} \in \mathbb{R}^{h \times (h/t)}),对应 (k/t) 个头的投影。

输出投影矩阵 (W_O) 按行切分(与 MLP 的 (W_2) 相同),每张卡计算部分结果后 AllReduce 求和。

整个 Self-Attention 层也只需要 1 次前向 AllReduce + 1 次反向 AllReduce。

与 MLP 的统一视角

模块第一组权重切分方式第二组权重切分方式通信
MLP(W_1) (h×4h)列并行(W_2) (4h×h)行并行1 AllReduce
Attention(W_{QKV}) (h×3h)列并行(W_O) (h×h)行并行1 AllReduce

两者遵循完全相同的模式:第一个矩阵列切分 → 各卡独立计算 → 第二个矩阵行切分 → AllReduce 求和


通信原语 (f) 和 (g)

论文定义了两个简洁的通信原语来封装张量并行的通信逻辑。

定义

f:前向 = Identity(恒等),反向 = AllReduceg:前向 = AllReduce,反向 = Identity(恒等)

含义

  • (f) 放在层的输入端:前向传播时,每张卡直接使用完整的输入(因为上一层的 (g) 已经 AllReduce 过了);反向传播时,需要 AllReduce 收集梯度
  • (g) 放在层的输出端:前向传播时,AllReduce 汇总各卡的部分结果;反向传播时,梯度直接分发给各卡

在 Transformer 层中的应用

一个完整 Transformer 层的通信模式:

输入 X (所有 GPU 相同)

    f (前向: Identity, 反向: AllReduce)

┌───┴───────────────────────────┐
│ Self-Attention (列并行 → 行并行) │
└───┬───────────────────────────┘

    g (前向: AllReduce, 反向: Identity)

    + (残差连接)

    LayerNorm

    f (前向: Identity, 反向: AllReduce)

┌───┴───────────────────────────┐
│ MLP (列并行 → 行并行)           │
└───┬───────────────────────────┘

    g (前向: AllReduce, 反向: Identity)

    + (残差连接)

    LayerNorm

输出 X' (所有 GPU 相同)

总通信: 前向 2 次 AllReduce (两个 g)
       反向 2 次 AllReduce (两个 f)

PyTorch 实现

论文的 (f) 和 (g) 在 PyTorch 中只需几行自定义 autograd.Function

python
class f(torch.autograd.Function):
    """输入端通信原语: 前向=Identity, 反向=AllReduce"""
    @staticmethod
    def forward(ctx, x):
        return x  # 前向不通信

    @staticmethod
    def backward(ctx, grad):
        # 反向 AllReduce: 收集所有卡的梯度
        torch.distributed.all_reduce(grad)
        return grad


class g(torch.autograd.Function):
    """输出端通信原语: 前向=AllReduce, 反向=Identity"""
    @staticmethod
    def forward(ctx, x):
        # 前向 AllReduce: 汇总所有卡的部分结果
        torch.distributed.all_reduce(x)
        return x

    @staticmethod
    def backward(ctx, grad):
        return grad  # 反向不通信

优雅的设计

通过 (f) 和 (g),Megatron-LM 将所有通信逻辑封装成两个 即插即用 的原语。在现有 PyTorch 模型代码中,只需在每个并行子模块的输入端插入 f、输出端插入 g,就完成了张量并行的改造——无需修改任何计算逻辑。


Embedding 层的并行化

输入 Embedding

输入 Embedding 矩阵 (E \in \mathbb{R}^{V \times h})((V) 为词表大小,(h) 为隐藏维度)可能非常大。以 GPT-2 为例:(V = 50257, h = 3072),则 (E) 有约 154M 参数。

Megatron-LM 按行切分 Embedding 矩阵(即按词表维度切分):

E=[E1E2Et],EiR(V/t)×h

每张 GPU 负责词表中 (V/t) 个 token 的 Embedding 查找。对于不在本卡词表范围内的 token,输出全零向量。最后通过 AllReduce 求和得到完整结果。

Embedding 并行 (t=2, V=50000):

输入 token_ids: [103, 25001, 7, 30000]

GPU 0 (词表 0-24999):                GPU 1 (词表 25000-49999):
  103    → E[103]                      103    → [0, 0, ..., 0]
  25001  → [0, 0, ..., 0]             25001  → E[25001]
  7      → E[7]                        7      → [0, 0, ..., 0]
  30000  → [0, 0, ..., 0]             30000  → E[30000]

         AllReduce (求和)

  结果: [E[103], E[25001], E[7], E[30000]]  ← 完整的 Embedding 输出

输出 Embedding(语言模型头)

语言模型的输出层需要将隐藏状态投影到词表维度:(\text{logits} = X W_{\text{out}}^\top),其中 (W_{\text{out}} \in \mathbb{R}^{V \times h})。

为了避免内存浪费,Megatron-LM 与输入 Embedding 共享权重(tied weights),即 (W_{\text{out}} = E)。输出层采用同样的切分方式,每张卡计算 (V/t) 个 token 的 logits。


Cross-Entropy 损失的并行化

问题:Softmax 需要完整的 logits

语言模型的 Cross-Entropy 损失需要对整个词表做 Softmax:

loss=logezyj=1Vezj

其中 (z_j) 是第 (j) 个 token 的 logit。Softmax 的分母是 对所有 (V) 个 logit 求和,这要求所有 logit 在同一个设备上。

如果先 AllGather 所有 logit 到每张卡((b \times s \times V) 的张量),当 (V) 很大时(如 50K),这个张量会非常大——比模型参数还大

Megatron-LM 的解决方案

论文的做法是 在分布式下直接计算 Cross-Entropy,避免 AllGather 完整 logit:

Step 1:每张卡在本地 logit 上计算局部最大值和局部 exp 之和

mi=maxjlocalzj,si=jlocalezjmi

Step 2:AllReduce 求全局最大值和全局 exp 之和

m=maximi,s=isiemim

Step 3:每张卡在本地计算梯度,只对自己负责的 logit 部分计算

这种方式只需要 AllReduce 2 个标量(最大值和 exp 和),比 AllGather 整个 logit 张量高效得多。

Cross-Entropy 并行计算:

GPU 0 (logits 0-24999)         GPU 1 (logits 25000-49999)
  local_max₀ = max(z₀..z₂₄₉₉₉)    local_max₁ = max(z₂₅₀₀₀..z₄₉₉₉₉)
  local_sum₀ = Σ exp(z-local_max₀)  local_sum₁ = Σ exp(z-local_max₁)
         │                                  │
         └──── AllReduce max,sum ───────────┘
         │                                  │
  global_max = max(local_max₀, local_max₁)
  global_sum = sum₀·e^{m₀-m} + sum₁·e^{m₁-m}
         │                                  │
  loss = -z_target + global_max + log(global_sum)
  ↓ 各卡只对自己的 logit 分片计算梯度

通信量: 2 个标量 (max 和 sum), 而非 b×s×V 的完整 logit!

与 FlashAttention 的联系

这种"在线 Softmax"的思路与 FlashAttention 中的在线 Softmax 技巧如出一辙——都是将全局的 max 和 sum 通过增量更新的方式分布式计算,避免实体化完整的 Softmax 输入。


完整 Transformer 的张量并行

将所有模块组合起来,我们可以看到 Megatron-LM 如何并行化一个完整的 GPT 模型:

GPT 模型的完整张量并行 (t=2 张 GPU):

                    GPU 0              GPU 1
                    ┌────┐             ┌────┐
  Input Token IDs ──→ E₁ │             │ E₂ ←── Input Token IDs
                    └──┬─┘             └──┬─┘
                       └── AllReduce ─────┘

                    ┌─────────┴─────────┐
                    │   + Position Emb  │
                    │   Dropout         │
                    └─────────┬─────────┘

         ┌────────────────────┴────────────────────┐
         │              × L 层 Transformer           │
         │                                          │
         │  ┌──── f ────┐          ┌──── f ────┐   │
         │  │ Attn 头0-7 │          │ Attn 头8-15│   │
         │  │ → W_O₁    │          │ → W_O₂    │   │
         │  └──── g ────┘          └──── g ────┘   │
         │        │ AllReduce            │          │
         │        └──────────┬───────────┘          │
         │                   │ + Residual + LN      │
         │  ┌──── f ────┐   │     ┌──── f ────┐   │
         │  │ MLP W₁₁   │   │     │ MLP W₁₂   │   │
         │  │ GeLU       │   │     │ GeLU       │   │
         │  │ W₂₁       │   │     │ W₂₂       │   │
         │  └──── g ────┘   │     └──── g ────┘   │
         │        │ AllReduce│           │          │
         │        └──────────┬───────────┘          │
         │                   │ + Residual + LN      │
         └───────────────────┴──────────────────────┘

                    ┌─────────┴─────────┐
                    │  Final LayerNorm  │
                    └─────────┬─────────┘

                    ┌──── 分片 logit 计算 ────┐
                    │ logit₁   │   logit₂     │
                    └──── 并行 CrossEntropy ───┘

                            Loss

每层通信量汇总

模块前向 AllReduce反向 AllReduce数据形状
Self-Attention1 次1 次(b \times s \times h)
MLP1 次1 次(b \times s \times h)
每层合计2 次2 次

对于 (L) 层的模型,总共 (4L) 次 AllReduce。每次 AllReduce 的数据量为 (b \times s \times h) 个元素。


通信效率的深入分析

AllReduce 的成本模型

对于 (t) 张 GPU 之间的 AllReduce,使用 Ring AllReduce 算法的通信量为:

通信量=2(t1)tn2n (当 t 较大时)

其中 (n) 是要归约的元素数。

Megatron-LM 的总通信量

每步训练的前向传播通信量:

Cforward=2L×2bsh=4Lbsh 元素

反向传播类似,总通信量约 (8Lbsh) 元素。

与数据并行的对比

并行策略通信内容通信量通信频率
数据并行梯度(模型参数)(2 \times 12Lh^2)每步 1 次
张量并行激活值(8Lbsh)每步 (4L) 次

关键区别:

  • 数据并行:通信量与 模型参数量 成正比,通信次数少(每步 1 次 AllReduce)
  • 张量并行:通信量与 激活值大小 成正比,通信次数多(每层 4 次 AllReduce)

当 batch size (b) 和序列长度 (s) 较大时,张量并行的通信量可能超过数据并行。这就是为什么 张量并行适合节点内(NVLink 高带宽)数据并行适合跨节点(以太网/InfiniBand)

V100 DGX-2 节点内 8 张 GPU 通过 NVSwitch 互连,提供每对 GPU 之间 300 GB/s 的双向带宽,比 PCIe 3.0 (32 GB/s) 快约 10 倍。

通信带宽与并行策略的匹配:

节点内 (NVLink/NVSwitch):
  ┌─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┐
  │GPU0 │GPU1 │GPU2 │GPU3 │GPU4 │GPU5 │GPU6 │GPU7 │
  │←──────── 300 GB/s 全互连 (NVSwitch) ────────→│
  └─────┴─────┴─────┴─────┴─────┴─────┴─────┴─────┘
  → 张量并行 (TP): 频繁通信, 需要高带宽 ✓

跨节点 (InfiniBand):
  ┌──────────┐     100 Gb/s     ┌──────────┐
  │  节点 0   │ ←───────────→  │  节点 1   │
  └──────────┘                  └──────────┘
  → 数据并行 (DP): 通信少, 带宽要求低 ✓

混合精度训练

Megatron-LM 使用 混合精度训练 来加速计算和减少内存:

策略

操作精度原因
前向/反向矩阵乘法fp16利用 V100 Tensor Core (125 TFLOPS)
权重存储fp16 (计算用) + fp32 (主副本)fp16 计算快,fp32 保证更新精度
激活值fp16减少内存和通信带宽
损失缩放(Loss Scaling)动态防止 fp16 下溢
AllReducefp16减少通信量
优化器状态fp32数值稳定性

动态 Loss Scaling

fp16 的可表示范围有限(最小正数约 (6 \times 10^{-8})),小梯度值可能被截断为 0(下溢)。动态 Loss Scaling 的策略:

  1. 用一个 scale factor (S) 乘以 loss(初始值通常为 (2^{16}))
  2. 反向传播的梯度也被放大 (S) 倍(链式法则自动实现)
  3. 优化器更新前将梯度除以 (S)
  4. 如果出现 inf/NaN,跳过这步更新并减小 (S)
  5. 如果连续多步没有 inf/NaN,增大 (S)

激活值内存优化

激活值检查点(Activation Checkpointing)

对于大模型,激活值的内存消耗可能远超参数本身。Megatron-LM 使用选择性的 激活值检查点 策略:

  • 只在每个 Transformer 层的 输入处 保存检查点
  • 反向传播时从检查点重新计算该层内部的激活值
  • 牺牲约 33% 的计算时间(该层要算两遍),换来大幅内存节省

张量并行对激活值内存的影响

张量并行不仅切分了参数,还 自然地切分了部分激活值

激活值切分情况大小(每卡)
Attention 的 QKV切分到各卡(b \times s \times (h/t))
MLP 中间层切分到各卡(b \times s \times (4h/t))
层间激活值所有卡相同(b \times s \times h)

MLP 中间层((4h) 维度)是激活值的大头,它被 (t) 等分后每卡只存 (1/t),显著减少了激活值内存。


实验结果与关键发现

实验设置

  • 硬件:最多 32 台 DGX-2H(每台 16 张 V100 32GB),共 512 张 GPU
  • 互连:节点内 NVSwitch (300 GB/s),节点间 InfiniBand (8× 100 Gb/s)
  • 模型:GPT-2 和 BERT 架构,参数量从 355M 到 8.3B

模型配置

模型参数量层数 (L)隐藏维度 (h)注意力头数TP 度
GPT-2 355M355M241024161
GPT-2 2.5B2.5B541920242
GPT-2 4.2B4.2B722304244
GPT-2 8.3B8.3B723072248
BERT 3.9B3.9B482560328

扩展效率

单节点内的强扩展(Strong Scaling)

GPU 数量 (TP 度)8.3B 模型 TFLOPS/GPU相对效率
1OOM
237.5
435.695%
832.386%

在 8 卡 NVLink 互连下,效率保持在 86% 以上。从 2 卡到 8 卡仅下降 14%,说明 NVLink 带宽足以支撑张量并行的通信需求。

多节点的弱扩展(Weak Scaling)

节点数 × GPU数模型参数量总 TFLOPS效率
1 × 8 (8)1.2B236100% (基线)
2 × 16 (32)2.5B45095%
4 × 32 (64)4.2B86091%
8 × 64 (128)8.3B159084%
32 × 512 (512)8.3B553076%

512 GPU 的弱扩展效率达到 76%,这在 2019 年是非常出色的。

语言模型质量

GPT-2 8.3B 在 WikiText-103 上的困惑度(Perplexity)

模型参数量PPL
GPT-2 (OpenAI)1.5B17.48
Megatron GPT-28.3B10.81

参数量从 1.5B 提升到 8.3B,困惑度从 17.48 下降到 10.81——证明了更大的模型确实带来更好的语言建模能力。

BERT 3.9B 在下游任务上

任务RoBERTa-Large (355M)Megatron BERT 3.9B
RACE-h83.2%89.5%
MNLI90.2%91.4%
QQP92.2%92.6%

在多项 NLU 基准上超越当时的 SOTA(RoBERTa)。

效率随 TP 度的下降

虽然 86% 的单节点效率很好,但值得注意的是:当 TP 扩展到 跨节点 时,效率会急剧下降(因为跨节点带宽远低于 NVLink)。这就是为什么 Megatron-LM 强调 TP 只在节点内使用,跨节点用数据并行。这个经验法则后来成为大模型训练的标准实践。


与 ZeRO 的互补关系

Megatron-LM 的张量并行和 ZeRO 的数据并行是 正交互补 的两种技术:

维度Megatron-LM (TP)ZeRO (DP)
切分对象层内的参数矩阵跨所有层的模型状态
通信内容激活值((bsh))梯度/参数((\Psi))
通信频率每层 4 次每步 1-2 次
通信带宽需求(需要 NVLink)(AllGather/RS 效率高)
适合场景节点内(高带宽)跨节点(低带宽也可)
解决的问题单层参数超过单卡总模型状态超过单卡

3D 并行:最佳实践

在后续工作(Megatron-LM v2, 2021)中,NVIDIA 将 TP + DP + PP 组合成 3D 并行,这成为 GPT-3、PaLM 等超大模型训练的标准范式:

3D 并行架构示例 (64 GPUs):

张量并行 (TP=8): 节点内 NVLink
  ┌─GPU0─GPU1─GPU2─GPU3─GPU4─GPU5─GPU6─GPU7─┐ ← 1 个 TP 组
  └─────────────────────────────────────────┘

流水线并行 (PP=4): 跨节点
  TP组₀ (层1-6) → TP组₁ (层7-12) → TP组₂ (层13-18) → TP组₃ (层19-24)
  节点0           节点1             节点2             节点3

数据并行 (DP=2): 跨 PP 阶段的副本
  TP组₀ᵃ → TP组₁ᵃ → TP组₂ᵃ → TP组₃ᵃ   ← 数据分片 A
  TP组₀ᵇ → TP组₁ᵇ → TP组₂ᵇ → TP组₃ᵇ   ← 数据分片 B

总 GPU = TP × PP × DP = 8 × 4 × 2 = 64

每种并行度的选择经验:

  • TP:等于节点内 GPU 数(通常 8),利用 NVLink
  • PP:根据模型层数和节点数调整,1-8 之间
  • DP:用剩余的 GPU 数做数据并行,越大越好(吞吐越高)

关键技术点汇总

下面的代码模拟 Megatron-LM 张量并行的通信量和内存分布:

C++Megatron-LM 张量并行通信量与内存分析

总结与启示

Megatron-LM 的核心贡献

  1. 精巧的张量切分方案:列并行 + 行并行的组合,使 Transformer 每层只需 2 次 AllReduce,通信最小化

  2. (f) 和 (g) 通信原语:将张量并行的通信逻辑封装为即插即用的算子,不侵入模型计算代码

  3. 端到端的并行化:不仅是 Attention 和 MLP,还包括 Embedding 层和 Cross-Entropy 损失的并行化,消除所有冗余

  4. 工程落地:证明了在 NVLink 互连下,张量并行可以达到 86% 的单节点效率,是大模型训练的核心技术

深层设计原则

Megatron-LM 的成功体现了两个重要的系统设计原则:

原则 1:利用计算图的结构特性

Megatron-LM 没有设计通用的张量并行框架,而是 针对 Transformer 的具体结构 手工设计切分方案。正是因为了解 MLP 的"扩展-收缩"模式和多头注意力的"独立头"结构,才找到了只需 2 次 AllReduce 的最优方案。

原则 2:匹配通信模式与硬件拓扑

张量并行的高频通信天然匹配 NVLink 的高带宽低延迟特性。论文明确建议 TP 只在节点内使用——这不是妥协,而是 将正确的并行策略放在正确的硬件层级

从 Megatron-LM 到 ZeRO:互补的全景

两条技术路线解决了大模型训练的两个不同维度:

  • Megatron-LM(张量并行):解决 "一层放不下单卡" 的问题 → 切分层内参数
  • ZeRO(数据并行优化):解决 "所有层的状态放不下单卡" 的问题 → 切分跨层状态

两者组合(TP + ZeRO-DP),再加上流水线并行(PP),就形成了 3D 并行——当今训练 GPT-4、PaLM、LLaMA 等万亿参数模型的标准方法。


参考文献

  1. Shoeybi, M., Patwary, M., Puri, R., LeGresley, P., Casper, J., & Catanzaro, B. (2019). Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism. arXiv:1909.08053

  2. Narayanan, D., Shoeybi, M., Casper, J., LeGresley, P., Patwary, M., Korthikanti, V., ... & Catanzaro, B. (2021). Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM. SC 2021. arXiv:2104.04473

  3. Rajbhandari, S., Rasley, J., Ruwase, O., & He, Y. (2020). ZeRO: Memory Optimizations Toward Training Trillion Parameter Models. SC 2020. arXiv:1910.02054

  4. Huang, Y., Cheng, Y., Bapna, A., Firat, O., Chen, D., Chen, M., ... & Wu, Y. (2019). GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism. NeurIPS 2019. arXiv:1811.06965

  5. Shazeer, N., Cheng, Y., Parmar, N., Tran, D., Vaswani, A., Koanantakool, P., ... & Hawkins, J. (2018). Mesh-TensorFlow: Deep Learning for Supercomputers. NeurIPS 2018. arXiv:1811.02084

  6. Smith, S., et al. (2022). Using DeepSpeed and Megatron to Train Megatron-Turing NLG 530B, A Large-Scale Generative Language Model. arXiv:2201.11990

Released under the MIT License.