El Grafo Computacional

Cada vez que realizas una operación sobre un tensor que tiene requires_grad=True , PyTorch construye silenciosamente un grafo acíclico dirigido (DAG) que registra cada operación. Cada nodo de este grafo almacena la operación que se aplicó (suma, multiplicación, matmul, etc.) junto con punteros a los nodos que representan sus entradas. Este grafo es el plano para calcular los gradientes — cuando llamas a .backward() , PyTorch recorre el grafo en orden inverso, aplicando la regla de la cadena en cada nodo para propagar la información del gradiente desde la pérdida hasta los parámetros.

La idea clave es que estas dos fases sirven para propósitos diferentes. El paso hacia adelante (forward pass) construye el grafo: a medida que cada operación se ejecuta, se registra como un nodo, almacenando tanto la función que se calculó como cualquier valor intermedio necesario posteriormente para la derivada. El paso hacia atrás (backward pass) consume el grafo: recorre cada nodo en orden topológico inverso, calculando la contribución del gradiente de cada operación y luego — por defecto — libera el grafo por completo . Por eso llamar a .backward() una segunda vez sobre la misma pérdida produce un error: el grafo ya no existe. Si realmente necesitas retropropagar a través del mismo grafo dos veces (por ejemplo, al calcular derivadas de orden superior), puedes pasar retain_graph=True , lo que mantiene el grafo en memoria a costa de un mayor uso de memoria.

💡 Piensa en el grafo computacional como una receta que registra cada paso de cocina. El forward pass escribe la receta. El backward pass la lee en orden inverso para determinar cómo cambiar cada ingrediente afecta al plato final. Una vez que la has leído, PyTorch descarta la receta — a menos que pidas explícitamente conservarla.

Simulemos este proceso manualmente con NumPy. Definiremos un cálculo simple — $y = wx + b$ seguido de una pérdida de error cuadrático — y luego lo recorreremos hacia atrás, calculando cada gradiente a mano usando la regla de la cadena.

import numpy as np

# Simple computation: y = w*x + b, loss = (y - target)^2
w = 2.0
x = 3.0
b = 1.0
target = 10.0

# Forward pass (building the "graph" mentally)
y = w * x + b          # y = 2*3 + 1 = 7
loss = (y - target)**2  # loss = (7 - 10)^2 = 9

print("Forward pass:")
print(f"  w={w}, x={x}, b={b}")
print(f"  y = w*x + b = {y}")
print(f"  loss = (y - target)² = ({y} - {target})² = {loss}")
print()

# Backward pass (chain rule, working backwards)
dloss_dy = 2 * (y - target)      # d(loss)/dy = 2(y - target) = -6
dy_dw = x                         # d(y)/dw = x = 3
dy_db = 1.0                       # d(y)/db = 1
dloss_dw = dloss_dy * dy_dw       # chain rule: -6 * 3 = -18
dloss_db = dloss_dy * dy_db       # chain rule: -6 * 1 = -6

print("Backward pass (chain rule):")
print(f"  d(loss)/dy = 2(y - target) = 2({y} - {target}) = {dloss_dy}")
print(f"  d(y)/dw    = x = {dy_dw}")
print(f"  d(y)/db    = 1")
print(f"  d(loss)/dw = d(loss)/dy × d(y)/dw = {dloss_dy} × {dy_dw} = {dloss_dw}")
print(f"  d(loss)/db = d(loss)/dy × d(y)/db = {dloss_dy} × {dy_db} = {dloss_db}")

El gradiente $\partial L / \partial w = -18$ nos dice que aumentar $w$ un poco disminuiría la pérdida (ya que el gradiente es negativo), y la magnitud nos indica que la pérdida es bastante sensible a $w$. De manera similar, $\partial L / \partial b = -6$ indica que la pérdida también es sensible a $b$, pero menos que a $w$. Un paso de descenso de gradiente ajustaría ambos parámetros en la dirección que reduce la pérdida: $w \leftarrow w - \alpha \cdot (-18)$ y $b \leftarrow b - \alpha \cdot (-6)$ , donde $\alpha$ es la tasa de aprendizaje.

