Skip to content

Large-Scale Model Training

Overview

Training large language models with billions or even trillions of parameters requires distributed training techniques spanning hundreds to thousands of GPUs. This chapter systematically covers data parallelism, model parallelism, pipeline parallelism, as well as key engineering techniques like mixed precision training and gradient checkpointing.

graph TD
    A[Large Model Training] --> B[Data Parallelism]
    A --> C[Model Parallelism]
    A --> D[Pipeline Parallelism]
    A --> E[Mixed Precision]
    A --> F[Memory Optimization]

    B --> B1[DDP]
    B --> B2[FSDP/ZeRO]

    C --> C1[Tensor Parallelism]
    C --> C2[Sequence Parallelism]

    D --> D1[GPipe]
    D --> D2[1F1B]
    D --> D3[Interleaved]

    E --> E1[FP16]
    E --> E2[BF16]
    E --> E3[FP8]

    F --> F1[Gradient Checkpointing]
    F --> F2[Offloading]
    F --> F3[Activation Compression]

1. GPU Memory Analysis

1.1 Memory Breakdown During Training

For a model with \(\Phi\) parameters (FP16 training):

Component Size Description
Model parameters \(2\Phi\) FP16: 2 bytes/param
Gradients \(2\Phi\) FP16
Optimizer states (Adam) \(12\Phi\) FP32 param copy(4) + first moment(4) + second moment(4)
Total \(16\Phi\)

Example: 7B model → \(16 \times 7 = 112\) GB (parameters + optimizer only)

Additional memory is needed for activations (dependent on batch size and sequence length), temporary buffers, etc.

1.2 Communication Bandwidth Requirements

Communication patterns in distributed training:

Operation Communication Volume Scenario
AllReduce \(2\Phi\) Data parallel gradient sync
AllGather \(\Phi\) FSDP parameter gathering
ReduceScatter \(\Phi\) FSDP gradient distribution
P2P Send/Recv Activations Pipeline parallelism

2. Data Parallelism

2.1 Distributed Data Parallel (DDP)

PyTorch DDP: Each GPU holds a complete model replica with different data.

Pipeline:

  1. Each GPU processes a different mini-batch
  2. Forward pass proceeds independently
  3. AllReduce synchronizes gradients during backward pass
  4. Each GPU updates parameters with identical gradients

Communication: One AllReduce per step, volume \(2\Phi\)

DDP Optimization: Gradient Bucketing — group gradients into buckets, overlapping communication with backward computation.

2.2 ZeRO (Zero Redundancy Optimizer)

Core Insight: In DDP, each GPU stores complete optimizer states, gradients, and parameters → massive redundancy.

Three ZeRO Stages:

Stage Sharded Content Per-GPU Memory Communication
Stage 1 Optimizer states \(4\Phi + 12\Phi/N\) \(2\Phi\)
Stage 2 + Gradients \(2\Phi + 14\Phi/N\) \(2\Phi\)
Stage 3 + Parameters \(16\Phi/N\) \(3\Phi\)

where \(N\) is the number of GPUs.

Stage 3 memory scales inversely with GPU count, but communication increases by 50%.

2.3 FSDP (Fully Sharded Data Parallel)

PyTorch FSDP is the native implementation of ZeRO Stage 3.

Core Operations:

Forward: AllGather full params → compute → release non-local params
Backward: AllGather params → compute gradients → ReduceScatter gradients

Key FSDP Configuration:

  • sharding_strategy: FULL_SHARD (Stage 3), SHARD_GRAD_OP (Stage 2)
  • auto_wrap_policy: Wrap by Transformer layers
  • mixed_precision: Mixed precision policy
  • cpu_offload: Offload parameters to CPU

3. Model Parallelism

3.1 Tensor Parallelism

Megatron-LM tensor parallelism shards individual operators across multiple GPUs.

MLP Layer Tensor Parallelism:

\[ Y = \text{GeLU}(XA) \cdot B \]

Shard \(A\) column-wise and \(B\) row-wise:

\[ A = [A_1 | A_2], \quad B = \begin{bmatrix} B_1 \\ B_2 \end{bmatrix} \]
  • GPU 1: \(Y_1 = \text{GeLU}(XA_1) B_1\)
  • GPU 2: \(Y_2 = \text{GeLU}(XA_2) B_2\)
  • AllReduce: \(Y = Y_1 + Y_2\)

