Funcional vs Orientado a Objetos

Los modelos de PyTorch son objetos. Heredas de nn.Module , almacenas pesos como atributos y llamas a self.linear(x) . El estado vive dentro del objeto — pesos, buffers, estadísticas acumuladas. Este diseño orientado a objetos resulta natural para la mayoría de los programadores Python y refleja cómo solemos pensar sobre las redes neuronales: un modelo es algo, y tiene parámetros.

JAX adopta un enfoque fundamentalmente distinto: funciones puras . Un modelo en JAX es una función que recibe los parámetros y las entradas como argumentos separados y devuelve la salida. No hay estado oculto — todo es explícito. El modelo no posee sus pesos; en cambio, los pesos son simplemente otro argumento que pasas.

Aquí están ambos paradigmas lado a lado:

# ── PyTorch: Object-Oriented ──────────────────────
import torch
import torch.nn as nn

class Linear(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(out_dim, in_dim))
        self.bias = nn.Parameter(torch.zeros(out_dim))

    def forward(self, x):
        return x @ self.weight.T + self.bias

model = Linear(4, 3)
output = model(torch.randn(2, 4))  # state is inside model

# ── JAX: Functional ───────────────────────────────
import jax
import jax.numpy as jnp

def linear(params, x):
    return x @ params['weight'].T + params['bias']

params = {
    'weight': jax.random.normal(jax.random.PRNGKey(0), (3, 4)),
    'bias': jnp.zeros(3)
}
output = linear(params, jnp.ones((2, 4)))  # state is explicit

¿Por qué JAX eligió este camino? Las funciones puras son más fáciles de compilar (no hay estado oculto que rastrear), más fáciles de paralelizar (no hay estado mutable compartido) y más fáciles de razonar matemáticamente (composición de funciones). Cuando una función no tiene efectos secundarios y depende solo de sus entradas, el compilador puede reordenar, fusionar y distribuir operaciones libremente sin preocuparse por dependencias invisibles.

El precio es la verbosidad — pasar params a todas partes se vuelve tedioso, especialmente para redes profundas con decenas de capas. Por eso bibliotecas como Flax (Google, 2020) y Equinox (Kidger, 2021) añaden abstracciones tipo nn.Module encima de JAX. Te dan la ergonomía del código orientado a objetos mientras preservan la semántica funcional de JAX por debajo — lo mejor de ambos mundos, aunque con una capa adicional de abstracción que aprender.

jax.jit: Compilar por Defecto

La filosofía de diseño de JAX es que la compilación debe ser lo predeterminado , no un complemento opcional. El decorador @jax.jit traza la función, captura un grafo de cómputo a través de la representación intermedia HLO (High-Level Operations) de XLA, y lo compila en un kernel optimizado. La primera llamada paga el costo de compilación; las llamadas subsecuentes con las mismas formas de entrada ejecutan el código compilado directamente, con cero sobrecarga de Python.

Esto contrasta marcadamente con el enfoque histórico de PyTorch:

  • PyTorch : eager por defecto, torch.compile opcional (añadido en PyTorch 2.0, 2023)
  • JAX : compilado por defecto para código sensible al rendimiento, modo eager disponible para depuración

Así luce un paso de entrenamiento típico en JAX:

# JAX: compilation is the natural way to run code
@jax.jit
def train_step(params, x, y):
    def loss_fn(p):
        pred = model(p, x)
        return jnp.mean((pred - y) ** 2)

    loss, grads = jax.value_and_grad(loss_fn)(params)
    params = jax.tree.map(lambda p, g: p - 0.01 * g, params, grads)
    return params, loss

# First call: traces + compiles via XLA
# Subsequent calls: runs compiled code directly (no Python overhead)

Observa cómo jax.value_and_grad calcula tanto el valor de la pérdida como los gradientes en una sola llamada — otro patrón funcional. No hay equivalente a loss.backward() mutando una cinta; en su lugar, jax.grad es una transformación a nivel de código fuente que toma una función y devuelve una nueva función que calcula su gradiente. La diferenciación es simplemente otra transformación de funciones, componible con jit , vmap y pmap .

💡 jax.jit requiere que la función sea 'pura' — las mismas entradas deben producir las mismas salidas, sin efectos secundarios. Esta restricción es lo que hace posible la compilación: el compilador puede razonar sobre la función completa sin preocuparse por cambios de estado ocultos. Si tu función lee una variable global o muta un arreglo en su lugar, jax.jit capturará silenciosamente un valor obsoleto o lanzará un error. Esta rigurosidad es una característica, no un defecto — fuerza código que el compilador realmente puede optimizar.