Ahora verifiquemos que el autograd de PyTorch produce exactamente los mismos resultados:

import torch

w = torch.tensor(2.0, requires_grad=True)
x = torch.tensor(3.0)
b = torch.tensor(1.0, requires_grad=True)
target = torch.tensor(10.0)

y = w * x + b
loss = (y - target) ** 2

loss.backward()

print(w.grad)  # tensor(-18.)  ← same as our manual calculation
print(b.grad)  # tensor(-6.)   ← same

Los números coinciden exactamente. Internamente, PyTorch realizó la misma descomposición con la regla de la cadena que hicimos a mano — simplemente lo hizo automáticamente recorriendo el grafo que construyó durante el forward pass.

La Regla de la Cadena en Acción

La regla de la cadena es la columna vertebral matemática de la retropropagación. Para una composición de funciones donde la pérdida $L$ depende de un parámetro $w$ a través de una secuencia de cálculos intermedios — digamos $L = f(g(h(w)))$ — la regla de la cadena nos dice que podemos descomponer la derivada completa en un producto de derivadas locales más simples:

$$\frac{\partial L}{\partial w} = \frac{\partial L}{\partial g} \cdot \frac{\partial g}{\partial h} \cdot \frac{\partial h}{\partial w}$$

Cada factor en este producto es una derivada local — responde a una pregunta específica: ¿cuánto cambia la salida de esta capa particular cuando su entrada cambia una cantidad pequeña? El factor $\partial h / \partial w$ mide la sensibilidad de la salida de la primera operación al parámetro $w$. El factor $\partial g / \partial h$ mide cuánto la segunda operación amplifica o atenúa los cambios que vienen de abajo. Y $\partial L / \partial g$ mide cómo responde la pérdida a cambios en el último valor intermedio. Su producto da la sensibilidad de extremo a extremo de la pérdida al parámetro — exactamente lo que necesitamos para el descenso de gradiente.

¿Por qué es tan importante esta descomposición? Porque calcular la derivada completa directamente — mediante diferenciación simbólica de toda la red como una función monolítica — sería intratable para modelos con millones o miles de millones de parámetros. La regla de la cadena nos permite descomponer el problema en una secuencia de derivadas locales simples, cada una calculable a partir de valores ya disponibles durante el forward pass. Esta es la idea que hace factible el aprendizaje profundo moderno.

Veamos la regla de la cadena en acción con un grafo computacional más profundo que incluye una activación ReLU — una de las no linealidades más comunes en redes neuronales. ReLU se define como $\text{ReLU}(x) = \max(0, x)$, y su derivada es particularmente simple: 1 cuando la entrada es positiva, 0 cuando es negativa. Esta naturaleza por tramos conduce a un fenómeno importante que exploraremos a continuación.

import numpy as np

# Deeper graph: z = relu(w*x + b), loss = (z - target)^2
w, x, b, target = 0.5, 4.0, -1.5, 3.0

# Forward
pre_act = w * x + b           # 0.5*4 + (-1.5) = 0.5
z = max(0, pre_act)           # relu(0.5) = 0.5
loss = (z - target) ** 2      # (0.5 - 3)^2 = 6.25

print("Forward:")
print(f"  pre_act = w*x + b = {pre_act}")
print(f"  z = relu(pre_act) = {z}")
print(f"  loss = (z - target)² = {loss}")
print()

# Backward (chain rule, layer by layer)
dloss_dz = 2 * (z - target)                  # = -5.0
dz_dpre = 1.0 if pre_act > 0 else 0.0        # relu derivative
dpre_dw = x                                    # = 4.0
dpre_db = 1.0

