大模型训练
概述
训练拥有数十亿乃至万亿参数的大语言模型,需要跨越数百甚至数千GPU的分布式训练技术。本章系统介绍数据并行、模型并行、流水线并行等核心策略,以及混合精度训练、梯度检查点等关键工程技术。
graph TD
A[大模型训练] --> B[数据并行]
A --> C[模型并行]
A --> D[流水线并行]
A --> E[混合精度]
A --> F[内存优化]
B --> B1[DDP]
B --> B2[FSDP/ZeRO]
C --> C1[张量并行]
C --> C2[序列并行]
D --> D1[GPipe]
D --> D2[1F1B]
D --> D3[Interleaved]
E --> E1[FP16]
E --> E2[BF16]
E --> E3[FP8]
F --> F1[梯度检查点]
F --> F2[Offloading]
F --> F3[激活压缩]
1. GPU内存分析
1.1 训练时的内存组成
对于一个参数量为 \(\Phi\) 的模型(FP16训练):
| 组成部分 | 大小 | 说明 |
|---|---|---|
| 模型参数 | \(2\Phi\) | FP16: 2字节/参数 |
| 梯度 | \(2\Phi\) | FP16 |
| 优化器状态(Adam) | \(12\Phi\) | FP32参数副本(4) + 一阶动量(4) + 二阶动量(4) |
| 总计 | \(16\Phi\) |
示例:7B模型 → \(16 \times 7 = 112\) GB(仅参数+优化器)
额外还有激活值(与batch size和序列长度相关)、临时缓冲区等。
1.2 通信带宽需求
分布式训练的通信模式:
| 操作 | 通信量 | 场景 |
|---|---|---|
| AllReduce | \(2\Phi\) | 数据并行梯度同步 |
| AllGather | \(\Phi\) | FSDP参数收集 |
| ReduceScatter | \(\Phi\) | FSDP梯度分发 |
| P2P Send/Recv | 激活值 | 流水线并行 |
2. 数据并行(Data Parallelism)
2.1 分布式数据并行(DDP)
PyTorch DDP:每个GPU持有模型完整副本,数据不同。
流程:
- 每个GPU处理不同的mini-batch
- 前向传播独立进行
- 反向传播时AllReduce同步梯度
- 每个GPU用相同梯度更新参数
通信:每步一次AllReduce,通信量 \(2\Phi\)
DDP优化:梯度桶(Gradient Bucketing)— 将梯度分桶,反向传播时重叠通信与计算。
2.2 ZeRO(Zero Redundancy Optimizer)
核心洞察:DDP中每个GPU存储完整的优化器状态、梯度和参数 → 大量冗余。
ZeRO三个阶段:
| 阶段 | 分片内容 | 每GPU内存 | 通信量 |
|---|---|---|---|
| Stage 1 | 优化器状态 | \(4\Phi + 12\Phi/N\) | \(2\Phi\) |
| Stage 2 | + 梯度 | \(2\Phi + 14\Phi/N\) | \(2\Phi\) |
| Stage 3 | + 参数 | \(16\Phi/N\) | \(3\Phi\) |
其中 \(N\) 是GPU数量。
Stage 3 内存需求随GPU数量线性下降,但通信量增加50%。
2.3 FSDP(Fully Sharded Data Parallel)
PyTorch的FSDP是ZeRO Stage 3的原生实现。
核心操作:
前向传播:AllGather收集完整参数 → 计算 → 释放非本地参数
反向传播:AllGather收集参数 → 计算梯度 → ReduceScatter分发梯度
FSDP配置要点:
sharding_strategy:FULL_SHARD(Stage 3)、SHARD_GRAD_OP(Stage 2)auto_wrap_policy:按Transformer层包装mixed_precision:混合精度策略cpu_offload:将参数卸载到CPU
3. 模型并行(Model Parallelism)
3.1 张量并行(Tensor Parallelism)
Megatron-LM的张量并行将单个算子分片到多个GPU。
MLP层的张量并行:
将 \(A\) 按列分片,\(B\) 按行分片:
- GPU 1: \(Y_1 = \text{GeLU}(XA_1) B_1\)
- GPU 2: \(Y_2 = \text{GeLU}(XA_2) B_2\)
- AllReduce: \(Y = Y_1 + Y_2\)
注意力层的张量并行:
各注意力头天然可以分配到不同GPU。
通信:每个Transformer层需要2次AllReduce(MLP和注意力各一次)。
3.2 序列并行(Sequence Parallelism)
思想:在张量并行的基础上,对LayerNorm和Dropout沿序列维度分片。
- 这些操作不涉及参数,但激活值很大
- 将AllReduce替换为AllGather + ReduceScatter
- 进一步减少激活值内存
3.3 上下文并行(Context Parallelism)
处理超长序列时,沿序列维度分片注意力计算:
- 每个GPU处理序列的一部分
- 通过Ring Attention或其他方式交换KV
4. 流水线并行(Pipeline Parallelism)
4.1 基本思想
将模型按层分到不同GPU:
- GPU 0: 层 1-8
- GPU 1: 层 9-16
- GPU 2: 层 17-24
- GPU 3: 层 25-32
4.2 GPipe
将mini-batch分成多个micro-batch,流水线式执行:
时间 →
GPU 0: |F1|F2|F3|F4| | | | |B4|B3|B2|B1|
GPU 1: |F1|F2|F3|F4| | | |B4|B3|B2|B1|
GPU 2: |F1|F2|F3|F4| | |B4|B3|B2|B1|
GPU 3: |F1|F2|F3|F4| |B4|B3|B2|B1|
气泡率(Pipeline Bubble):
其中 \(p\) 是流水线阶段数,\(m\) 是micro-batch数。\(m \gg p\) 时气泡率降低。
4.3 1F1B(One Forward One Backward)
交替执行前向和反向,减少内存峰值:
GPU 0: |F1|F2|F3|F4|B1|B2|B3|B4|
GPU 1: |F1|F2|F3|B1|F4|B2|B3|B4|
GPU 2: |F1|F2|B1|F3|B2|F4|B3|B4|
GPU 3: |F1|B1|F2|B2|F3|B3|F4|B4|
优点:同一时刻只需缓存较少的激活值。
4.4 Interleaved Pipeline
将模型层交错分配:
- GPU 0: 层 1-2, 9-10, 17-18, 25-26
- 更小的气泡率,但更多的通信
5. 3D并行(Megatron-LM)
结合三种并行:
graph TD
A[3D并行] --> B[数据并行 DP]
A --> C[张量并行 TP]
A --> D[流水线并行 PP]
B --> B1[跨节点]
C --> C1[节点内 NVLink]
D --> D1[跨节点]
style C1 fill:#f9f,stroke:#333
典型配置(以LLaMA 65B为例):
- TP=8(单节点8卡NVLink)
- PP=8(8个节点)
- DP=16(剩余GPU用于数据并行)
- 总计:8 × 8 × 16 = 1024 GPU
放置原则:
- 张量并行:需要高带宽 → 节点内(NVLink 600GB/s)
- 流水线并行:通信较少 → 可跨节点
- 数据并行:通信可与计算重叠 → 跨节点
6. 混合精度训练
6.1 数据类型对比
| 类型 | 位数 | 指数位 | 尾数位 | 范围 | 精度 |
|---|---|---|---|---|---|
| FP32 | 32 | 8 | 23 | \(\pm 3.4 \times 10^{38}\) | 高 |
| FP16 | 16 | 5 | 10 | \(\pm 65504\) | 低 |
| BF16 | 16 | 8 | 7 | \(\pm 3.4 \times 10^{38}\) | 中低 |
| FP8 E4M3 | 8 | 4 | 3 | \(\pm 448\) | 很低 |
| FP8 E5M2 | 8 | 5 | 2 | \(\pm 57344\) | 极低 |
6.2 混合精度训练策略
AMP(Automatic Mixed Precision):
- 前向传播:FP16/BF16(节省内存和计算)
- 损失缩放:FP16需要Loss Scaling防止下溢
- 主权重:FP32(用于参数更新)
- 优化器状态:FP32
6.3 BF16 vs FP16
| 特性 | FP16 | BF16 |
|---|---|---|
| 数值范围 | 小(需要loss scaling) | 与FP32相同 |
| 精度 | 较高 | 较低 |
| Loss Scaling | 必须 | 通常不需要 |
| 硬件支持 | A100+ | A100+, H100 |
| 推荐 | 图像任务 | LLM训练首选 |
7. 梯度检查点(Gradient Checkpointing)
7.1 原理
问题:反向传播需要所有层的激活值 → 内存与层数成正比。
解决:只保存部分层的激活值,反向传播时重新计算。
时间-内存权衡:
| 策略 | 内存 | 计算 |
|---|---|---|
| 全部保存 | \(O(L)\) | \(1\times\) |
| 全部重算 | \(O(1)\) | \(2\times\) |
| 检查点(每\(\sqrt{L}\)层) | \(O(\sqrt{L})\) | \(\sim 1.33\times\) |
7.2 选择性检查点
不是所有层都需要重算,优先检查点:
- 注意力层(激活值大:\(O(n^2)\))
- 保留线性层激活值(重算代价高但内存占用小)
8. DeepSpeed
8.1 DeepSpeed生态
| 组件 | 功能 |
|---|---|
| ZeRO-1/2/3 | 优化器/梯度/参数分片 |
| ZeRO-Offload | CPU/NVMe offloading |
| ZeRO-Infinity | 支持万亿参数 |
| DeepSpeed-MoE | MoE训练支持 |
| DeepSpeed-Chat | RLHF训练框架 |
8.2 配置示例
{
"bf16": {"enabled": true},
"zero_optimization": {
"stage": 3,
"offload_param": {"device": "cpu"},
"offload_optimizer": {"device": "cpu"},
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_bucket_size": 5e8
},
"gradient_accumulation_steps": 4,
"gradient_clipping": 1.0,
"train_micro_batch_size_per_gpu": 2
}
9. NCCL与通信
9.1 集合通信原语
| 原语 | 描述 | 用途 |
|---|---|---|
| AllReduce | 所有设备求和并广播 | DDP梯度同步 |
| AllGather | 收集所有设备的数据 | FSDP参数收集 |
| ReduceScatter | 求和并分片分发 | FSDP梯度分发 |
| Broadcast | 一对多广播 | 参数初始化 |
| P2P Send/Recv | 点对点通信 | 流水线并行 |
9.2 通信拓扑
- Ring AllReduce:\(2(N-1)/N \times \Phi\) 通信量
- Tree AllReduce:\(2\log_2(N) \times \Phi\) 通信量
- NVLink:节点内高带宽(600 GB/s per GPU, H100)
- InfiniBand:节点间高带宽(400 Gbps HDR)
10. 训练稳定性
10.1 常见问题
| 问题 | 症状 | 解决方案 |
|---|---|---|
| 损失发散 | loss突然飙升 | 降低学习率、梯度裁剪 |
| 损失尖刺 | 偶发的loss spike | 跳过异常batch、数据清洗 |
| 梯度消失/爆炸 | 训练停滞 | Pre-LN、梯度裁剪 |
| 数值溢出 | NaN/Inf | BF16、loss scaling |
10.2 训练超参数建议
- 学习率:cosine schedule, warmup 2000步
- 梯度裁剪:max_norm = 1.0
- 权重衰减:0.1
- Batch Size:从小到大逐步增加
- Adam \(\beta\):\(\beta_1=0.9, \beta_2=0.95\)
11. 总结
graph LR
A[1B模型] -->|单GPU + AMP| B[训练]
C[7B模型] -->|FSDP/ZeRO-3 + BF16| D[训练]
E[70B模型] -->|3D并行 TP+PP+DP| F[训练]
G[>400B模型] -->|3D并行 + Expert并行| H[训练]
| 模型规模 | 推荐方案 | GPU需求 |
|---|---|---|
| <1B | 单GPU + AMP | 1× A100 80GB |
| 1-7B | FSDP Stage 2/3 | 4-8× A100 |
| 7-70B | TP + FSDP | 16-128× A100/H100 |
| 70B-400B | 3D并行 | 256-2048× H100 |
| >400B | 3D并行 + Expert并行 | 2048+× H100 |
参考文献
- Rajbhandari et al., "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models," SC 2020
- Shoeybi et al., "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism," 2019
- Narayanan et al., "Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM," SC 2021
- Micikevicius et al., "Mixed Precision Training," ICLR 2018
- Zhao et al., "PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel," VLDB 2023