AI Infrastructure
AI infrastructure spans the full technology stack from hardware to software, from training to inference. Understanding this infrastructure is a prerequisite for large-scale AI engineering.
Compute and Clusters
GPU Fundamentals
Modern AI training and inference rely heavily on GPUs. Key concepts include:
CUDA Cores vs Tensor Cores:
- CUDA Cores: General-purpose parallel compute cores that perform one floating-point operation (FP32) per clock cycle, suitable for general parallel computation
- Tensor Cores: Specialized accelerator units designed for matrix multiplication, capable of performing a \(4 \times 4\) fused multiply-add (FMA) operation per clock cycle — the workhorses of AI training
VRAM (Video Memory):
VRAM is the primary bottleneck for GPU-based training. During training, VRAM is occupied by the following:
| Component | Description | Approximate Share |
|---|---|---|
| Model Parameters | Weights (FP16/FP32) | Small |
| Gradients | Same size as parameters | Small |
| Optimizer States | Adam requires 2 state tensors (mean + variance) | Large |
| Activations | Intermediate layer outputs needed for backpropagation | Largest |
| KV Cache | Attention cache during inference | Largest during inference |
Common GPU Comparison:
| GPU | VRAM | FP16 Throughput | Typical Use Case |
|---|---|---|---|
| RTX 4090 | 24GB | 330 TFLOPS | Personal research / small model training |
| A100 80GB | 80GB | 312 TFLOPS | Standard training / inference |
| H100 SXM | 80GB | 990 TFLOPS | Large-scale training |
| H200 | 141GB | 990 TFLOPS | Large model training / long-context inference |
Multi-GPU Communication
CUDA:
- NVIDIA's parallel computing platform and programming model
- PyTorch provides a high-level interface via
torch.cuda - The
CUDA_VISIBLE_DEVICESenvironment variable controls which GPUs are visible
NCCL (NVIDIA Collective Communications Library):
- A multi-GPU communication library developed by NVIDIA
- Supports collective operations such as AllReduce, AllGather, and Broadcast
- Automatically selects the optimal communication path (NVLink > PCIe > InfiniBand)
- Used as the default backend for PyTorch's
torch.distributed
NVLink vs PCIe:
- NVLink: Direct GPU-to-GPU interconnect with 600–900 GB/s bandwidth (H100); essential for training large models
- PCIe Gen5: Approximately 64 GB/s; often the communication bottleneck
Cluster Management
Slurm:
- The most widely used job scheduler in HPC (High-Performance Computing)
- Supports resource allocation, queue management, and job scheduling
- The standard choice for academia and research institutions
# Submit a training job
sbatch --gres=gpu:4 --nodes=2 --ntasks-per-node=4 train.sh
# View the job queue
squeue -u $USER
# Request resources interactively
srun --gres=gpu:1 --pty bash
Kubernetes (K8s):
- A container orchestration platform; the industry standard
- Supports GPU scheduling via the NVIDIA GPU Operator
- Advantages: elastic scaling, service discovery, rolling updates
Cloud vs On-Prem:
| Dimension | Cloud | On-Premises |
|---|---|---|
| Upfront Cost | Low (pay-as-you-go) | High (hardware procurement) |
| Long-term Cost | Potentially higher | More cost-effective at high GPU utilization |
| Elasticity | Scale up/down on demand | Requires advance planning |
| Data Security | Requires additional measures | Data stays on-site |
| Best For | Experimentation, elastic workloads | Continuous training, compliance requirements |
Storage and Data Pipelines
Storage Hierarchy
Fast ←──────────────────────────────→ Slow
Small capacity Large capacity
GPU VRAM → NVMe SSD → SATA SSD → NAS/NFS → Object Storage (S3)
~2TB/s ~7GB/s ~0.5GB/s ~1-10Gb/s ~hundreds of MB/s
Storage Choices for Training:
- Training datasets: High-capacity storage (NAS, S3), with data loading pipelines for prefetching
- Checkpoints: High-speed SSDs for frequent reads and writes
- Logs and metrics: Object storage or databases
Data Loading Bottlenecks
Data loading can become a bottleneck during training, leaving GPUs idle and waiting ("starving"):
Common Optimization Techniques:
- Multi-process loading: The
num_workersparameter in PyTorch'sDataLoader - Prefetching: Asynchronously loading the next batch while the GPU is computing
- Memory-Mapped Files: Avoiding repeated disk reads
- Optimized data formats: Using sequential-read-friendly formats such as WebDataset or TFRecord
- Local SSD caching: Caching remote data on a local NVMe drive
# PyTorch DataLoader optimization example
train_loader = DataLoader(
dataset,
batch_size=64,
num_workers=8, # Multi-process loading
pin_memory=True, # Pin memory to accelerate GPU transfers
prefetch_factor=2, # Each worker prefetches 2 batches
persistent_workers=True # Don't recreate workers each epoch
)
Checkpoint Management
Checkpoint management is critical when training large models:
- Storage size: A checkpoint for a 7B-parameter model (including optimizer states) is approximately 56 GB (FP32)
- Save frequency: Typically saved every N steps or at fixed time intervals
- Retention policy: Keep only the latest N checkpoints plus key milestone checkpoints
- Asynchronous saving: Use background threads to save checkpoints without blocking training
Training Infrastructure
Distributed Training
Data Parallel (DP) vs Distributed Data Parallel (DDP)
Data Parallel (DP):
- PyTorch's early, simple implementation (
torch.nn.DataParallel) - Single-process, multi-threaded — limited by the Python GIL
- Communication pattern: GPU0 collects all gradients -> updates parameters -> broadcasts back to other GPUs
- Problem: GPU0 becomes a communication and memory bottleneck; no longer recommended
Distributed Data Parallel (DDP):
- Multi-process (one process per GPU), no GIL limitation
- Communication pattern: AllReduce (gradient averaging), no single point of bottleneck
- Gradient synchronization overlaps with backpropagation (Overlap Communication with Computation)
- The current standard approach
# Basic DDP usage
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
dist.init_process_group("nccl")
model = DDP(model.to(local_rank), device_ids=[local_rank])
Launch command:
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 \
--master_addr="192.168.1.1" --master_port=29500 train.py
FSDP (Fully Sharded Data Parallel)
The limitation of DDP is that every GPU must hold a full copy of the model parameters, gradients, and optimizer states. When a model is too large to fit on a single GPU, DDP is insufficient.
FSDP addresses this by sharding parameters, gradients, and optimizer states across multiple GPUs:
- Forward pass: Uses AllGather to collect the full parameter tensor for each needed layer
- After computation: Releases non-local parameter shards, retaining only the local shard
- Backward pass: Gathers parameters -> computes gradients -> ReduceScatter gradients -> releases
Result: Per-GPU memory consumption drops from \(O(N)\) to \(O(N/P)\), where \(P\) is the number of GPUs.
DeepSpeed ZeRO
Microsoft's DeepSpeed library provides ZeRO (Zero Redundancy Optimizer), which follows a similar philosophy to FSDP but with finer granularity:
| Stage | What Is Sharded | Memory Savings |
|---|---|---|
| ZeRO-1 | Optimizer states only | ~4x |
| ZeRO-2 | Optimizer states + gradients | ~8x |
| ZeRO-3 | Optimizer states + gradients + parameters | Linear scaling |
ZeRO-3 is functionally equivalent to FSDP, but DeepSpeed provides additional engineering optimizations (e.g., ZeRO-Infinity can offload to CPU/NVMe).
// DeepSpeed configuration example (ZeRO-2)
{
"zero_optimization": {
"stage": 2,
"offload_optimizer": {"device": "cpu"},
"contiguous_gradients": true,
"overlap_comm": true
},
"fp16": {"enabled": true},
"train_batch_size": 64
}
Memory Optimization
Gradient Checkpointing
Standard training stores activations from all intermediate layers for backpropagation. Gradient checkpointing saves only a subset of layer activations and recomputes the rest during the backward pass:
- Memory savings: Reduced from \(O(L)\) to \(O(\sqrt{L})\), where \(L\) is the number of layers
- Cost: Approximately 30% additional compute time (due to recomputation)
- Trade-off: Trading compute for memory — nearly always used when training large models
from torch.utils.checkpoint import checkpoint
# Enable gradient checkpointing for a specific layer
output = checkpoint(self.heavy_layer, input, use_reentrant=False)
Mixed Precision (AMP)
Mixed precision training accelerates computation using lower-precision formats without sacrificing model accuracy:
- Forward and backward passes: Use FP16 or BF16 (reduces memory, accelerates Tensor Core computation)
- Parameter updates: Maintain FP32 master weights (preserves numerical precision)
- Loss Scaling: Scales up the loss to prevent FP16 gradient underflow
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for data, target in train_loader:
optimizer.zero_grad()
with autocast(dtype=torch.bfloat16):
output = model(data)
loss = criterion(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
BF16 vs FP16:
- FP16: 5-bit exponent; higher precision but prone to overflow; requires Loss Scaling
- BF16: 8-bit exponent; same dynamic range as FP32; no Loss Scaling needed; more stable
Offloading
When GPU memory is insufficient, some data can be offloaded to CPU memory or even NVMe storage:
- CPU Offloading: Moves optimizer states or inactive parameters to CPU memory
- NVMe Offloading: Further offloads data to SSDs (DeepSpeed ZeRO-Infinity)
- Cost: Introduces CPU-GPU data transfer latency
Fault Tolerance and Recovery
Checkpoint Strategies
Large-scale training runs may last weeks, making hardware failures inevitable:
- Periodic saving: Save a full checkpoint every N steps
- Asynchronous saving: Use separate threads/processes for saving without blocking training
- Distributed saving: Each rank saves its own shard, speeding up the save process
- Incremental saving: Save only the differences from the previous checkpoint (saves storage)
Elastic Training
Allows dynamic addition or removal of nodes during training:
- PyTorch Elastic (TorchElastic): Automatically restarts after node failures; supports dynamic scaling
- Use case: Preemptible instances (Spot Instances) that may be reclaimed at any time
- Mechanism: Uses
torchrun's--rdzv_backendfor node discovery and coordination
# Elastic training: minimum 2 nodes, maximum 4 nodes
torchrun --nnodes=2:4 --nproc_per_node=4 \
--rdzv_backend=c10d --rdzv_endpoint=host:port train.py
Inference Infrastructure
Serving Frameworks
| Framework | Developer | Key Features |
|---|---|---|
| Triton Inference Server | NVIDIA | General-purpose inference server supporting multiple backends (TensorRT, PyTorch, ONNX, etc.) |
| TGI | HuggingFace | Optimized for Transformers, Docker deployment |
| vLLM | UC Berkeley | PagedAttention, high throughput |
| TensorRT-LLM | NVIDIA | Maximum performance, requires compilation-based optimization |
For more details on vLLM, see vLLM and KV Cache
Batching Strategies
| Strategy | Description | Throughput |
|---|---|---|
| No Batching | Process requests one at a time | Lowest |
| Static Batching | Wait for a full batch before processing | Moderate |
| Dynamic Batching | Use a time window to collect requests | Higher |
| Continuous Batching | Dynamically schedule at every iteration | Highest |
Auto-Scaling
Dynamically adjusts the number of inference instances based on load:
- Metric-based: GPU utilization, request queue length, P99 latency
- Kubernetes HPA: Horizontal Pod Autoscaler — automatically adjusts the number of Pods based on metrics
- Warm-up: New instances need to load the model, typically requiring tens of seconds to several minutes
- Scale-down protection: Prevents oscillation from frequent scaling events (by setting a cooldown period)
References
- Rajbhandari et al., "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models", SC 2020
- Zhao et al., "PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel", VLDB 2023
- DeepSpeed Documentation
- PyTorch Distributed Overview