dloss_dw = dloss_dz * dz_dpre * dpre_dw      # -5 * 1 * 4 = -20
dloss_db = dloss_dz * dz_dpre * dpre_db      # -5 * 1 * 1 = -5

print("Backward (chain rule):")
print(f"  ∂loss/∂z    = 2(z - target) = {dloss_dz}")
print(f"  ∂z/∂pre    = {'1 (pre > 0)' if pre_act > 0 else '0 (pre ≤ 0)'} = {dz_dpre}")
print(f"  ∂pre/∂w    = x = {dpre_dw}")
print(f"  ∂loss/∂w   = {dloss_dz} × {dz_dpre} × {dpre_dw} = {dloss_dw}")
print(f"  ∂loss/∂b   = {dloss_dz} × {dz_dpre} × {dpre_db} = {dloss_db}")
print()

# Show what happens when relu "kills" the gradient
w2, x2, b2 = 0.5, 4.0, -3.0
pre_act2 = w2 * x2 + b2  # = -1.0 (negative!)
z2 = max(0, pre_act2)     # relu(-1) = 0
dz_dpre2 = 1.0 if pre_act2 > 0 else 0.0  # = 0!
print("When pre_act is negative (relu kills gradient):")
print(f"  pre_act = {pre_act2}, relu = {z2}")
print(f"  ∂z/∂pre = {dz_dpre2} → gradient is ZERO, nothing flows back")
print(f"  This is the 'dying ReLU' problem.")

Observa lo que ocurrió en el segundo caso: cuando el valor de pre-activación era negativo ($-1.0$), ReLU lo fijó en cero, y la derivada local $\partial z / \partial \text{pre}$ también se convirtió en cero. Como la regla de la cadena multiplica todas las derivadas locales entre sí, un solo cero en cualquier parte de la cadena anula el gradiente completo. Ningún gradiente fluye de vuelta hacia $w$ o $b$, así que esos parámetros no reciben señal de aprendizaje. Esto se conoce como el problema del ReLU muerto (Lu et al., 2019) — si la pre-activación de una neurona es negativa para todos los ejemplos de entrenamiento, su gradiente es permanentemente cero y se convierte efectivamente en peso muerto en la red. Alternativas como Leaky ReLU y GELU evitan esto al garantizar que la derivada nunca sea exactamente cero, incluso para entradas negativas.

Acumulación de Gradientes y .zero_grad()

PyTorch acumula gradientes por defecto: llamar a .backward() suma al atributo .grad en lugar de reemplazarlo. Esta decisión de diseño es deliberada — es útil cuando quieres acumular gradientes a través de múltiples mini-lotes (una técnica llamada acumulación de gradientes ) o cuando tienes múltiples funciones de pérdida que contribuyen gradientes a los mismos parámetros.

Pero también es uno de los errores más comunes en PyTorch. Si olvidas llamar a optimizer.zero_grad() (o model.zero_grad() ) antes de cada paso de entrenamiento, los gradientes de pasos anteriores se acumulan y el gradiente efectivo se convierte en la suma de todos los gradientes pasados. El entrenamiento típicamente diverge — las actualizaciones de parámetros crecen cada vez más a medida que los gradientes obsoletos se acumulan, y la pérdida explota u oscila descontroladamente.

💡 Un buen modelo mental: .grad es como una pizarra. PyTorch sigue escribiendo nuevos valores de gradiente encima de lo que ya está ahí. Si no borras la pizarra (zero_grad) entre pasos, terminas con un revoltijo de información vieja y nueva.

Simulemos ambos escenarios — el error (olvidar hacer zero) y el uso intencional (acumulación de gradientes para tamaños de lote efectivos más grandes):

import numpy as np

# Simulate gradient accumulation
grad_w = 0.0  # starts at zero

# Step 1
loss_grad_1 = -18.0
grad_w += loss_grad_1  # accumulate
print(f"After step 1: grad_w = {grad_w}")

