Large-Scale Model Training
Overview
Training large language models with billions or even trillions of parameters requires distributed training techniques spanning hundreds to thousands of GPUs. This chapter systematically covers data parallelism, model parallelism, pipeline parallelism, as well as key engineering techniques like mixed precision training and gradient checkpointing.
graph TD
A[Large Model Training] --> B[Data Parallelism]
A --> C[Model Parallelism]
A --> D[Pipeline Parallelism]
A --> E[Mixed Precision]
A --> F[Memory Optimization]
B --> B1[DDP]
B --> B2[FSDP/ZeRO]
C --> C1[Tensor Parallelism]
C --> C2[Sequence Parallelism]
D --> D1[GPipe]
D --> D2[1F1B]
D --> D3[Interleaved]
E --> E1[FP16]
E --> E2[BF16]
E --> E3[FP8]
F --> F1[Gradient Checkpointing]
F --> F2[Offloading]
F --> F3[Activation Compression]
1. GPU Memory Analysis
1.1 Memory Breakdown During Training
For a model with \(\Phi\) parameters (FP16 training):
| Component | Size | Description |
|---|---|---|
| Model parameters | \(2\Phi\) | FP16: 2 bytes/param |
| Gradients | \(2\Phi\) | FP16 |
| Optimizer states (Adam) | \(12\Phi\) | FP32 param copy(4) + first moment(4) + second moment(4) |
| Total | \(16\Phi\) |
Example: 7B model → \(16 \times 7 = 112\) GB (parameters + optimizer only)
Additional memory is needed for activations (dependent on batch size and sequence length), temporary buffers, etc.
1.2 Communication Bandwidth Requirements
Communication patterns in distributed training:
| Operation | Communication Volume | Scenario |
|---|---|---|
| AllReduce | \(2\Phi\) | Data parallel gradient sync |
| AllGather | \(\Phi\) | FSDP parameter gathering |
| ReduceScatter | \(\Phi\) | FSDP gradient distribution |
| P2P Send/Recv | Activations | Pipeline parallelism |
2. Data Parallelism
2.1 Distributed Data Parallel (DDP)
PyTorch DDP: Each GPU holds a complete model replica with different data.
Pipeline:
- Each GPU processes a different mini-batch
- Forward pass proceeds independently
- AllReduce synchronizes gradients during backward pass
- Each GPU updates parameters with identical gradients
Communication: One AllReduce per step, volume \(2\Phi\)
DDP Optimization: Gradient Bucketing — group gradients into buckets, overlapping communication with backward computation.
2.2 ZeRO (Zero Redundancy Optimizer)
Core Insight: In DDP, each GPU stores complete optimizer states, gradients, and parameters → massive redundancy.
Three ZeRO Stages:
| Stage | Sharded Content | Per-GPU Memory | Communication |
|---|---|---|---|
| Stage 1 | Optimizer states | \(4\Phi + 12\Phi/N\) | \(2\Phi\) |
| Stage 2 | + Gradients | \(2\Phi + 14\Phi/N\) | \(2\Phi\) |
| Stage 3 | + Parameters | \(16\Phi/N\) | \(3\Phi\) |
where \(N\) is the number of GPUs.
Stage 3 memory scales inversely with GPU count, but communication increases by 50%.
2.3 FSDP (Fully Sharded Data Parallel)
PyTorch FSDP is the native implementation of ZeRO Stage 3.
Core Operations:
Forward: AllGather full params → compute → release non-local params
Backward: AllGather params → compute gradients → ReduceScatter gradients
Key FSDP Configuration:
sharding_strategy: FULL_SHARD (Stage 3), SHARD_GRAD_OP (Stage 2)auto_wrap_policy: Wrap by Transformer layersmixed_precision: Mixed precision policycpu_offload: Offload parameters to CPU
3. Model Parallelism
3.1 Tensor Parallelism
Megatron-LM tensor parallelism shards individual operators across multiple GPUs.
MLP Layer Tensor Parallelism:
Shard \(A\) column-wise and \(B\) row-wise:
- GPU 1: \(Y_1 = \text{GeLU}(XA_1) B_1\)
- GPU 2: \(Y_2 = \text{GeLU}(XA_2) B_2\)
- AllReduce: \(Y = Y_1 + Y_2\)
Attention Layer Tensor Parallelism:
Attention heads are naturally distributable across GPUs.
Communication: 2 AllReduce operations per Transformer layer (one for MLP, one for attention).
3.2 Sequence Parallelism
Idea: Building on tensor parallelism, shard LayerNorm and Dropout along the sequence dimension.
- These operations have no parameters but large activations
- Replace AllReduce with AllGather + ReduceScatter
- Further reduces activation memory
3.3 Context Parallelism
For ultra-long sequences, shard attention computation along the sequence dimension:
- Each GPU processes a portion of the sequence
- Exchange KV through Ring Attention or similar mechanisms
4. Pipeline Parallelism
4.1 Basic Idea
Partition model layers across GPUs:
- GPU 0: Layers 1-8
- GPU 1: Layers 9-16
- GPU 2: Layers 17-24
- GPU 3: Layers 25-32
4.2 GPipe
Split mini-batches into micro-batches for pipelined execution:
Time →
GPU 0: |F1|F2|F3|F4| | | | |B4|B3|B2|B1|
GPU 1: |F1|F2|F3|F4| | | |B4|B3|B2|B1|
GPU 2: |F1|F2|F3|F4| | |B4|B3|B2|B1|
GPU 3: |F1|F2|F3|F4| |B4|B3|B2|B1|
Pipeline Bubble:
where \(p\) is the number of pipeline stages and \(m\) is the number of micro-batches. Bubble decreases when \(m \gg p\).
4.3 1F1B (One Forward One Backward)
Alternate forward and backward passes to reduce peak memory:
GPU 0: |F1|F2|F3|F4|B1|B2|B3|B4|
GPU 1: |F1|F2|F3|B1|F4|B2|B3|B4|
GPU 2: |F1|F2|B1|F3|B2|F4|B3|B4|
GPU 3: |F1|B1|F2|B2|F3|B3|F4|B4|
Advantage: Fewer activations need to be cached at any given time.
4.4 Interleaved Pipeline
Interleave model layer assignment:
- GPU 0: Layers 1-2, 9-10, 17-18, 25-26
- Smaller bubble ratio but more communication
5. 3D Parallelism (Megatron-LM)
Combining Three Parallelism Strategies:
graph TD
A[3D Parallelism] --> B[Data Parallel DP]
A --> C[Tensor Parallel TP]
A --> D[Pipeline Parallel PP]
B --> B1[Inter-node]
C --> C1[Intra-node NVLink]
D --> D1[Inter-node]
style C1 fill:#f9f,stroke:#333
Typical Configuration (e.g., LLaMA 65B):
- TP=8 (single node, 8 GPUs with NVLink)
- PP=8 (8 nodes)
- DP=16 (remaining GPUs for data parallelism)
- Total: 8 x 8 x 16 = 1024 GPUs
Placement Principles:
- Tensor parallelism: Requires high bandwidth → intra-node (NVLink 600GB/s)
- Pipeline parallelism: Less communication → can go inter-node
- Data parallelism: Communication overlaps with compute → inter-node
6. Mixed Precision Training
6.1 Data Type Comparison
| Type | Bits | Exponent | Mantissa | Range | Precision |
|---|---|---|---|---|---|
| FP32 | 32 | 8 | 23 | \(\pm 3.4 \times 10^{38}\) | High |
| FP16 | 16 | 5 | 10 | \(\pm 65504\) | Low |
| BF16 | 16 | 8 | 7 | \(\pm 3.4 \times 10^{38}\) | Medium-low |
| FP8 E4M3 | 8 | 4 | 3 | \(\pm 448\) | Very low |
| FP8 E5M2 | 8 | 5 | 2 | \(\pm 57344\) | Extremely low |
6.2 Mixed Precision Training Strategy
AMP (Automatic Mixed Precision):
- Forward pass: FP16/BF16 (saves memory and compute)
- Loss scaling: FP16 requires Loss Scaling to prevent underflow
- Master weights: FP32 (for parameter updates)
- Optimizer states: FP32
6.3 BF16 vs FP16
| Feature | FP16 | BF16 |
|---|---|---|
| Numerical range | Small (requires loss scaling) | Same as FP32 |
| Precision | Higher | Lower |
| Loss Scaling | Required | Usually not needed |
| Hardware support | A100+ | A100+, H100 |
| Recommendation | Vision tasks | Preferred for LLM training |
7. Gradient Checkpointing
7.1 Principle
Problem: Backward pass requires activations from all layers → memory proportional to number of layers.
Solution: Save activations at selected layers only; recompute during backward pass.
Time-Memory Trade-off:
| Strategy | Memory | Compute |
|---|---|---|
| Save all | \(O(L)\) | \(1\times\) |
| Recompute all | \(O(1)\) | \(2\times\) |
| Checkpoint (every \(\sqrt{L}\) layers) | \(O(\sqrt{L})\) | \(\sim 1.33\times\) |
7.2 Selective Checkpointing
Not all layers need recomputation. Priority for checkpointing:
- Attention layers (large activations: \(O(n^2)\))
- Keep linear layer activations (expensive to recompute but small memory footprint)
8. DeepSpeed
8.1 DeepSpeed Ecosystem
| Component | Function |
|---|---|
| ZeRO-1/2/3 | Optimizer/gradient/parameter sharding |
| ZeRO-Offload | CPU/NVMe offloading |
| ZeRO-Infinity | Trillion parameter support |
| DeepSpeed-MoE | MoE training support |
| DeepSpeed-Chat | RLHF training framework |
8.2 Configuration Example
{
"bf16": {"enabled": true},
"zero_optimization": {
"stage": 3,
"offload_param": {"device": "cpu"},
"offload_optimizer": {"device": "cpu"},
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_bucket_size": 5e8
},
"gradient_accumulation_steps": 4,
"gradient_clipping": 1.0,
"train_micro_batch_size_per_gpu": 2
}
9. NCCL and Communication
9.1 Collective Communication Primitives
| Primitive | Description | Usage |
|---|---|---|
| AllReduce | Sum across all devices and broadcast | DDP gradient sync |
| AllGather | Gather data from all devices | FSDP parameter gathering |
| ReduceScatter | Sum and distribute shards | FSDP gradient distribution |
| Broadcast | One-to-many broadcast | Parameter initialization |
| P2P Send/Recv | Point-to-point communication | Pipeline parallelism |
9.2 Communication Topology
- Ring AllReduce: \(2(N-1)/N \times \Phi\) communication
- Tree AllReduce: \(2\log_2(N) \times \Phi\) communication
- NVLink: High intra-node bandwidth (600 GB/s per GPU, H100)
- InfiniBand: High inter-node bandwidth (400 Gbps HDR)
10. Training Stability
10.1 Common Issues
| Issue | Symptom | Solution |
|---|---|---|
| Loss divergence | Loss suddenly spikes | Reduce learning rate, gradient clipping |
| Loss spikes | Occasional loss spikes | Skip anomalous batches, data cleaning |
| Vanishing/exploding gradients | Training stalls | Pre-LN, gradient clipping |
| Numerical overflow | NaN/Inf | BF16, loss scaling |
10.2 Training Hyperparameter Recommendations
- Learning rate: Cosine schedule, warmup 2000 steps
- Gradient clipping: max_norm = 1.0
- Weight decay: 0.1
- Batch size: Gradually increase from small to large
- Adam \(\beta\): \(\beta_1=0.9, \beta_2=0.95\)
11. Summary
graph LR
A[1B Model] -->|Single GPU + AMP| B[Train]
C[7B Model] -->|FSDP/ZeRO-3 + BF16| D[Train]
E[70B Model] -->|3D Parallel TP+PP+DP| F[Train]
G[>400B Model] -->|3D Parallel + Expert Parallel| H[Train]
| Model Scale | Recommended Approach | GPU Requirements |
|---|---|---|
| <1B | Single GPU + AMP | 1x A100 80GB |
| 1-7B | FSDP Stage 2/3 | 4-8x A100 |
| 7-70B | TP + FSDP | 16-128x A100/H100 |
| 70B-400B | 3D Parallelism | 256-2048x H100 |
| >400B | 3D Parallel + Expert Parallel | 2048+ H100 |
References
- Rajbhandari et al., "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models," SC 2020
- Shoeybi et al., "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism," 2019
- Narayanan et al., "Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM," SC 2021
- Micikevicius et al., "Mixed Precision Training," ICLR 2018
- Zhao et al., "PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel," VLDB 2023