XLA: Álgebra Lineal Acelerada

XLA (Accelerated Linear Algebra) es el compilador de Google para operaciones de álgebra lineal. Originalmente desarrollado para TensorFlow (circa 2017), se convirtió en el backend que hace a JAX rápido. Cuando jax.jit traza una función, el resultado es un programa HLO — un grafo de operaciones de alto nivel como "matmul", "add", "reduce" — que XLA luego optimiza y compila a código máquina específico del dispositivo.

Este es el pipeline de compilación de XLA:

Python function
    ↓ jax.jit traces it
HLO IR (High-Level Operations)
    ↓ XLA optimises (fusion, layout, scheduling)
Device-specific code
    ├── GPU: LLVM IR → PTX → SASS (similar to Triton)
    ├── TPU: TPU-specific machine code
    └── CPU: LLVM IR → x86/ARM assembly

Los pasos de optimización de XLA son donde ocurre la verdadera magia. La fusión de operadores combina múltiples operaciones elementales (suma, multiplicación, función de activación) en un solo lanzamiento de kernel, eliminando lecturas y escrituras intermedias de memoria. La optimización de layout reorganiza las disposiciones de memoria de los tensores para coincidir con las preferencias del hardware (por ejemplo, eligiendo entre almacenamiento por filas o por columnas). La planificación ordena las operaciones para maximizar la utilización del hardware y superponer cómputo con transferencias de memoria.

La diferencia clave con el enfoque de PyTorch se reduce a de dónde provienen los kernels:

  • PyTorch (eager) : usa kernels preconstruidos de cuBLAS/cuDNN, optimizados por NVIDIA durante muchos años
  • PyTorch (compilado) : usa TorchInductor + Triton, genera kernels en tiempo de compilación
  • JAX : usa XLA, genera kernels en tiempo de compilación vía LLVM

Ventaja de XLA : apunta a múltiples backends desde el mismo código fuente. El mismo código JAX se ejecuta en GPU, TPU y CPU sin modificación. El torch.compile de PyTorch actualmente apunta a GPU (vía Triton) y CPU (vía C++/OpenMP), pero no soporta TPUs de forma nativa. Para organizaciones con acceso a los pods de TPU de Google — que pueden ofrecer excelente relación precio-rendimiento para entrenamiento a gran escala — esta capacidad multi-backend es posiblemente la mayor ventaja de JAX.

Desventaja de XLA : requiere formas estáticas en tiempo de compilación. Cada dimensión del tensor debe ser conocida cuando se traza la función, y cualquier cambio de forma dispara una recompilación completa. Un bucle de entrenamiento donde el tamaño de lote o la longitud de secuencia varía de paso a paso puede terminar recompilando repetidamente, lo cual es costoso. El modo eager de PyTorch maneja formas dinámicas naturalmente — no hay kernel compilado que invalidar, así que un lote de 32 y un lote de 37 se ejecutan sin ninguna sobrecarga de compilación.

Compromisos: Cuándo Brilla Cada Uno

Ningún framework es universalmente mejor — hacen compromisos diferentes que se adaptan a distintos flujos de trabajo y restricciones. Entender dónde brilla cada uno te ayuda a elegir la herramienta correcta para el trabajo.

Fortalezas de PyTorch:

  • Depuración : el modo eager te permite usar print() , inspeccionar valores intermedios y poner breakpoints en cualquier parte de tu modelo. Lo que escribes es lo que se ejecuta — sin sorpresas de trazado.
  • Formas dinámicas : los tamaños de lote, longitudes de secuencia y estructuras de grafos pueden cambiar libremente entre iteraciones sin ninguna penalización de recompilación.
  • Ecosistema : el mayor repositorio de modelos (Hugging Face aloja decenas de miles de modelos PyTorch), la mayor cantidad de tutoriales y la adopción industrial más amplia. Si necesitas un checkpoint preentrenado, casi seguro existe en formato PyTorch.
  • Compilación gradual : comienza con modo eager para prototipar y depurar, luego añade torch.compile cuando estés listo para velocidad — sin necesidad de reescribir.

