Why Eager Mode Leaves Performance on the Table

In eager mode, PyTorch executes each operation immediately: Python calls torch.matmul , C++ runs it, the result returns to Python, then Python calls torch.relu , and so on. This is great for debugging — you can inspect intermediate values, set breakpoints, use print statements — but it means the system can never see the bigger picture.

Consider three consecutive operations: matmul, add bias, relu. In eager mode, each launches a separate GPU kernel, each with its own memory reads and writes. But these could be fused into a single kernel that reads the inputs once, computes all three operations, and writes the output once — avoiding two round-trips to GPU memory. Eager mode can't do this because it doesn't know the next operation until Python tells it.

An analogy may help here. Eager mode is like a translator who translates one sentence, hands it to the reader, waits, then translates the next. Graph mode is like reading the entire paragraph first and delivering a polished translation. The end result is the same, but the process is far more efficient when you can see the whole picture before you start.

The following example demonstrates the fusion opportunity. We simulate three separate "eager" operations and compare them with a single fused function. The math is identical — the difference is how many times we pass through memory.

import numpy as np

# Simulate eager: 3 separate operations, 3 memory round-trips
x = np.random.randn(4, 4).astype(np.float32)
w = np.random.randn(4, 4).astype(np.float32)
b = np.random.randn(4).astype(np.float32)

# Eager: each step reads from memory and writes to memory
step1 = x @ w                          # read x,w → compute → write step1
step2 = step1 + b                      # read step1,b → compute → write step2
step3 = np.maximum(step2, 0)           # read step2 → compute → write step3

# Fused: one pass, reads x,w,b once, writes output once
def fused_linear_relu(x, w, b):
    return np.maximum(x @ w + b, 0)

fused = fused_linear_relu(x, w, b)

print("Eager (3 steps):  3 memory reads + 3 memory writes")
print("Fused (1 kernel): 1 memory read  + 1 memory write")
print(f"Results match: {np.allclose(step3, fused)}")
print()
print("On a GPU, memory bandwidth is often the bottleneck,")
print("so reducing memory round-trips can speed things up significantly.")

Graph Mode: Seeing the Whole Picture

torch.compile(model) opts into graph mode. Instead of executing operations one at a time, PyTorch first captures the entire computation as a graph — a directed acyclic graph (DAG) of math operations with no Python control flow, no print statements, no side effects.

This graph can then be optimised as a whole: fusing operations, reordering them for better memory access, choosing optimal kernel implementations. The optimised graph is compiled into efficient code that runs without returning to the Python interpreter between operations. The result is that hundreds of small Python-to-C++ round-trips collapse into a single compiled function call.

Here is a minimal example. Notice that the API is remarkably simple — a single call to torch.compile wraps the model and returns a compiled version that is functionally identical but (after the first compilation pass) substantially faster.

import torch

class SimpleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(512, 512)

    def forward(self, x):
        return torch.relu(self.linear(x))

model = SimpleModel().cuda()

# Eager mode (default): each op runs immediately
output_eager = model(torch.randn(32, 512).cuda())

# Compiled: captures graph, optimises, compiles
compiled_model = torch.compile(model)
output_compiled = compiled_model(torch.randn(32, 512).cuda())
# First call is slow (compiling), subsequent calls are fast

TorchDynamo: The Frontend Compiler

TorchDynamo is the component that captures the graph. It works by intercepting Python bytecode execution — literally watching what CPython does instruction by instruction — and recording the math operations into a graph.

When Dynamo encounters something it can't capture (a print statement, a data-dependent if/else , a call to an external library), it creates a graph break . A graph break is not an error. Instead, Dynamo:

  • Compiles and runs the graph captured so far
  • Falls back to the Python interpreter for the un-capturable code
  • Starts a new graph capture after the break

This means torch.compile is always correct — it never silently changes behaviour. But each graph break is a lost optimisation opportunity, because the compiler can't fuse operations across the break.

def forward(self, x):
    x = self.linear1(x)      # ┐
    x = torch.relu(x)        # │ Graph 1 (compiled, fused)
    x = self.linear2(x)      # ┘

    print(f"Shape: {x.shape}") # ← GRAPH BREAK (print is a side effect)

    x = self.linear3(x)      # ┐
    x = torch.sigmoid(x)     # │ Graph 2 (compiled, fused)
    return x                  # ┘

# Result: 2 compiled graphs with a Python interpreter gap between them.
# Remove the print → 1 fused graph → faster.
📌 Common graph-break triggers: print(), logging, pdb breakpoints, data-dependent control flow (if x.sum() > 0), calls to non-PyTorch libraries (e.g., scipy), and Python built-ins like list comprehensions over tensors. Use torch._dynamo.explain(model, input) to find graph breaks in your model.