# Step 2 (forgot to zero!)
loss_grad_2 = -12.0
grad_w += loss_grad_2  # accumulates on top!
print(f"After step 2 (no zero_grad): grad_w = {grad_w}  ← WRONG, should be {loss_grad_2}")

# Correct: zero first
grad_w = 0.0           # zero_grad()
grad_w += loss_grad_2
print(f"After step 2 (with zero_grad): grad_w = {grad_w}  ← correct")
print()

# When accumulation is INTENTIONAL (gradient accumulation for large effective batches)
effective_batch_size = 4
micro_batch_grads = [-5.0, -3.0, -7.0, -1.0]
grad_w = 0.0
for i, g in enumerate(micro_batch_grads):
    grad_w += g  # intentional accumulation
    print(f"  Micro-batch {i+1}: grad += {g}, total = {grad_w}")
grad_w /= effective_batch_size  # average
print(f"Average gradient: {grad_w}")
print(f"(Same as processing all 4 at once: {sum(micro_batch_grads)/4})")

La acumulación de gradientes es particularmente útil cuando la memoria de tu GPU no puede manejar el tamaño de lote que idealmente usarías. Por ejemplo, si el tamaño de lote óptimo es 32 pero tu GPU solo puede procesar 8 muestras a la vez, puedes ejecutar 4 pases forward-backward con micro-lotes de 8, acumulando gradientes, y luego hacer un solo paso del optimizador. El efecto matemático es idéntico a procesar las 32 muestras de una vez (asumiendo que promedias los gradientes acumulados), pero el consumo máximo de memoria es el de un lote de 8.

PyTorch moderno también ofrece optimizer.zero_grad(set_to_none=True) , que establece .grad en None en lugar de llenarlo con ceros. Esto es ligeramente más eficiente en memoria (no se asigna un tensor de ceros) y puede ser marginalmente más rápido, aunque la diferencia suele ser pequeña.

torch.no_grad() e Inferencia

Durante la inferencia (o evaluación), no necesitas gradientes. Pero por defecto, PyTorch sigue construyendo el grafo computacional para cada operación — almacenando activaciones intermedias que se necesitarían para el backward pass, registrando qué operaciones se realizaron, y consumiendo tanto memoria como cómputo en el proceso. Para un modelo con miles de millones de parámetros, estos valores intermedios almacenados pueden fácilmente duplicar o triplicar el consumo de memoria en comparación con lo que requiere solo el forward pass.

PyTorch proporciona dos mecanismos para omitir la construcción del grafo. El primero es el context manager torch.no_grad() , que deshabilita temporalmente el seguimiento de gradientes para todas las operaciones dentro de su ámbito. El segundo es @torch.inference_mode() , un decorador (o context manager) que es más estricto y ligeramente más eficiente.

# Context manager — temporarily disables gradient tracking
with torch.no_grad():
    output = model(input)   # no graph built, no memory for intermediates

# Decorator — stricter, slightly faster
@torch.inference_mode()
def predict(model, input):
    return model(input)     # no graph, no stale tensor issues

# Why this matters: a model with 1B parameters stores intermediate
# activations during forward pass for backward. Skipping that
# can cut memory usage by 2-3× during inference.

La diferencia clave entre ambos es sutil pero importante. torch.no_grad() deshabilita el cálculo de gradientes pero aún permite operaciones in-place sobre tensores que originalmente tenían requires_grad=True . Esto significa que puedes crear accidentalmente un tensor dentro de un bloque no_grad() , pasarlo fuera, e intentar luego retropropagar a través de él — lo que lleva a errores confusos o bugs silenciosos de corrección.

torch.inference_mode() es más estricto: marca todos los tensores creados dentro de su ámbito como tensores de inferencia , que no pueden usarse como entradas para operaciones que requieren seguimiento de gradientes. Si las salidas de inferencia se filtran accidentalmente al código de entrenamiento, PyTorch lanza un error inmediatamente en lugar de producir gradientes incorrectos silenciosamente. Esto hace que inference_mode() sea la opción más segura y generalmente preferida para pipelines de inferencia en producción.