Attention Layer Tensor Parallelism:

Attention heads are naturally distributable across GPUs.

Communication: 2 AllReduce operations per Transformer layer (one for MLP, one for attention).

3.2 Sequence Parallelism

Idea: Building on tensor parallelism, shard LayerNorm and Dropout along the sequence dimension.

  • These operations have no parameters but large activations
  • Replace AllReduce with AllGather + ReduceScatter
  • Further reduces activation memory

3.3 Context Parallelism

For ultra-long sequences, shard attention computation along the sequence dimension:

  • Each GPU processes a portion of the sequence
  • Exchange KV through Ring Attention or similar mechanisms

4. Pipeline Parallelism

4.1 Basic Idea

Partition model layers across GPUs:

  • GPU 0: Layers 1-8
  • GPU 1: Layers 9-16
  • GPU 2: Layers 17-24
  • GPU 3: Layers 25-32

4.2 GPipe

Split mini-batches into micro-batches for pipelined execution:

Time →
GPU 0: |F1|F2|F3|F4|  |  |  |  |B4|B3|B2|B1|
GPU 1:    |F1|F2|F3|F4|  |  |  |B4|B3|B2|B1|
GPU 2:       |F1|F2|F3|F4|  |  |B4|B3|B2|B1|
GPU 3:          |F1|F2|F3|F4|  |B4|B3|B2|B1|

Pipeline Bubble:

\[ \text{Bubble} = \frac{(p-1)}{m} \times 100\% \]

where \(p\) is the number of pipeline stages and \(m\) is the number of micro-batches. Bubble decreases when \(m \gg p\).

4.3 1F1B (One Forward One Backward)

Alternate forward and backward passes to reduce peak memory:

GPU 0: |F1|F2|F3|F4|B1|B2|B3|B4|
GPU 1:    |F1|F2|F3|B1|F4|B2|B3|B4|
GPU 2:       |F1|F2|B1|F3|B2|F4|B3|B4|
GPU 3:          |F1|B1|F2|B2|F3|B3|F4|B4|

Advantage: Fewer activations need to be cached at any given time.

4.4 Interleaved Pipeline

Interleave model layer assignment:

  • GPU 0: Layers 1-2, 9-10, 17-18, 25-26
  • Smaller bubble ratio but more communication

5. 3D Parallelism (Megatron-LM)

Combining Three Parallelism Strategies:

\[ \text{Total GPUs} = \text{DP} \times \text{TP} \times \text{PP} \]
graph TD
    A[3D Parallelism] --> B[Data Parallel DP]
    A --> C[Tensor Parallel TP]
    A --> D[Pipeline Parallel PP]

    B --> B1[Inter-node]
    C --> C1[Intra-node NVLink]
    D --> D1[Inter-node]

    style C1 fill:#f9f,stroke:#333

Typical Configuration (e.g., LLaMA 65B):

  • TP=8 (single node, 8 GPUs with NVLink)
  • PP=8 (8 nodes)
  • DP=16 (remaining GPUs for data parallelism)
  • Total: 8 x 8 x 16 = 1024 GPUs

Placement Principles:

  • Tensor parallelism: Requires high bandwidth → intra-node (NVLink 600GB/s)
  • Pipeline parallelism: Less communication → can go inter-node
  • Data parallelism: Communication overlaps with compute → inter-node

6. Mixed Precision Training

6.1 Data Type Comparison

Type Bits Exponent Mantissa Range Precision
FP32 32 8 23 \(\pm 3.4 \times 10^{38}\) High
FP16 16 5 10 \(\pm 65504\) Low
BF16 16 8 7 \(\pm 3.4 \times 10^{38}\) Medium-low
FP8 E4M3 8 4 3 \(\pm 448\) Very low
FP8 E5M2 8 5 2 \(\pm 57344\) Extremely low

6.2 Mixed Precision Training Strategy

AMP (Automatic Mixed Precision):

  1. Forward pass: FP16/BF16 (saves memory and compute)
  2. Loss scaling: FP16 requires Loss Scaling to prevent underflow
  3. Master weights: FP32 (for parameter updates)
  4. Optimizer states: FP32
\[ \theta_{t+1} = \theta_t^{\text{FP32}} - \eta \cdot \text{Adam}(g_t^{\text{FP16} \to \text{FP32}}) \]

