Skip to content

AI Infrastructure

AI infrastructure spans the full technology stack from hardware to software, from training to inference. Understanding this infrastructure is a prerequisite for large-scale AI engineering.


Compute and Clusters

GPU Fundamentals

Modern AI training and inference rely heavily on GPUs. Key concepts include:

CUDA Cores vs Tensor Cores:

  • CUDA Cores: General-purpose parallel compute cores that perform one floating-point operation (FP32) per clock cycle, suitable for general parallel computation
  • Tensor Cores: Specialized accelerator units designed for matrix multiplication, capable of performing a \(4 \times 4\) fused multiply-add (FMA) operation per clock cycle — the workhorses of AI training

VRAM (Video Memory):

VRAM is the primary bottleneck for GPU-based training. During training, VRAM is occupied by the following:

Component Description Approximate Share
Model Parameters Weights (FP16/FP32) Small
Gradients Same size as parameters Small
Optimizer States Adam requires 2 state tensors (mean + variance) Large
Activations Intermediate layer outputs needed for backpropagation Largest
KV Cache Attention cache during inference Largest during inference

Common GPU Comparison:

GPU VRAM FP16 Throughput Typical Use Case
RTX 4090 24GB 330 TFLOPS Personal research / small model training
A100 80GB 80GB 312 TFLOPS Standard training / inference
H100 SXM 80GB 990 TFLOPS Large-scale training
H200 141GB 990 TFLOPS Large model training / long-context inference

Multi-GPU Communication

CUDA:

  • NVIDIA's parallel computing platform and programming model
  • PyTorch provides a high-level interface via torch.cuda
  • The CUDA_VISIBLE_DEVICES environment variable controls which GPUs are visible

NCCL (NVIDIA Collective Communications Library):

  • A multi-GPU communication library developed by NVIDIA
  • Supports collective operations such as AllReduce, AllGather, and Broadcast
  • Automatically selects the optimal communication path (NVLink > PCIe > InfiniBand)
  • Used as the default backend for PyTorch's torch.distributed

NVLink vs PCIe:

  • NVLink: Direct GPU-to-GPU interconnect with 600–900 GB/s bandwidth (H100); essential for training large models
  • PCIe Gen5: Approximately 64 GB/s; often the communication bottleneck

Cluster Management

Slurm:

  • The most widely used job scheduler in HPC (High-Performance Computing)
  • Supports resource allocation, queue management, and job scheduling
  • The standard choice for academia and research institutions
# Submit a training job
sbatch --gres=gpu:4 --nodes=2 --ntasks-per-node=4 train.sh

# View the job queue
squeue -u $USER

# Request resources interactively
srun --gres=gpu:1 --pty bash

Kubernetes (K8s):

  • A container orchestration platform; the industry standard
  • Supports GPU scheduling via the NVIDIA GPU Operator
  • Advantages: elastic scaling, service discovery, rolling updates

Cloud vs On-Prem:

Dimension Cloud On-Premises
Upfront Cost Low (pay-as-you-go) High (hardware procurement)
Long-term Cost Potentially higher More cost-effective at high GPU utilization
Elasticity Scale up/down on demand Requires advance planning
Data Security Requires additional measures Data stays on-site
Best For Experimentation, elastic workloads Continuous training, compliance requirements

Storage and Data Pipelines

Storage Hierarchy

Fast ←──────────────────────────────→ Slow
Small capacity                        Large capacity

GPU VRAM → NVMe SSD → SATA SSD → NAS/NFS → Object Storage (S3)
 ~2TB/s    ~7GB/s     ~0.5GB/s   ~1-10Gb/s  ~hundreds of MB/s

Storage Choices for Training:

  • Training datasets: High-capacity storage (NAS, S3), with data loading pipelines for prefetching
  • Checkpoints: High-speed SSDs for frequent reads and writes
  • Logs and metrics: Object storage or databases

Data Loading Bottlenecks

Data loading can become a bottleneck during training, leaving GPUs idle and waiting ("starving"):

Common Optimization Techniques:

  • Multi-process loading: The num_workers parameter in PyTorch's DataLoader
  • Prefetching: Asynchronously loading the next batch while the GPU is computing
  • Memory-Mapped Files: Avoiding repeated disk reads
  • Optimized data formats: Using sequential-read-friendly formats such as WebDataset or TFRecord
  • Local SSD caching: Caching remote data on a local NVMe drive
# PyTorch DataLoader optimization example
train_loader = DataLoader(
    dataset,
    batch_size=64,
    num_workers=8,          # Multi-process loading
    pin_memory=True,        # Pin memory to accelerate GPU transfers
    prefetch_factor=2,      # Each worker prefetches 2 batches
    persistent_workers=True # Don't recreate workers each epoch
)

Checkpoint Management

Checkpoint management is critical when training large models:

  • Storage size: A checkpoint for a 7B-parameter model (including optimizer states) is approximately 56 GB (FP32)
  • Save frequency: Typically saved every N steps or at fixed time intervals
  • Retention policy: Keep only the latest N checkpoints plus key milestone checkpoints
  • Asynchronous saving: Use background threads to save checkpoints without blocking training

Training Infrastructure

Distributed Training

Data Parallel (DP) vs Distributed Data Parallel (DDP)

Data Parallel (DP):

  • PyTorch's early, simple implementation (torch.nn.DataParallel)
  • Single-process, multi-threaded — limited by the Python GIL
  • Communication pattern: GPU0 collects all gradients -> updates parameters -> broadcasts back to other GPUs
  • Problem: GPU0 becomes a communication and memory bottleneck; no longer recommended

Distributed Data Parallel (DDP):

  • Multi-process (one process per GPU), no GIL limitation
  • Communication pattern: AllReduce (gradient averaging), no single point of bottleneck
  • Gradient synchronization overlaps with backpropagation (Overlap Communication with Computation)
  • The current standard approach
