Skip to content

PyTorch Under the Hood

Dynamic Computation Graph

Tape-based

The tape-based mechanism of Autograd.

PyTorch's automatic differentiation system (Autograd) works like a tape recorder:

  • Process: When you execute the forward pass, PyTorch silently records every operation you perform (addition, multiplication, reshape) onto a "tape" in sequential order.
  • Backward: When you call .backward(), it is like pressing "rewind and play" — the system traverses the recorded tape in reverse, computing gradients via the chain rule.
  • Key property: Once the tape has been played back, it is destroyed (freeing memory). On the next iteration (next epoch), a fresh tape is used for a new recording. This is precisely why PyTorch naturally supports Python control flow such as if/else and loops — because the recorded content can be entirely different each time.

Storage/View Separation

This is at the core of how PyTorch handles data efficiently, and it is also the most common source of pitfalls for beginners (e.g., modifying a view inadvertently changes the original data, or encountering contiguous errors).

  • Principle: A Tensor is internally split into two parts in memory:
    1. Storage: A one-dimensional, contiguous array of numbers. This is where the actual data lives.
    2. Metadata (View): Records the shape (Size), stride (Stride), and offset (Offset). This determines how the tensor appears to be structured.
  • Example: When you call .view() or .reshape() to change a tensor's shape, PyTorch does not move any data in memory! It simply creates a new "metadata header" that points to the same Storage, just with a different interpretation (e.g., reading "2x3" as "3x2").
  • Why this matters:
    • Extremely fast: Reshaping costs virtually no time because no data is copied.
    • Pitfall: If you modify a value in the new tensor, the original tensor's value changes too (because they share the same storage). This also explains why you sometimes need to call .contiguous() — it forces PyTorch to actually copy the data into a new, independent storage.

Imperative Programming

PyTorch does not attempt to invent a new "graph language." Instead, it builds directly on top of the Python interpreter:

  • PyTorch objects are Python objects.
  • PyTorch stack traces are Python stack traces.
  • It leverages Python's reference counting mechanism for memory management.

Explicit Device Management

This is a notable difference in design philosophy between PyTorch and other frameworks (such as Keras).

  • Principle: "If you don't move it, neither will I." By default, all data resides on the CPU. PyTorch will never automatically transfer data to the GPU unless you explicitly write code to do so.
  • In code: .to('cuda') or .cuda().
  • Design philosophy: Although this approach requires more code (you must manage the device yourself), it gives developers absolute control over hardware behavior. You know exactly when each piece of GPU memory is allocated, avoiding the performance pitfalls that come with "black-box" automation.

Autograd in Depth

Computation Graph

During forward propagation, PyTorch's Autograd builds a Directed Acyclic Graph (DAG) in which:

  • Leaf nodes: User-created tensors such as model parameters (requires_grad=True)
  • Intermediate nodes: Intermediate results of forward computation
  • Edges: Record each operation's grad_fn (e.g., AddBackward0, MulBackward0)
x = torch.tensor([2.0], requires_grad=True)  # leaf node
y = x ** 2 + 3 * x  # intermediate node, y.grad_fn = <AddBackward0>
y.backward()         # backpropagation traverses the DAG in reverse
print(x.grad)        # tensor([7.]) = 2*2 + 3

Gradient Accumulation

By default, calling .backward() accumulates gradients (rather than replacing them). This design choice enables the following pattern:

# Simulate a large batch: actual batch=32, effective batch=128
optimizer.zero_grad()
for i in range(4):
    loss = model(batch[i]) / 4  # divide by accumulation steps
    loss.backward()              # gradients accumulate
optimizer.step()                 # single update

Common pitfall: Forgetting to call optimizer.zero_grad() causes gradients to accumulate indefinitely, leading to training divergence.

detach() and torch.no_grad()

Both prevent gradient propagation, but they serve different purposes:

Method Effect Typical Use Case
tensor.detach() Detaches a tensor from the computation graph; returns a new tensor sharing the same data Freezing target networks (e.g., DQN target network)
torch.no_grad() Context manager; no computation graph is built for operations inside Inference and evaluation
tensor.requires_grad_(False) In-place modification; stops gradient tracking for this tensor Freezing specific layer parameters
# detach: sever the gradient flow, but the value remains correct
target = model_target(state).detach()

# no_grad: save memory during inference
with torch.no_grad():
    predictions = model(test_data)

Tensor Memory Model in Depth

The Underlying Relationship Between Storage and View

Every Tensor consists of two parts:

  1. torch.Storage: The underlying contiguous 1D array (the actual data)
  2. Metadata: shape, stride, offset (determines how the Storage is interpreted)
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = a.T  # transpose

# They share the same Storage
print(a.data_ptr() == b.data_ptr())  # True

# But their strides differ
print(a.stride())  # (3, 1) -- row-major
print(b.stride())  # (1, 3) -- column-major

The Essence of contiguous

A tensor is contiguous if and only if its physical memory layout matches its logical index order.

a = torch.randn(3, 4)
b = a.T
print(b.is_contiguous())  # False

# Some operations require a contiguous layout
c = b.contiguous()  # triggers an actual data copy
# Alternatively, use reshape (which copies automatically when non-contiguous)

