PyTorch Under the Hood
Dynamic Computation Graph
Tape-based
The tape-based mechanism of Autograd.
PyTorch's automatic differentiation system (Autograd) works like a tape recorder:
- Process: When you execute the forward pass, PyTorch silently records every operation you perform (addition, multiplication, reshape) onto a "tape" in sequential order.
- Backward: When you call
.backward(), it is like pressing "rewind and play" — the system traverses the recorded tape in reverse, computing gradients via the chain rule. - Key property: Once the tape has been played back, it is destroyed (freeing memory). On the next iteration (next epoch), a fresh tape is used for a new recording. This is precisely why PyTorch naturally supports Python control flow such as
if/elseand loops — because the recorded content can be entirely different each time.
Storage/View Separation
This is at the core of how PyTorch handles data efficiently, and it is also the most common source of pitfalls for beginners (e.g., modifying a view inadvertently changes the original data, or encountering contiguous errors).
- Principle: A
Tensoris internally split into two parts in memory:- Storage: A one-dimensional, contiguous array of numbers. This is where the actual data lives.
- Metadata (View): Records the shape (Size), stride (Stride), and offset (Offset). This determines how the tensor appears to be structured.
- Example:
When you call
.view()or.reshape()to change a tensor's shape, PyTorch does not move any data in memory! It simply creates a new "metadata header" that points to the same Storage, just with a different interpretation (e.g., reading "2x3" as "3x2"). - Why this matters:
- Extremely fast: Reshaping costs virtually no time because no data is copied.
- Pitfall: If you modify a value in the new tensor, the original tensor's value changes too (because they share the same storage). This also explains why you sometimes need to call
.contiguous()— it forces PyTorch to actually copy the data into a new, independent storage.
Imperative Programming
PyTorch does not attempt to invent a new "graph language." Instead, it builds directly on top of the Python interpreter:
- PyTorch objects are Python objects.
- PyTorch stack traces are Python stack traces.
- It leverages Python's reference counting mechanism for memory management.
Explicit Device Management
This is a notable difference in design philosophy between PyTorch and other frameworks (such as Keras).
- Principle: "If you don't move it, neither will I." By default, all data resides on the CPU. PyTorch will never automatically transfer data to the GPU unless you explicitly write code to do so.
- In code:
.to('cuda')or.cuda(). - Design philosophy: Although this approach requires more code (you must manage the
deviceyourself), it gives developers absolute control over hardware behavior. You know exactly when each piece of GPU memory is allocated, avoiding the performance pitfalls that come with "black-box" automation.
Autograd in Depth
Computation Graph
During forward propagation, PyTorch's Autograd builds a Directed Acyclic Graph (DAG) in which:
- Leaf nodes: User-created tensors such as model parameters (
requires_grad=True) - Intermediate nodes: Intermediate results of forward computation
- Edges: Record each operation's
grad_fn(e.g.,AddBackward0,MulBackward0)
x = torch.tensor([2.0], requires_grad=True) # leaf node
y = x ** 2 + 3 * x # intermediate node, y.grad_fn = <AddBackward0>
y.backward() # backpropagation traverses the DAG in reverse
print(x.grad) # tensor([7.]) = 2*2 + 3
Gradient Accumulation
By default, calling .backward() accumulates gradients (rather than replacing them). This design choice enables the following pattern:
# Simulate a large batch: actual batch=32, effective batch=128
optimizer.zero_grad()
for i in range(4):
loss = model(batch[i]) / 4 # divide by accumulation steps
loss.backward() # gradients accumulate
optimizer.step() # single update
Common pitfall: Forgetting to call optimizer.zero_grad() causes gradients to accumulate indefinitely, leading to training divergence.
detach() and torch.no_grad()
Both prevent gradient propagation, but they serve different purposes:
| Method | Effect | Typical Use Case |
|---|---|---|
tensor.detach() |
Detaches a tensor from the computation graph; returns a new tensor sharing the same data | Freezing target networks (e.g., DQN target network) |
torch.no_grad() |
Context manager; no computation graph is built for operations inside | Inference and evaluation |
tensor.requires_grad_(False) |
In-place modification; stops gradient tracking for this tensor | Freezing specific layer parameters |
# detach: sever the gradient flow, but the value remains correct
target = model_target(state).detach()
# no_grad: save memory during inference
with torch.no_grad():
predictions = model(test_data)
Tensor Memory Model in Depth
The Underlying Relationship Between Storage and View
Every Tensor consists of two parts:
torch.Storage: The underlying contiguous 1D array (the actual data)- Metadata:
shape,stride,offset(determines how the Storage is interpreted)
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = a.T # transpose
# They share the same Storage
print(a.data_ptr() == b.data_ptr()) # True
# But their strides differ
print(a.stride()) # (3, 1) -- row-major
print(b.stride()) # (1, 3) -- column-major
The Essence of contiguous
A tensor is contiguous if and only if its physical memory layout matches its logical index order.
a = torch.randn(3, 4)
b = a.T
print(b.is_contiguous()) # False
# Some operations require a contiguous layout
c = b.contiguous() # triggers an actual data copy
# Alternatively, use reshape (which copies automatically when non-contiguous)
When you need .contiguous():
- Before calling
.view()(.reshape()handles this automatically) - When passing data to certain CUDA kernels
- When serializing or saving tensors
Intuition Behind stride
stride tells you how many elements to skip in memory to move to the next element along each dimension:
x = torch.randn(2, 3, 4)
print(x.stride()) # (12, 4, 1)
# Meaning: skip 12 for dim=0, skip 4 for dim=1, skip 1 for dim=2
Dynamic vs. Static Computation Graphs
| Feature | PyTorch (Dynamic) | TensorFlow 1.x (Static) |
|---|---|---|
| Graph construction | Built on-the-fly during each forward pass | Defined first, then executed |
| Python control flow | Natively supports if/else/for | Requires tf.cond/tf.while_loop |
| Debugging | Standard Python debugger (pdb) | Difficult; requires tf.debugging |
| Dynamic shapes | Natively supported | Requires explicit handling |
| Performance optimization | JIT (torch.compile) | Compiler optimizations (XLA) |
| Deployment | TorchScript / ONNX | SavedModel / TF Lite |
torch.compile (PyTorch 2.0+):
PyTorch 2.0 introduced torch.compile(), which obtains compiler optimizations while preserving the flexibility of dynamic graphs:
model = torch.compile(model) # automatic optimization
# The first call triggers compilation (slower); subsequent calls execute quickly
output = model(input)
Under the hood, it uses TorchDynamo (to capture the computation graph) + TorchInductor (to generate optimized kernels).
Essential Operations
Broadcasting
PyTorch's broadcasting rules are identical to NumPy's: dimensions are aligned from right to left; each dimension must either be equal or one of them must be 1.
a = torch.randn(3, 1) # shape: (3, 1)
b = torch.randn(1, 4) # shape: (1, 4)
c = a + b # shape: (3, 4) -- auto-broadcast
einsum: Einstein Summation
torch.einsum is a universal tensor operation tool that expresses arbitrary tensor contractions via subscript notation:
# Matrix multiplication
C = torch.einsum('ij,jk->ik', A, B)
# Batched matrix multiplication
C = torch.einsum('bij,bjk->bik', A, B)
# Attention scores: (batch, heads, seq_q, dim) x (batch, heads, seq_k, dim) -> (batch, heads, seq_q, seq_k)
scores = torch.einsum('bhqd,bhkd->bhqk', Q, K)
# Trace
t = torch.einsum('ii->', A)
# Outer product
outer = torch.einsum('i,j->ij', a, b)
Advanced Indexing
x = torch.randn(5, 3)
# Boolean indexing
mask = x > 0
positives = x[mask] # 1D tensor
# Fancy indexing (integer array indexing)
indices = torch.tensor([0, 2, 4])
selected = x[indices] # shape: (3, 3)
# scatter and gather -- advanced aggregation operations
# gather: pick values along a specified dimension by index
src = torch.randn(3, 4)
idx = torch.tensor([[0, 1, 2, 0], [1, 2, 3, 1]])
out = torch.gather(src, 0, idx)
GPU Memory Management
Memory Allocation Mechanism
PyTorch uses a CUDA Caching Allocator to manage GPU memory:
- Freed memory is not immediately returned to CUDA; instead, it is cached in a memory pool for reuse
torch.cuda.memory_allocated()-- currently allocated memorytorch.cuda.memory_reserved()-- total reserved memory (including cache)torch.cuda.empty_cache()-- releases the cache (but not memory in active use)
# Monitor GPU memory usage
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
Components of GPU Memory Usage
For a typical training step, GPU memory is consumed primarily by:
| Component | Estimated Share | Notes |
|---|---|---|
| Model parameters | ~25% | Weight matrices |
| Gradients | ~25% | Same size as parameters |
| Optimizer states | ~25-50% | Adam requires m and v, consuming 2x parameter size |
| Activations (intermediate results) | Variable | Grows with batch size and sequence length |
Gradient Checkpointing
Problem: When training large models, storing all intermediate activations for backpropagation consumes enormous amounts of memory.
Solution: Save only a subset of activations (checkpoints) and recompute the rest during backpropagation. This trades compute time for memory.
from torch.utils.checkpoint import checkpoint
class MyModel(nn.Module):
def forward(self, x):
# Use checkpointing for memory-heavy blocks
x = checkpoint(self.heavy_block1, x, use_reentrant=False)
x = checkpoint(self.heavy_block2, x, use_reentrant=False)
return x
Effect: Memory drops from \(O(L)\) (where \(L\) is the number of layers) to \(O(\sqrt{L})\), at the cost of approximately 33% additional computation.
Mixed Precision Training
Using FP16/BF16 instead of FP32 can halve memory usage and accelerate computation:
scaler = torch.amp.GradScaler() # prevents FP16 underflow
for data, target in loader:
optimizer.zero_grad()
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
output = model(data)
loss = criterion(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
BF16 vs. FP16:
- FP16: Higher precision but smaller dynamic range; requires loss scaling
- BF16: Same dynamic range as FP32; no loss scaling needed, but slightly lower precision
- On modern GPUs (A100/H100), BF16 is recommended