TorchInductor: The Backend Compiler

Once Dynamo has captured a graph of math operations, TorchInductor translates it into optimised code. For GPU targets, Inductor generates Triton kernels . For CPU targets, it generates C++ code.

The key optimisations Inductor performs:

  • Operator fusion : matmul + bias + relu becomes one kernel launch instead of three, eliminating intermediate memory writes
  • Memory planning : reuse memory buffers across operations to reduce peak allocation, so a model that would have allocated ten temporary tensors might only need three
  • Layout optimisation : choose memory layouts (channels-first vs channels-last) that are optimal for each kernel, inserting layout conversions only where necessary

The following shows a simplified version of what Inductor might generate for a fused linear + relu operation. The key insight is that the bias addition and relu are performed inside the matmul kernel, so the intermediate results never leave the GPU's fast registers to round-trip through slower global memory.

# What Inductor generates (simplified) for linear + relu:
@triton.jit
def fused_linear_relu_kernel(
    x_ptr, w_ptr, b_ptr, out_ptr,
    M, N, K,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    # Compute matmul tile
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    for k in range(0, K, BLOCK_K):
        a = tl.load(x_ptr + ...)
        b = tl.load(w_ptr + ...)
        acc += tl.dot(a, b)

    # Fused: add bias and relu in the same kernel
    bias = tl.load(b_ptr + ...)
    acc = acc + bias
    acc = tl.maximum(acc, 0.0)  # relu fused in!

    tl.store(out_ptr + ..., acc)

# One kernel launch instead of three. One memory write instead of three.

Triton: The Kernel Compiler

Triton serves two roles: it is both a domain-specific language (DSL) for writing GPU kernels and a compiler that turns that DSL into GPU machine code.

Why does Triton exist? Writing CUDA kernels is hard. You must manually manage thread blocks, shared memory, memory coalescing, and synchronisation. Triton abstracts these away — you write at the "tile" level (blocks of data), and the Triton compiler handles thread scheduling and memory management. This is a substantial productivity gain: a fused kernel that might take hundreds of lines of CUDA C++ can often be expressed in a few dozen lines of Triton.

The Triton compilation pipeline looks like this:

Triton DSL (@triton.jit decorated Python)
    ↓
Triton IR (intermediate representation)
    ↓
LLVM IR (general-purpose IR)
    ↓
PTX (NVIDIA portable assembly)
    ↓ ptxas (NVIDIA assembler)
SASS (GPU-specific machine code)

Contrast this with cuBLAS and cuDNN: those kernels were written in CUDA C++ ( .cu files), compiled with NVCC (producing PTX, then SASS), and shipped as pre-compiled binaries. Triton kernels, by contrast, are compiled at runtime (JIT). This makes the first invocation slower, but it allows the compiler to optimise for the specific GPU architecture and input shapes that are actually encountered — something pre-compiled binaries fundamentally cannot do.

💡 This is why the first call to a torch.compiled model is slow — Triton is compiling kernels. Subsequent calls reuse the cached compiled kernels and run at full speed. PyTorch caches compiled kernels to disk, so even across Python sessions the compilation cost is usually paid only once.

When to torch.compile (and When Not To)

torch.compile is not a universal "go faster" button — it involves real tradeoffs. Here is a practical guide to when it tends to help and when it tends to hurt.

Compile when:

  • Your model architecture is stable and you're in production or sustained training. The compilation cost is amortised over thousands of iterations.
  • Your operations are large enough that fusion helps — matmul, attention, feed-forward networks. These have enough arithmetic intensity to benefit from reduced memory traffic.
  • You're willing to wait for initial compilation (typically seconds to a few minutes, depending on model complexity).

Don't compile when:

  • Debugging. Graph breaks make tracing harder, and error messages from compiled code are often less clear than eager-mode errors.
  • Dynamic shapes. Each new input shape may trigger recompilation, which can make things slower overall if shapes change frequently.
  • Rapid prototyping. When you're iterating on model architecture every few minutes, the compilation overhead dominates short runs.
  • Very small models. If the Python overhead per operation is already negligible relative to compute time, fusion won't save much.

Quiz

Test your understanding of torch.compile, TorchDynamo, TorchInductor, and the Triton compilation pipeline.

Why can't eager mode fuse matmul + bias + relu into one kernel?

What is a graph break in TorchDynamo, and is it an error?

What is Triton's role in the torch.compile pipeline?

Why is the first call to a torch.compiled model slow?