Funcional vs Orientado a Objetos
Los modelos de PyTorch son objetos. Heredas de
nn.Module
, almacenas pesos como atributos y llamas a
self.linear(x)
. El estado vive dentro del objeto — pesos, buffers, estadísticas acumuladas. Este diseño orientado a objetos resulta natural para la mayoría de los programadores Python y refleja cómo solemos pensar sobre las redes neuronales: un modelo
es
algo, y
tiene
parámetros.
JAX adopta un enfoque fundamentalmente distinto: funciones puras . Un modelo en JAX es una función que recibe los parámetros y las entradas como argumentos separados y devuelve la salida. No hay estado oculto — todo es explícito. El modelo no posee sus pesos; en cambio, los pesos son simplemente otro argumento que pasas.
Aquí están ambos paradigmas lado a lado:
# ── PyTorch: Object-Oriented ──────────────────────
import torch
import torch.nn as nn
class Linear(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim))
self.bias = nn.Parameter(torch.zeros(out_dim))
def forward(self, x):
return x @ self.weight.T + self.bias
model = Linear(4, 3)
output = model(torch.randn(2, 4)) # state is inside model
# ── JAX: Functional ───────────────────────────────
import jax
import jax.numpy as jnp
def linear(params, x):
return x @ params['weight'].T + params['bias']
params = {
'weight': jax.random.normal(jax.random.PRNGKey(0), (3, 4)),
'bias': jnp.zeros(3)
}
output = linear(params, jnp.ones((2, 4))) # state is explicit
¿Por qué JAX eligió este camino? Las funciones puras son más fáciles de compilar (no hay estado oculto que rastrear), más fáciles de paralelizar (no hay estado mutable compartido) y más fáciles de razonar matemáticamente (composición de funciones). Cuando una función no tiene efectos secundarios y depende solo de sus entradas, el compilador puede reordenar, fusionar y distribuir operaciones libremente sin preocuparse por dependencias invisibles.
El precio es la verbosidad — pasar
params
a todas partes se vuelve tedioso, especialmente para redes profundas con decenas de capas. Por eso bibliotecas como
Flax
(Google, 2020) y
Equinox
(Kidger, 2021) añaden abstracciones tipo
nn.Module
encima de JAX. Te dan la ergonomía del código orientado a objetos mientras preservan la semántica funcional de JAX por debajo — lo mejor de ambos mundos, aunque con una capa adicional de abstracción que aprender.
jax.jit: Compilar por Defecto
La filosofía de diseño de JAX es que la compilación debe ser lo
predeterminado
, no un complemento opcional. El decorador
@jax.jit
traza la función, captura un grafo de cómputo a través de la representación intermedia HLO (High-Level Operations) de XLA, y lo compila en un kernel optimizado. La primera llamada paga el costo de compilación; las llamadas subsecuentes con las mismas formas de entrada ejecutan el código compilado directamente, con cero sobrecarga de Python.
Esto contrasta marcadamente con el enfoque histórico de PyTorch:
-
PyTorch
: eager por defecto,
torch.compileopcional (añadido en PyTorch 2.0, 2023) - JAX : compilado por defecto para código sensible al rendimiento, modo eager disponible para depuración
Así luce un paso de entrenamiento típico en JAX:
# JAX: compilation is the natural way to run code
@jax.jit
def train_step(params, x, y):
def loss_fn(p):
pred = model(p, x)
return jnp.mean((pred - y) ** 2)
loss, grads = jax.value_and_grad(loss_fn)(params)
params = jax.tree.map(lambda p, g: p - 0.01 * g, params, grads)
return params, loss
# First call: traces + compiles via XLA
# Subsequent calls: runs compiled code directly (no Python overhead)
Observa cómo
jax.value_and_grad
calcula tanto el valor de la pérdida como los gradientes en una sola llamada — otro patrón funcional. No hay equivalente a
loss.backward()
mutando una cinta; en su lugar,
jax.grad
es una
transformación a nivel de código fuente
que toma una función y devuelve una nueva función que calcula su gradiente. La diferenciación es simplemente otra transformación de funciones, componible con
jit
,
vmap
y
pmap
.
XLA: Álgebra Lineal Acelerada
XLA (Accelerated Linear Algebra) es el compilador de Google para operaciones de álgebra lineal. Originalmente desarrollado para TensorFlow (circa 2017), se convirtió en el backend que hace a JAX rápido. Cuando
jax.jit
traza una función, el resultado es un programa HLO — un grafo de operaciones de alto nivel como "matmul", "add", "reduce" — que XLA luego optimiza y compila a código máquina específico del dispositivo.
Este es el pipeline de compilación de XLA:
Python function
↓ jax.jit traces it
HLO IR (High-Level Operations)
↓ XLA optimises (fusion, layout, scheduling)
Device-specific code
├── GPU: LLVM IR → PTX → SASS (similar to Triton)
├── TPU: TPU-specific machine code
└── CPU: LLVM IR → x86/ARM assembly
Los pasos de optimización de XLA son donde ocurre la verdadera magia. La fusión de operadores combina múltiples operaciones elementales (suma, multiplicación, función de activación) en un solo lanzamiento de kernel, eliminando lecturas y escrituras intermedias de memoria. La optimización de layout reorganiza las disposiciones de memoria de los tensores para coincidir con las preferencias del hardware (por ejemplo, eligiendo entre almacenamiento por filas o por columnas). La planificación ordena las operaciones para maximizar la utilización del hardware y superponer cómputo con transferencias de memoria.
La diferencia clave con el enfoque de PyTorch se reduce a de dónde provienen los kernels:
- PyTorch (eager) : usa kernels preconstruidos de cuBLAS/cuDNN, optimizados por NVIDIA durante muchos años
- PyTorch (compilado) : usa TorchInductor + Triton, genera kernels en tiempo de compilación
- JAX : usa XLA, genera kernels en tiempo de compilación vía LLVM
Ventaja de XLA
: apunta a múltiples backends desde el mismo código fuente. El mismo código JAX se ejecuta en GPU, TPU y CPU sin modificación. El
torch.compile
de PyTorch actualmente apunta a GPU (vía Triton) y CPU (vía C++/OpenMP), pero no soporta TPUs de forma nativa. Para organizaciones con acceso a los pods de TPU de Google — que pueden ofrecer excelente relación precio-rendimiento para entrenamiento a gran escala — esta capacidad multi-backend es posiblemente la mayor ventaja de JAX.
Desventaja de XLA : requiere formas estáticas en tiempo de compilación. Cada dimensión del tensor debe ser conocida cuando se traza la función, y cualquier cambio de forma dispara una recompilación completa. Un bucle de entrenamiento donde el tamaño de lote o la longitud de secuencia varía de paso a paso puede terminar recompilando repetidamente, lo cual es costoso. El modo eager de PyTorch maneja formas dinámicas naturalmente — no hay kernel compilado que invalidar, así que un lote de 32 y un lote de 37 se ejecutan sin ninguna sobrecarga de compilación.
Compromisos: Cuándo Brilla Cada Uno
Ningún framework es universalmente mejor — hacen compromisos diferentes que se adaptan a distintos flujos de trabajo y restricciones. Entender dónde brilla cada uno te ayuda a elegir la herramienta correcta para el trabajo.
Fortalezas de PyTorch:
-
Depuración
: el modo eager te permite usar
print(), inspeccionar valores intermedios y poner breakpoints en cualquier parte de tu modelo. Lo que escribes es lo que se ejecuta — sin sorpresas de trazado. - Formas dinámicas : los tamaños de lote, longitudes de secuencia y estructuras de grafos pueden cambiar libremente entre iteraciones sin ninguna penalización de recompilación.
- Ecosistema : el mayor repositorio de modelos (Hugging Face aloja decenas de miles de modelos PyTorch), la mayor cantidad de tutoriales y la adopción industrial más amplia. Si necesitas un checkpoint preentrenado, casi seguro existe en formato PyTorch.
-
Compilación gradual
: comienza con modo eager para prototipar y depurar, luego añade
torch.compilecuando estés listo para velocidad — sin necesidad de reescribir.
Fortalezas de JAX:
- Soporte para TPU : compilación de primera clase para TPU vía XLA. Los pods de TPU de Google están entre el hardware más rentable para entrenamiento a gran escala, y JAX es la forma más natural de utilizarlos.
-
Transformaciones funcionales
:
jax.vmap(auto-batching),jax.pmap(auto-paralelismo) yjax.gradse componen limpiamente porque todo es una función pura. Puedes aplicar vmap al gradiente de una función jiteada — estas transformaciones se anidan naturalmente. -
Reproducibilidad
: estado PRNG explícito (sin semilla aleatoria global) hace que los experimentos sean reproducibles por construcción. Cada operación aleatoria requiere un
jax.random.PRNGKeyexplícito, así que no hay estado global oculto que pueda diferir silenciosamente entre ejecuciones. - Velocidad de investigación para ciertas cargas : DeepMind, Google Brain (ahora Google DeepMind) y varios laboratorios de investigación líderes usan JAX extensivamente para experimentos a gran escala donde la composición funcional y el acceso a TPUs son ventajas críticas.
Lado a Lado: Las Pilas de Compilación Completas
Para cerrar, aquí hay un resumen visual que coloca ambos frameworks uno al lado del otro a través de cada capa de la pila — desde el frontend que escribes hasta el hardware que lo ejecuta:
PyTorch JAX
─────── ───
Frontend: Python (nn.Module) Python (pure functions)
Paradigm: Object-oriented, stateful Functional, stateless
Default: Eager (immediate execution) Compiled (jax.jit)
Compilation: torch.compile (opt-in) jax.jit (standard)
│ │
Graph capture: TorchDynamo JAX tracing
│ │
Optimiser: TorchInductor XLA
│ │
Kernel gen: Triton (GPU) / C++ (CPU) LLVM (GPU/TPU/CPU)
│ │
GPU assembly: Triton → PTX → SASS LLVM → PTX → SASS
│ │
Pre-built: cuBLAS, cuDNN (eager path) (none — always compiled)
│ │
Hardware: NVIDIA GPUs NVIDIA GPUs + Google TPUs
Gradient: Autograd (tape-based) jax.grad (source transform)
State: Inside model (self.weight) Explicit (params dict)
Random: Global seed (torch.manual_seed) Explicit key (jax.random.PRNGKey)
Varias filas de esta tabla merecen un momento de reflexión. La fila de
gradiente
resalta una diferencia arquitectónica profunda: el autograd de PyTorch graba operaciones en una cinta durante el paso forward y reproduce esa cinta en reversa durante
loss.backward()
. El
jax.grad
de JAX es una
transformación a nivel de código fuente
— toma una función y devuelve una nueva función que calcula su derivada. No hay cinta que construir en tiempo de ejecución; la función gradiente se construye una vez y se compila.
La fila de
aleatoriedad
podría parecer un detalle menor, pero tiene consecuencias reales para la reproducibilidad. PyTorch usa un generador de números aleatorios global (
torch.manual_seed(42)
), lo que significa que la secuencia de números aleatorios depende del
orden
en que se ejecutan las operaciones — añade una capa de dropout, y cada llamada aleatoria subsecuente se desplaza. JAX evita esto requiriendo un
PRNGKey
explícito para cada operación aleatoria, haciendo que la aleatoriedad sea determinista e independiente del orden de ejecución.
¿Cuál deberías usar? Para la mayoría de los profesionales, el ecosistema más grande de PyTorch y su depuración más sencilla lo convierten en la opción predeterminada. Para investigación que empuja los límites de la compilación, el paralelismo o el entrenamiento a escala de TPU, JAX ofrece poderosas abstracciones funcionales. Entender ambos te ayuda a elegir la herramienta correcta — y a apreciar lo que cada framework hace bajo el capó.
Quiz
Pon a prueba tu comprensión de las diferencias filosóficas y arquitectónicas entre PyTorch y JAX.
¿Cuál es la diferencia fundamental de paradigma entre PyTorch y JAX?
¿Por qué jax.jit requiere funciones puras?
¿Qué ventaja tiene XLA sobre Triton para la generación de kernels?
¿Por qué el modo eager puede manejar formas dinámicas pero el modo compilado tiene dificultades?