Por Qué el Modo Eager Deja Rendimiento Sobre la Mesa
En modo eager, PyTorch ejecuta cada operación inmediatamente: Python llama a
torch.matmul
, C++ la ejecuta, el resultado vuelve a Python, luego Python llama a
torch.relu
, y así sucesivamente. Esto es excelente para depuración — puedes inspeccionar valores intermedios, establecer puntos de interrupción, usar sentencias print — pero significa que el sistema nunca puede ver el panorama completo.
Considera tres operaciones consecutivas: matmul, sumar sesgo, relu. En modo eager, cada una lanza un kernel GPU separado, cada uno con sus propias lecturas y escrituras de memoria. Pero estas podrían fusionarse en un solo kernel que lee las entradas una vez, computa las tres operaciones y escribe la salida una vez — evitando dos viajes de ida y vuelta a la memoria GPU. El modo eager no puede hacer esto porque no conoce la siguiente operación hasta que Python se la indica.
Una analogía puede ayudar aquí. El modo eager es como un traductor que traduce una oración, se la entrega al lector, espera, y luego traduce la siguiente. El modo grafo es como leer el párrafo completo primero y entregar una traducción pulida. El resultado final es el mismo, pero el proceso es mucho más eficiente cuando puedes ver el panorama completo antes de empezar.
El siguiente ejemplo demuestra la oportunidad de fusión. Simulamos tres operaciones "eager" separadas y las comparamos con una sola función fusionada. La matemática es idéntica — la diferencia es cuántas veces pasamos por la memoria.
import numpy as np
# Simulate eager: 3 separate operations, 3 memory round-trips
x = np.random.randn(4, 4).astype(np.float32)
w = np.random.randn(4, 4).astype(np.float32)
b = np.random.randn(4).astype(np.float32)
# Eager: each step reads from memory and writes to memory
step1 = x @ w # read x,w → compute → write step1
step2 = step1 + b # read step1,b → compute → write step2
step3 = np.maximum(step2, 0) # read step2 → compute → write step3
# Fused: one pass, reads x,w,b once, writes output once
def fused_linear_relu(x, w, b):
return np.maximum(x @ w + b, 0)
fused = fused_linear_relu(x, w, b)
print("Eager (3 steps): 3 memory reads + 3 memory writes")
print("Fused (1 kernel): 1 memory read + 1 memory write")
print(f"Results match: {np.allclose(step3, fused)}")
print()
print("On a GPU, memory bandwidth is often the bottleneck,")
print("so reducing memory round-trips can speed things up significantly.")
Modo Grafo: Viendo el Panorama Completo
torch.compile(model)
opta por el modo grafo. En lugar de ejecutar operaciones una a la vez, PyTorch primero captura toda la computación como un grafo — un grafo acíclico dirigido (DAG) de operaciones matemáticas sin flujo de control Python, sin sentencias print, sin efectos secundarios.
Este grafo puede entonces optimizarse como un todo: fusionando operaciones, reordenándolas para mejor acceso a memoria, eligiendo implementaciones óptimas de kernels. El grafo optimizado se compila en código eficiente que se ejecuta sin volver al intérprete Python entre operaciones. El resultado es que cientos de pequeños viajes de ida y vuelta entre Python y C++ se colapsan en una sola llamada a función compilada.
Aquí hay un ejemplo mínimo. Observa que la API es notablemente simple — una sola llamada a
torch.compile
envuelve el modelo y devuelve una versión compilada que es funcionalmente idéntica pero (después del primer paso de compilación) sustancialmente más rápida.
import torch
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(512, 512)
def forward(self, x):
return torch.relu(self.linear(x))
model = SimpleModel().cuda()
# Eager mode (default): each op runs immediately
output_eager = model(torch.randn(32, 512).cuda())
# Compiled: captures graph, optimises, compiles
compiled_model = torch.compile(model)
output_compiled = compiled_model(torch.randn(32, 512).cuda())
# First call is slow (compiling), subsequent calls are fast
TorchDynamo: El Compilador Frontend
TorchDynamo es el componente que captura el grafo. Funciona interceptando la ejecución del bytecode Python — literalmente observando lo que CPython hace instrucción por instrucción — y registrando las operaciones matemáticas en un grafo.
Cuando Dynamo encuentra algo que no puede capturar (una sentencia print, un
if/else
dependiente de datos, una llamada a una biblioteca externa), crea un
graph break
(quiebre de grafo). Un graph break
no
es un error. En su lugar, Dynamo:
- Compila y ejecuta el grafo capturado hasta ese punto
- Vuelve al intérprete Python para el código que no puede capturar
- Inicia una nueva captura de grafo después del quiebre
Esto significa que
torch.compile
siempre es correcto — nunca cambia el comportamiento silenciosamente. Pero cada graph break es una oportunidad de optimización perdida, porque el compilador no puede fusionar operaciones a través del quiebre.
def forward(self, x):
x = self.linear1(x) # ┐
x = torch.relu(x) # │ Graph 1 (compiled, fused)
x = self.linear2(x) # ┘
print(f"Shape: {x.shape}") # ← GRAPH BREAK (print is a side effect)
x = self.linear3(x) # ┐
x = torch.sigmoid(x) # │ Graph 2 (compiled, fused)
return x # ┘
# Result: 2 compiled graphs with a Python interpreter gap between them.
# Remove the print → 1 fused graph → faster.
TorchInductor: El Compilador Backend
Una vez que Dynamo ha capturado un grafo de operaciones matemáticas, TorchInductor lo traduce en código optimizado. Para objetivos GPU, Inductor genera kernels Triton . Para objetivos CPU, genera código C++.
Las optimizaciones clave que Inductor realiza:
- Fusión de operadores : matmul + sesgo + relu se convierte en un solo lanzamiento de kernel en lugar de tres, eliminando escrituras de memoria intermedias
- Planificación de memoria : reutilizar buffers de memoria entre operaciones para reducir la asignación máxima, de modo que un modelo que habría asignado diez tensores temporales podría necesitar solo tres
- Optimización de layout : elegir disposiciones de memoria (channels-first vs channels-last) que sean óptimas para cada kernel, insertando conversiones de layout solo donde sea necesario
A continuación se muestra una versión simplificada de lo que Inductor podría generar para una operación fusionada de linear + relu. La idea clave es que la suma del sesgo y relu se realizan dentro del kernel de matmul, de modo que los resultados intermedios nunca salen de los registros rápidos de la GPU para hacer un viaje de ida y vuelta por la memoria global más lenta.
# What Inductor generates (simplified) for linear + relu:
@triton.jit
def fused_linear_relu_kernel(
x_ptr, w_ptr, b_ptr, out_ptr,
M, N, K,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
# Compute matmul tile
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
a = tl.load(x_ptr + ...)
b = tl.load(w_ptr + ...)
acc += tl.dot(a, b)
# Fused: add bias and relu in the same kernel
bias = tl.load(b_ptr + ...)
acc = acc + bias
acc = tl.maximum(acc, 0.0) # relu fused in!
tl.store(out_ptr + ..., acc)
# One kernel launch instead of three. One memory write instead of three.
Triton: El Compilador de Kernels
Triton cumple dos roles: es tanto un lenguaje de dominio específico (DSL) para escribir kernels GPU como un compilador que convierte ese DSL en código máquina GPU.
¿Por qué existe Triton? Escribir kernels CUDA es difícil. Debes gestionar manualmente bloques de hilos, memoria compartida, coalescencia de memoria y sincronización. Triton abstrae todo esto — escribes a nivel de "tile" (bloques de datos), y el compilador Triton se encarga de la programación de hilos y la gestión de memoria. Esto es una ganancia sustancial de productividad: un kernel fusionado que podría requerir cientos de líneas de CUDA C++ a menudo puede expresarse en unas pocas docenas de líneas de Triton.
El pipeline de compilación de Triton se ve así:
Triton DSL (@triton.jit decorated Python)
↓
Triton IR (intermediate representation)
↓
LLVM IR (general-purpose IR)
↓
PTX (NVIDIA portable assembly)
↓ ptxas (NVIDIA assembler)
SASS (GPU-specific machine code)
Contrasta esto con cuBLAS y cuDNN: esos kernels fueron escritos en CUDA C++ (archivos
.cu
), compilados con NVCC (produciendo PTX, luego SASS), y distribuidos como binarios precompilados. Los kernels Triton, en cambio, se compilan en tiempo de ejecución (JIT). Esto hace que la primera invocación sea más lenta, pero permite al compilador optimizar para la arquitectura GPU específica y las formas de entrada que realmente se encuentran — algo que los binarios precompilados fundamentalmente no pueden hacer.
Cuándo Usar torch.compile (y Cuándo No)
torch.compile
no es un botón universal de "ir más rápido" — implica compromisos reales. Aquí tienes una guía práctica de cuándo tiende a ayudar y cuándo tiende a perjudicar.
Compila cuando:
- La arquitectura de tu modelo es estable y estás en producción o entrenamiento sostenido. El costo de compilación se amortiza sobre miles de iteraciones.
- Tus operaciones son lo suficientemente grandes para que la fusión ayude — matmul, atención, redes feed-forward. Estas tienen suficiente intensidad aritmética para beneficiarse del tráfico de memoria reducido.
- Estás dispuesto a esperar la compilación inicial (típicamente segundos a unos pocos minutos, dependiendo de la complejidad del modelo).
No compiles cuando:
- Estás depurando. Los graph breaks dificultan el rastreo, y los mensajes de error del código compilado suelen ser menos claros que los errores del modo eager.
- Formas dinámicas. Cada nueva forma de entrada puede desencadenar recompilación, lo que puede hacer las cosas más lentas en general si las formas cambian frecuentemente.
- Prototipado rápido. Cuando estás iterando sobre la arquitectura del modelo cada pocos minutos, la sobrecarga de compilación domina las ejecuciones cortas.
- Modelos muy pequeños. Si la sobrecarga de Python por operación ya es insignificante en relación al tiempo de cómputo, la fusión no ahorrará mucho.
Quiz
Pon a prueba tu comprensión de torch.compile, TorchDynamo, TorchInductor y el pipeline de compilación de Triton.
¿Por qué el modo eager no puede fusionar matmul + sesgo + relu en un solo kernel?
¿Qué es un graph break en TorchDynamo y es un error?
¿Cuál es el rol de Triton en el pipeline de torch.compile?
¿Por qué la primera llamada a un modelo compilado con torch.compile es lenta?