When you need .contiguous():

  • Before calling .view() (.reshape() handles this automatically)
  • When passing data to certain CUDA kernels
  • When serializing or saving tensors

Intuition Behind stride

stride tells you how many elements to skip in memory to move to the next element along each dimension:

x = torch.randn(2, 3, 4)
print(x.stride())  # (12, 4, 1)
# Meaning: skip 12 for dim=0, skip 4 for dim=1, skip 1 for dim=2

Dynamic vs. Static Computation Graphs

Feature PyTorch (Dynamic) TensorFlow 1.x (Static)
Graph construction Built on-the-fly during each forward pass Defined first, then executed
Python control flow Natively supports if/else/for Requires tf.cond/tf.while_loop
Debugging Standard Python debugger (pdb) Difficult; requires tf.debugging
Dynamic shapes Natively supported Requires explicit handling
Performance optimization JIT (torch.compile) Compiler optimizations (XLA)
Deployment TorchScript / ONNX SavedModel / TF Lite

torch.compile (PyTorch 2.0+):

PyTorch 2.0 introduced torch.compile(), which obtains compiler optimizations while preserving the flexibility of dynamic graphs:

model = torch.compile(model)  # automatic optimization
# The first call triggers compilation (slower); subsequent calls execute quickly
output = model(input)

Under the hood, it uses TorchDynamo (to capture the computation graph) + TorchInductor (to generate optimized kernels).


Essential Operations

Broadcasting

PyTorch's broadcasting rules are identical to NumPy's: dimensions are aligned from right to left; each dimension must either be equal or one of them must be 1.

a = torch.randn(3, 1)    # shape: (3, 1)
b = torch.randn(1, 4)    # shape: (1, 4)
c = a + b                 # shape: (3, 4) -- auto-broadcast

einsum: Einstein Summation

torch.einsum is a universal tensor operation tool that expresses arbitrary tensor contractions via subscript notation:

# Matrix multiplication
C = torch.einsum('ij,jk->ik', A, B)

# Batched matrix multiplication
C = torch.einsum('bij,bjk->bik', A, B)

# Attention scores: (batch, heads, seq_q, dim) x (batch, heads, seq_k, dim) -> (batch, heads, seq_q, seq_k)
scores = torch.einsum('bhqd,bhkd->bhqk', Q, K)

# Trace
t = torch.einsum('ii->', A)

# Outer product
outer = torch.einsum('i,j->ij', a, b)

Advanced Indexing

x = torch.randn(5, 3)

# Boolean indexing
mask = x > 0
positives = x[mask]  # 1D tensor

# Fancy indexing (integer array indexing)
indices = torch.tensor([0, 2, 4])
selected = x[indices]  # shape: (3, 3)

# scatter and gather -- advanced aggregation operations
# gather: pick values along a specified dimension by index
src = torch.randn(3, 4)
idx = torch.tensor([[0, 1, 2, 0], [1, 2, 3, 1]])
out = torch.gather(src, 0, idx)

GPU Memory Management

Memory Allocation Mechanism

PyTorch uses a CUDA Caching Allocator to manage GPU memory:

  • Freed memory is not immediately returned to CUDA; instead, it is cached in a memory pool for reuse
  • torch.cuda.memory_allocated() -- currently allocated memory
  • torch.cuda.memory_reserved() -- total reserved memory (including cache)
  • torch.cuda.empty_cache() -- releases the cache (but not memory in active use)
# Monitor GPU memory usage
print(f"Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"Reserved:  {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

Components of GPU Memory Usage

For a typical training step, GPU memory is consumed primarily by:

Component Estimated Share Notes
Model parameters ~25% Weight matrices
Gradients ~25% Same size as parameters
Optimizer states ~25-50% Adam requires m and v, consuming 2x parameter size
Activations (intermediate results) Variable Grows with batch size and sequence length

Gradient Checkpointing

Problem: When training large models, storing all intermediate activations for backpropagation consumes enormous amounts of memory.

Solution: Save only a subset of activations (checkpoints) and recompute the rest during backpropagation. This trades compute time for memory.

from torch.utils.checkpoint import checkpoint

class MyModel(nn.Module):
    def forward(self, x):
        # Use checkpointing for memory-heavy blocks
        x = checkpoint(self.heavy_block1, x, use_reentrant=False)
        x = checkpoint(self.heavy_block2, x, use_reentrant=False)
        return x

Effect: Memory drops from \(O(L)\) (where \(L\) is the number of layers) to \(O(\sqrt{L})\), at the cost of approximately 33% additional computation.

Mixed Precision Training

Using FP16/BF16 instead of FP32 can halve memory usage and accelerate computation:

scaler = torch.amp.GradScaler()  # prevents FP16 underflow

for data, target in loader:
    optimizer.zero_grad()
    with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
        output = model(data)
        loss = criterion(output, target)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

BF16 vs. FP16:

  • FP16: Higher precision but smaller dynamic range; requires loss scaling
  • BF16: Same dynamic range as FP32; no loss scaling needed, but slightly lower precision
  • On modern GPUs (A100/H100), BF16 is recommended

评论 #