Fortalezas de JAX:

  • Soporte para TPU : compilación de primera clase para TPU vía XLA. Los pods de TPU de Google están entre el hardware más rentable para entrenamiento a gran escala, y JAX es la forma más natural de utilizarlos.
  • Transformaciones funcionales : jax.vmap (auto-batching), jax.pmap (auto-paralelismo) y jax.grad se componen limpiamente porque todo es una función pura. Puedes aplicar vmap al gradiente de una función jiteada — estas transformaciones se anidan naturalmente.
  • Reproducibilidad : estado PRNG explícito (sin semilla aleatoria global) hace que los experimentos sean reproducibles por construcción. Cada operación aleatoria requiere un jax.random.PRNGKey explícito, así que no hay estado global oculto que pueda diferir silenciosamente entre ejecuciones.
  • Velocidad de investigación para ciertas cargas : DeepMind, Google Brain (ahora Google DeepMind) y varios laboratorios de investigación líderes usan JAX extensivamente para experimentos a gran escala donde la composición funcional y el acceso a TPUs son ventajas críticas.
💡 La brecha entre los dos se está cerrando. PyTorch 2.0+ añadió torch.compile para compilación en modo grafo, y PyTorch/XLA proporciona un camino experimental para ejecutar PyTorch en TPUs. JAX añadió herramientas de depuración como jax.debug.print y jax.disable_jit para recorrer código de forma eager. Ambos frameworks están convergiendo hacia 'fácil de escribir, rápido de ejecutar' — simplemente empezaron desde extremos opuestos del espectro de diseño.

Lado a Lado: Las Pilas de Compilación Completas

Para cerrar, aquí hay un resumen visual que coloca ambos frameworks uno al lado del otro a través de cada capa de la pila — desde el frontend que escribes hasta el hardware que lo ejecuta:

              PyTorch                          JAX
              ───────                          ───
Frontend:     Python (nn.Module)               Python (pure functions)
Paradigm:     Object-oriented, stateful        Functional, stateless
Default:      Eager (immediate execution)      Compiled (jax.jit)

Compilation:  torch.compile (opt-in)           jax.jit (standard)
              │                                │
Graph capture: TorchDynamo                     JAX tracing
              │                                │
Optimiser:    TorchInductor                    XLA
              │                                │
Kernel gen:   Triton (GPU) / C++ (CPU)         LLVM (GPU/TPU/CPU)
              │                                │
GPU assembly: Triton → PTX → SASS             LLVM → PTX → SASS
              │                                │
Pre-built:    cuBLAS, cuDNN (eager path)       (none — always compiled)
              │                                │
Hardware:     NVIDIA GPUs                      NVIDIA GPUs + Google TPUs

Gradient:     Autograd (tape-based)            jax.grad (source transform)
State:        Inside model (self.weight)       Explicit (params dict)
Random:       Global seed (torch.manual_seed)  Explicit key (jax.random.PRNGKey)

Varias filas de esta tabla merecen un momento de reflexión. La fila de gradiente resalta una diferencia arquitectónica profunda: el autograd de PyTorch graba operaciones en una cinta durante el paso forward y reproduce esa cinta en reversa durante loss.backward() . El jax.grad de JAX es una transformación a nivel de código fuente — toma una función y devuelve una nueva función que calcula su derivada. No hay cinta que construir en tiempo de ejecución; la función gradiente se construye una vez y se compila.

La fila de aleatoriedad podría parecer un detalle menor, pero tiene consecuencias reales para la reproducibilidad. PyTorch usa un generador de números aleatorios global ( torch.manual_seed(42) ), lo que significa que la secuencia de números aleatorios depende del orden en que se ejecutan las operaciones — añade una capa de dropout, y cada llamada aleatoria subsecuente se desplaza. JAX evita esto requiriendo un PRNGKey explícito para cada operación aleatoria, haciendo que la aleatoriedad sea determinista e independiente del orden de ejecución.

¿Cuál deberías usar? Para la mayoría de los profesionales, el ecosistema más grande de PyTorch y su depuración más sencilla lo convierten en la opción predeterminada. Para investigación que empuja los límites de la compilación, el paralelismo o el entrenamiento a escala de TPU, JAX ofrece poderosas abstracciones funcionales. Entender ambos te ayuda a elegir la herramienta correcta — y a apreciar lo que cada framework hace bajo el capó.

Quiz

Pon a prueba tu comprensión de las diferencias filosóficas y arquitectónicas entre PyTorch y JAX.

¿Cuál es la diferencia fundamental de paradigma entre PyTorch y JAX?

¿Por qué jax.jit requiere funciones puras?

¿Qué ventaja tiene XLA sobre Triton para la generación de kernels?

¿Por qué el modo eager puede manejar formas dinámicas pero el modo compilado tiene dificultades?