# Basic DDP usage
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

dist.init_process_group("nccl")
model = DDP(model.to(local_rank), device_ids=[local_rank])

Launch command:

torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 \
    --master_addr="192.168.1.1" --master_port=29500 train.py

FSDP (Fully Sharded Data Parallel)

The limitation of DDP is that every GPU must hold a full copy of the model parameters, gradients, and optimizer states. When a model is too large to fit on a single GPU, DDP is insufficient.

FSDP addresses this by sharding parameters, gradients, and optimizer states across multiple GPUs:

  • Forward pass: Uses AllGather to collect the full parameter tensor for each needed layer
  • After computation: Releases non-local parameter shards, retaining only the local shard
  • Backward pass: Gathers parameters -> computes gradients -> ReduceScatter gradients -> releases

Result: Per-GPU memory consumption drops from \(O(N)\) to \(O(N/P)\), where \(P\) is the number of GPUs.

DeepSpeed ZeRO

Microsoft's DeepSpeed library provides ZeRO (Zero Redundancy Optimizer), which follows a similar philosophy to FSDP but with finer granularity:

Stage What Is Sharded Memory Savings
ZeRO-1 Optimizer states only ~4x
ZeRO-2 Optimizer states + gradients ~8x
ZeRO-3 Optimizer states + gradients + parameters Linear scaling

ZeRO-3 is functionally equivalent to FSDP, but DeepSpeed provides additional engineering optimizations (e.g., ZeRO-Infinity can offload to CPU/NVMe).

// DeepSpeed configuration example (ZeRO-2)
{
    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {"device": "cpu"},
        "contiguous_gradients": true,
        "overlap_comm": true
    },
    "fp16": {"enabled": true},
    "train_batch_size": 64
}

Memory Optimization

Gradient Checkpointing

Standard training stores activations from all intermediate layers for backpropagation. Gradient checkpointing saves only a subset of layer activations and recomputes the rest during the backward pass:

  • Memory savings: Reduced from \(O(L)\) to \(O(\sqrt{L})\), where \(L\) is the number of layers
  • Cost: Approximately 30% additional compute time (due to recomputation)
  • Trade-off: Trading compute for memory — nearly always used when training large models
from torch.utils.checkpoint import checkpoint

# Enable gradient checkpointing for a specific layer
output = checkpoint(self.heavy_layer, input, use_reentrant=False)

Mixed Precision (AMP)

Mixed precision training accelerates computation using lower-precision formats without sacrificing model accuracy:

  • Forward and backward passes: Use FP16 or BF16 (reduces memory, accelerates Tensor Core computation)
  • Parameter updates: Maintain FP32 master weights (preserves numerical precision)
  • Loss Scaling: Scales up the loss to prevent FP16 gradient underflow
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()
for data, target in train_loader:
    optimizer.zero_grad()
    with autocast(dtype=torch.bfloat16):
        output = model(data)
        loss = criterion(output, target)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

BF16 vs FP16:

  • FP16: 5-bit exponent; higher precision but prone to overflow; requires Loss Scaling
  • BF16: 8-bit exponent; same dynamic range as FP32; no Loss Scaling needed; more stable

Offloading

When GPU memory is insufficient, some data can be offloaded to CPU memory or even NVMe storage:

  • CPU Offloading: Moves optimizer states or inactive parameters to CPU memory
  • NVMe Offloading: Further offloads data to SSDs (DeepSpeed ZeRO-Infinity)
  • Cost: Introduces CPU-GPU data transfer latency

Fault Tolerance and Recovery

Checkpoint Strategies

Large-scale training runs may last weeks, making hardware failures inevitable:

  • Periodic saving: Save a full checkpoint every N steps
  • Asynchronous saving: Use separate threads/processes for saving without blocking training
  • Distributed saving: Each rank saves its own shard, speeding up the save process
  • Incremental saving: Save only the differences from the previous checkpoint (saves storage)

Elastic Training

Allows dynamic addition or removal of nodes during training:

  • PyTorch Elastic (TorchElastic): Automatically restarts after node failures; supports dynamic scaling
  • Use case: Preemptible instances (Spot Instances) that may be reclaimed at any time
  • Mechanism: Uses torchrun's --rdzv_backend for node discovery and coordination
# Elastic training: minimum 2 nodes, maximum 4 nodes
torchrun --nnodes=2:4 --nproc_per_node=4 \
    --rdzv_backend=c10d --rdzv_endpoint=host:port train.py

Inference Infrastructure

Serving Frameworks

Framework Developer Key Features
Triton Inference Server NVIDIA General-purpose inference server supporting multiple backends (TensorRT, PyTorch, ONNX, etc.)
TGI HuggingFace Optimized for Transformers, Docker deployment
vLLM UC Berkeley PagedAttention, high throughput
TensorRT-LLM NVIDIA Maximum performance, requires compilation-based optimization

For more details on vLLM, see vLLM and KV Cache

Batching Strategies

Strategy Description Throughput
No Batching Process requests one at a time Lowest
Static Batching Wait for a full batch before processing Moderate
Dynamic Batching Use a time window to collect requests Higher
Continuous Batching Dynamically schedule at every iteration Highest

Auto-Scaling

Dynamically adjusts the number of inference instances based on load:

  • Metric-based: GPU utilization, request queue length, P99 latency
  • Kubernetes HPA: Horizontal Pod Autoscaler — automatically adjusts the number of Pods based on metrics
  • Warm-up: New instances need to load the model, typically requiring tens of seconds to several minutes
  • Scale-down protection: Prevents oscillation from frequent scaling events (by setting a cooldown period)

References


评论 #