💡 Regla general: usa torch.inference_mode() para inferencia y evaluación. Usa torch.no_grad() solo cuando necesites la flexibilidad de operar sobre tensores con seguimiento de gradientes sin registrar la operación (por ejemplo, al actualizar parámetros manualmente durante optimizadores personalizados).

Funciones Autograd Personalizadas

A veces necesitas operaciones para las que PyTorch no tiene derivadas incorporadas, o quieres calcular el gradiente de manera diferente al comportamiento por defecto — por ejemplo, usando gradient checkpointing para intercambiar cómputo por memoria al recalcular activaciones durante el backward pass en lugar de almacenarlas. torch.autograd.Function de PyTorch te permite definir lógica personalizada de forward y backward que se integra perfectamente en el motor de autograd.

Aquí hay una implementación personalizada de ReLU como función autograd:

class CustomReLU(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        # ctx stores values needed for backward
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        # Gradient of ReLU: 1 where input > 0, 0 elsewhere
        grad_input = grad_output * (input > 0).float()
        return grad_input

# Usage:
output = CustomReLU.apply(input_tensor)

Desglosemos cada parte, porque cada línea cumple un propósito específico en la maquinaria de autograd:

  • forward(ctx, input) — calcula la salida de la operación (aquí, fijando valores negativos a cero). El objeto ctx es un contexto que actúa como puente entre forward y backward.
  • ctx.save_for_backward(input) — almacena el tensor de entrada para que esté disponible durante el backward pass. PyTorch gestiona el ciclo de vida de memoria de estos tensores guardados: se mantienen vivos hasta que backward completa, y luego se liberan. Solo deberías guardar tensores (no objetos Python arbitrarios) a través de este método, ya que PyTorch necesita rastrearlos para la gestión de memoria y el cálculo de gradientes.
  • backward(ctx, grad_output) — recibe el gradiente ascendente ( grad_output , también llamado el gradiente "entrante" de las capas superiores) y devuelve el gradiente con respecto a cada entrada de forward . Esta es la regla de la cadena en acción: multiplicamos el gradiente ascendente por la derivada local.
  • (input > 0).float() — la derivada local de ReLU. Esto crea una máscara binaria que es 1.0 donde la entrada era positiva y 0.0 en el resto. Multiplicar el gradiente ascendente por esta máscara implementa la regla de la cadena: los gradientes fluyen donde ReLU estaba activo y se bloquean donde no lo estaba.

Las funciones autograd personalizadas son también el mecanismo detrás de varias técnicas importantes en el aprendizaje profundo moderno. Gradient checkpointing (Chen et al., 2016) usa una función personalizada cuyo forward pass descarta las activaciones intermedias para ahorrar memoria, y cuyo backward pass vuelve a ejecutar el cálculo forward para reconstruirlas sobre la marcha. Los estimadores straight-through (Bengio et al., 2013) usan un backward personalizado que pasa gradientes a través de operaciones no diferenciables (como redondeo o argmax) pretendiendo que la derivada es 1. Son herramientas poderosas, pero conllevan responsabilidad — una implementación incorrecta de backward producirá gradientes erróneos silenciosamente, y depurar errores de gradientes tiende a ser mucho más difícil que depurar errores del forward pass.

Quiz

Comprueba tu comprensión del sistema autograd de PyTorch.

¿Por qué PyTorch acumula gradientes por defecto en lugar de sobrescribirlos?

En la regla de la cadena ∂L/∂w = ∂L/∂y · ∂y/∂w, ¿qué representa ∂y/∂w?

¿Cuál es la diferencia clave entre torch.no_grad() y torch.inference_mode()?

En una Function de autograd personalizada, ¿qué hace ctx.save_for_backward()?