6.3 BF16 vs FP16

Feature FP16 BF16
Numerical range Small (requires loss scaling) Same as FP32
Precision Higher Lower
Loss Scaling Required Usually not needed
Hardware support A100+ A100+, H100
Recommendation Vision tasks Preferred for LLM training

7. Gradient Checkpointing

7.1 Principle

Problem: Backward pass requires activations from all layers → memory proportional to number of layers.

Solution: Save activations at selected layers only; recompute during backward pass.

Time-Memory Trade-off:

Strategy Memory Compute
Save all \(O(L)\) \(1\times\)
Recompute all \(O(1)\) \(2\times\)
Checkpoint (every \(\sqrt{L}\) layers) \(O(\sqrt{L})\) \(\sim 1.33\times\)

7.2 Selective Checkpointing

Not all layers need recomputation. Priority for checkpointing:

  • Attention layers (large activations: \(O(n^2)\))
  • Keep linear layer activations (expensive to recompute but small memory footprint)

8. DeepSpeed

8.1 DeepSpeed Ecosystem

Component Function
ZeRO-1/2/3 Optimizer/gradient/parameter sharding
ZeRO-Offload CPU/NVMe offloading
ZeRO-Infinity Trillion parameter support
DeepSpeed-MoE MoE training support
DeepSpeed-Chat RLHF training framework

8.2 Configuration Example

{
  "bf16": {"enabled": true},
  "zero_optimization": {
    "stage": 3,
    "offload_param": {"device": "cpu"},
    "offload_optimizer": {"device": "cpu"},
    "overlap_comm": true,
    "contiguous_gradients": true,
    "reduce_bucket_size": 5e8
  },
  "gradient_accumulation_steps": 4,
  "gradient_clipping": 1.0,
  "train_micro_batch_size_per_gpu": 2
}

9. NCCL and Communication

9.1 Collective Communication Primitives

Primitive Description Usage
AllReduce Sum across all devices and broadcast DDP gradient sync
AllGather Gather data from all devices FSDP parameter gathering
ReduceScatter Sum and distribute shards FSDP gradient distribution
Broadcast One-to-many broadcast Parameter initialization
P2P Send/Recv Point-to-point communication Pipeline parallelism

9.2 Communication Topology

  • Ring AllReduce: \(2(N-1)/N \times \Phi\) communication
  • Tree AllReduce: \(2\log_2(N) \times \Phi\) communication
  • NVLink: High intra-node bandwidth (600 GB/s per GPU, H100)
  • InfiniBand: High inter-node bandwidth (400 Gbps HDR)

10. Training Stability

10.1 Common Issues

Issue Symptom Solution
Loss divergence Loss suddenly spikes Reduce learning rate, gradient clipping
Loss spikes Occasional loss spikes Skip anomalous batches, data cleaning
Vanishing/exploding gradients Training stalls Pre-LN, gradient clipping
Numerical overflow NaN/Inf BF16, loss scaling

10.2 Training Hyperparameter Recommendations

  • Learning rate: Cosine schedule, warmup 2000 steps
  • Gradient clipping: max_norm = 1.0
  • Weight decay: 0.1
  • Batch size: Gradually increase from small to large
  • Adam \(\beta\): \(\beta_1=0.9, \beta_2=0.95\)

11. Summary

graph LR
    A[1B Model] -->|Single GPU + AMP| B[Train]
    C[7B Model] -->|FSDP/ZeRO-3 + BF16| D[Train]
    E[70B Model] -->|3D Parallel TP+PP+DP| F[Train]
    G[>400B Model] -->|3D Parallel + Expert Parallel| H[Train]
Model Scale Recommended Approach GPU Requirements
<1B Single GPU + AMP 1x A100 80GB
1-7B FSDP Stage 2/3 4-8x A100
7-70B TP + FSDP 16-128x A100/H100
70B-400B 3D Parallelism 256-2048x H100
>400B 3D Parallel + Expert Parallel 2048+ H100

References

  • Rajbhandari et al., "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models," SC 2020
  • Shoeybi et al., "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism," 2019
  • Narayanan et al., "Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM," SC 2021
  • Micikevicius et al., "Mixed Precision Training," ICLR 2018
  • Zhao et al., "PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel," VLDB 2023

评论 #