The Computational Graph
Every time you perform an operation on a tensor that has
requires_grad=True
, PyTorch silently builds a
directed acyclic graph (DAG)
recording every operation. Each node in this graph stores the operation that was applied (add, multiply, matmul, and so on) along with pointers to the nodes representing its inputs. This graph is the blueprint for computing gradients — when you call
.backward()
, PyTorch walks the graph in reverse, applying the chain rule at each node to propagate gradient information from the loss all the way back to the parameters.
The key insight is that these two phases serve different purposes. The
forward pass
builds the graph: as each operation executes, it registers itself as a node, storing both the function that was computed and any intermediate values needed later for the derivative. The
backward pass
consumes the graph: it traverses every node in reverse topological order, computing the gradient contribution from each operation, then — by default —
frees the graph entirely
. This is why calling
.backward()
a second time on the same loss raises an error: the graph no longer exists. If you genuinely need to backpropagate through the same graph twice (for example, when computing higher-order derivatives), you can pass
retain_graph=True
, which keeps the graph in memory at the cost of additional memory usage.
Let's simulate this process manually with NumPy. We'll define a simple computation — $y = wx + b$ followed by a squared-error loss — and then walk backward through it, computing each gradient by hand using the chain rule.
import numpy as np
# Simple computation: y = w*x + b, loss = (y - target)^2
w = 2.0
x = 3.0
b = 1.0
target = 10.0
# Forward pass (building the "graph" mentally)
y = w * x + b # y = 2*3 + 1 = 7
loss = (y - target)**2 # loss = (7 - 10)^2 = 9
print("Forward pass:")
print(f" w={w}, x={x}, b={b}")
print(f" y = w*x + b = {y}")
print(f" loss = (y - target)² = ({y} - {target})² = {loss}")
print()
# Backward pass (chain rule, working backwards)
dloss_dy = 2 * (y - target) # d(loss)/dy = 2(y - target) = -6
dy_dw = x # d(y)/dw = x = 3
dy_db = 1.0 # d(y)/db = 1
dloss_dw = dloss_dy * dy_dw # chain rule: -6 * 3 = -18
dloss_db = dloss_dy * dy_db # chain rule: -6 * 1 = -6
print("Backward pass (chain rule):")
print(f" d(loss)/dy = 2(y - target) = 2({y} - {target}) = {dloss_dy}")
print(f" d(y)/dw = x = {dy_dw}")
print(f" d(y)/db = 1")
print(f" d(loss)/dw = d(loss)/dy × d(y)/dw = {dloss_dy} × {dy_dw} = {dloss_dw}")
print(f" d(loss)/db = d(loss)/dy × d(y)/db = {dloss_dy} × {dy_db} = {dloss_db}")
The gradient $\partial L / \partial w = -18$ tells us that increasing $w$ by a tiny amount would decrease the loss (since the gradient is negative), and the magnitude tells us the loss is quite sensitive to $w$. Similarly, $\partial L / \partial b = -6$ indicates the loss is sensitive to $b$ as well, but less so than to $w$. A gradient descent step would nudge both parameters in the direction that reduces the loss: $w \leftarrow w - \alpha \cdot (-18)$ and $b \leftarrow b - \alpha \cdot (-6)$ , where $\alpha$ is the learning rate.
Now let's verify that PyTorch's autograd produces the exact same results:
import torch
w = torch.tensor(2.0, requires_grad=True)
x = torch.tensor(3.0)
b = torch.tensor(1.0, requires_grad=True)
target = torch.tensor(10.0)
y = w * x + b
loss = (y - target) ** 2
loss.backward()
print(w.grad) # tensor(-18.) ← same as our manual calculation
print(b.grad) # tensor(-6.) ← same
The numbers match exactly. Under the hood, PyTorch performed the same chain-rule decomposition we did by hand — it just did it automatically by traversing the graph it built during the forward pass.
The Chain Rule in Action
The chain rule is the mathematical backbone of backpropagation. For a composition of functions where the loss $L$ depends on a parameter $w$ through a sequence of intermediate computations — say $L = f(g(h(w)))$ — the chain rule tells us we can decompose the full derivative into a product of simpler, local derivatives:
Each factor in this product is a local derivative — it answers one narrow question: how much does this particular layer's output change when its input changes by a small amount? The factor $\partial h / \partial w$ measures the sensitivity of the first operation's output to the parameter $w$. The factor $\partial g / \partial h$ measures how much the second operation amplifies or dampens changes coming from below. And $\partial L / \partial g$ measures how the loss responds to changes in the final intermediate value. Their product gives the end-to-end sensitivity of the loss to the parameter — exactly what we need for gradient descent.
Why is this decomposition so important? Because computing the full derivative directly — by symbolic differentiation of the entire network as one monolithic function — would be intractable for models with millions or billions of parameters. The chain rule lets us decompose the problem into a sequence of simple local derivatives, each of which is computable from values already available during the forward pass. This is the insight that makes modern deep learning feasible.
Let's see the chain rule in action with a deeper computation graph that includes a ReLU activation — one of the most common nonlinearities in neural networks. ReLU is defined as $\text{ReLU}(x) = \max(0, x)$, and its derivative is particularly simple: 1 when the input is positive, 0 when the input is negative. This piecewise nature leads to an important phenomenon we'll explore below.
import numpy as np
# Deeper graph: z = relu(w*x + b), loss = (z - target)^2
w, x, b, target = 0.5, 4.0, -1.5, 3.0
# Forward
pre_act = w * x + b # 0.5*4 + (-1.5) = 0.5
z = max(0, pre_act) # relu(0.5) = 0.5
loss = (z - target) ** 2 # (0.5 - 3)^2 = 6.25
print("Forward:")
print(f" pre_act = w*x + b = {pre_act}")
print(f" z = relu(pre_act) = {z}")
print(f" loss = (z - target)² = {loss}")
print()
# Backward (chain rule, layer by layer)
dloss_dz = 2 * (z - target) # = -5.0
dz_dpre = 1.0 if pre_act > 0 else 0.0 # relu derivative
dpre_dw = x # = 4.0
dpre_db = 1.0
dloss_dw = dloss_dz * dz_dpre * dpre_dw # -5 * 1 * 4 = -20
dloss_db = dloss_dz * dz_dpre * dpre_db # -5 * 1 * 1 = -5
print("Backward (chain rule):")
print(f" ∂loss/∂z = 2(z - target) = {dloss_dz}")
print(f" ∂z/∂pre = {'1 (pre > 0)' if pre_act > 0 else '0 (pre ≤ 0)'} = {dz_dpre}")
print(f" ∂pre/∂w = x = {dpre_dw}")
print(f" ∂loss/∂w = {dloss_dz} × {dz_dpre} × {dpre_dw} = {dloss_dw}")
print(f" ∂loss/∂b = {dloss_dz} × {dz_dpre} × {dpre_db} = {dloss_db}")
print()
# Show what happens when relu "kills" the gradient
w2, x2, b2 = 0.5, 4.0, -3.0
pre_act2 = w2 * x2 + b2 # = -1.0 (negative!)
z2 = max(0, pre_act2) # relu(-1) = 0
dz_dpre2 = 1.0 if pre_act2 > 0 else 0.0 # = 0!
print("When pre_act is negative (relu kills gradient):")
print(f" pre_act = {pre_act2}, relu = {z2}")
print(f" ∂z/∂pre = {dz_dpre2} → gradient is ZERO, nothing flows back")
print(f" This is the 'dying ReLU' problem.")
Notice what happened in the second case: when the pre-activation value was negative ($-1.0$), ReLU clamped it to zero, and the local derivative $\partial z / \partial \text{pre}$ became zero as well. Because the chain rule multiplies all the local derivatives together, a single zero anywhere in the chain kills the entire gradient. No gradient flows back to $w$ or $b$, so those parameters receive no learning signal. This is known as the dying ReLU problem (Lu et al., 2019) — if a neuron's pre-activation is negative for every training example, its gradient is permanently zero and it effectively becomes a dead weight in the network. Alternatives like Leaky ReLU and GELU avoid this by ensuring the derivative is never exactly zero, even for negative inputs.
Gradient Accumulation and .zero_grad()
PyTorch accumulates gradients by default: calling
.backward()
adds to
the
.grad
attribute rather than replacing it. This design choice is deliberate — it's useful when you want to accumulate gradients across multiple mini-batches (a technique called
gradient accumulation
) or when you have multiple loss functions that each contribute gradients to the same parameters.
But it's also one of the most common footguns in PyTorch. If you forget to call
optimizer.zero_grad()
(or
model.zero_grad()
) before each training step, gradients from previous steps pile up and the effective gradient becomes a sum of all past gradients. Training typically diverges — the parameter updates grow larger and larger as stale gradients accumulate, and the loss explodes or oscillates wildly.
Let's simulate both scenarios — the bug (forgetting to zero) and the intentional use (gradient accumulation for large effective batch sizes):
import numpy as np
# Simulate gradient accumulation
grad_w = 0.0 # starts at zero
# Step 1
loss_grad_1 = -18.0
grad_w += loss_grad_1 # accumulate
print(f"After step 1: grad_w = {grad_w}")
# Step 2 (forgot to zero!)
loss_grad_2 = -12.0
grad_w += loss_grad_2 # accumulates on top!
print(f"After step 2 (no zero_grad): grad_w = {grad_w} ← WRONG, should be {loss_grad_2}")
# Correct: zero first
grad_w = 0.0 # zero_grad()
grad_w += loss_grad_2
print(f"After step 2 (with zero_grad): grad_w = {grad_w} ← correct")
print()
# When accumulation is INTENTIONAL (gradient accumulation for large effective batches)
effective_batch_size = 4
micro_batch_grads = [-5.0, -3.0, -7.0, -1.0]
grad_w = 0.0
for i, g in enumerate(micro_batch_grads):
grad_w += g # intentional accumulation
print(f" Micro-batch {i+1}: grad += {g}, total = {grad_w}")
grad_w /= effective_batch_size # average
print(f"Average gradient: {grad_w}")
print(f"(Same as processing all 4 at once: {sum(micro_batch_grads)/4})")
Gradient accumulation is particularly useful when your GPU memory can't fit the batch size you'd ideally use. For instance, if the optimal batch size is 32 but your GPU can only handle 8 samples at a time, you can run 4 forward-backward passes with micro-batches of 8, accumulating gradients, and then do a single optimizer step. The mathematical effect is identical to processing all 32 samples at once (assuming you average the accumulated gradients), but the peak memory footprint is that of a batch of 8.
Modern PyTorch also offers
optimizer.zero_grad(set_to_none=True)
, which sets
.grad
to
None
instead of filling it with zeros. This is slightly more memory-efficient (no zero tensor allocated) and can be marginally faster, though the difference is typically small.
torch.no_grad() and Inference
During inference (or evaluation), you don't need gradients. But by default, PyTorch still builds the computational graph for every operation — storing intermediate activations that would be needed for the backward pass, recording which operations were performed, and consuming both memory and compute in the process. For a model with billions of parameters, these stored intermediates can easily double or triple the memory footprint compared to what the forward pass alone requires.
PyTorch provides two mechanisms to skip graph construction. The first is the
torch.no_grad()
context manager, which temporarily disables gradient tracking for all operations within its scope. The second is
@torch.inference_mode()
, a decorator (or context manager) that is stricter and slightly more efficient.
# Context manager — temporarily disables gradient tracking
with torch.no_grad():
output = model(input) # no graph built, no memory for intermediates
# Decorator — stricter, slightly faster
@torch.inference_mode()
def predict(model, input):
return model(input) # no graph, no stale tensor issues
# Why this matters: a model with 1B parameters stores intermediate
# activations during forward pass for backward. Skipping that
# can cut memory usage by 2-3× during inference.
The key difference between the two is subtle but important.
torch.no_grad()
disables gradient computation but still allows in-place operations on tensors that originally had
requires_grad=True
. This means you can accidentally create a tensor inside a
no_grad()
block, pass it outside, and later try to backpropagate through it — leading to confusing errors or silent correctness bugs.
torch.inference_mode()
is stricter: it marks all tensors created within its scope as
inference tensors
, which cannot be used as inputs to operations that require gradient tracking. If inference outputs accidentally leak into training code, PyTorch raises an error immediately rather than producing wrong gradients silently. This makes
inference_mode()
the safer and generally preferred choice for production inference pipelines.
Custom Autograd Functions
Sometimes you need operations that PyTorch doesn't have built-in derivatives for, or you want to compute the gradient differently than the default — for example, using
gradient checkpointing
to trade compute for memory by recomputing activations during the backward pass instead of storing them. PyTorch's
torch.autograd.Function
lets you define custom forward and backward logic that plugs seamlessly into the autograd engine.
Here's a custom implementation of ReLU as an autograd function:
class CustomReLU(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
# ctx stores values needed for backward
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
# Gradient of ReLU: 1 where input > 0, 0 elsewhere
grad_input = grad_output * (input > 0).float()
return grad_input
# Usage:
output = CustomReLU.apply(input_tensor)
Let's break down each piece, because every line serves a specific purpose in the autograd machinery:
-
forward(ctx, input)— computes the output of the operation (here, clamping negative values to zero). Thectxobject is a context that acts as a bridge between forward and backward. -
ctx.save_for_backward(input)— stores the input tensor so it's available during the backward pass. PyTorch manages the memory lifecycle of these saved tensors: they are kept alive until backward completes, then released. You should only save tensors (not arbitrary Python objects) through this method, as PyTorch needs to track them for memory management and gradient computation. -
backward(ctx, grad_output)— receives the upstream gradient (grad_output, also called the "incoming" gradient from layers above) and returns the gradient with respect to each input offorward. This is the chain rule in action: we multiply the upstream gradient by the local derivative. -
(input > 0).float()— the local derivative of ReLU. This creates a binary mask that is 1.0 wherever the input was positive and 0.0 elsewhere. Multiplying the upstream gradient by this mask implements the chain rule: gradients flow through where ReLU was active and are blocked where it was not.
Custom autograd functions are also the mechanism behind several important techniques in modern deep learning. Gradient checkpointing (Chen et al., 2016) uses a custom function whose forward pass discards intermediate activations to save memory, and whose backward pass re-runs the forward computation to reconstruct them on the fly. Straight-through estimators (Bengio et al., 2013) use a custom backward that passes gradients through non-differentiable operations (like rounding or argmax) by pretending the derivative is 1. These are powerful tools, but they come with responsibility — an incorrect backward implementation will produce wrong gradients silently, and debugging gradient errors tends to be much harder than debugging forward-pass errors.
Quiz
Test your understanding of PyTorch's autograd system.
Why does PyTorch accumulate gradients by default instead of overwriting them?
In the chain rule ∂L/∂w = ∂L/∂y · ∂y/∂w, what does ∂y/∂w represent?
What is the key difference between torch.no_grad() and torch.inference_mode()?
In a custom autograd Function, what does ctx.save_for_backward() do?