Functional vs Object-Oriented
PyTorch models are objects. You inherit from
nn.Module
, store weights as attributes, and call
self.linear(x)
. State lives inside the object — weights, buffers, running statistics. This object-oriented design feels natural to most Python programmers and mirrors how we tend to think about neural networks: a model
is
something, and it
has
parameters.
JAX takes a fundamentally different approach: pure functions . A JAX model is a function that takes parameters and inputs as separate arguments and returns output. There's no hidden state — everything is explicit. The model doesn't own its weights; instead, weights are just another argument you pass in.
Here are both paradigms side by side:
# ── PyTorch: Object-Oriented ──────────────────────
import torch
import torch.nn as nn
class Linear(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim))
self.bias = nn.Parameter(torch.zeros(out_dim))
def forward(self, x):
return x @ self.weight.T + self.bias
model = Linear(4, 3)
output = model(torch.randn(2, 4)) # state is inside model
# ── JAX: Functional ───────────────────────────────
import jax
import jax.numpy as jnp
def linear(params, x):
return x @ params['weight'].T + params['bias']
params = {
'weight': jax.random.normal(jax.random.PRNGKey(0), (3, 4)),
'bias': jnp.zeros(3)
}
output = linear(params, jnp.ones((2, 4))) # state is explicit
Why did JAX choose this? Pure functions are easier to compile (no hidden state to track), easier to parallelise (no shared mutable state), and easier to reason about mathematically (function composition). When a function has no side effects and depends only on its inputs, the compiler can freely reorder, fuse, and distribute operations without worrying about invisible dependencies.
The tradeoff is verbosity — passing
params
everywhere gets tedious, especially for deep networks with dozens of layers. This is why libraries like
Flax
(Google, 2020) and
Equinox
(Kidger, 2021) add
nn.Module
-like abstractions back on top of JAX. They give you the ergonomics of object-oriented code while preserving JAX's functional semantics under the hood — the best of both worlds, though with an additional layer of abstraction to learn.
jax.jit: Compile by Default
JAX's design philosophy is that compilation should be the
default
, not an opt-in afterthought. The
@jax.jit
decorator traces the function, captures a computation graph via XLA's HLO (High-Level Operations) intermediate representation, and compiles it into an optimised kernel. The first call pays the compilation cost; subsequent calls with the same input shapes run the compiled code directly, with zero Python overhead.
This stands in sharp contrast to PyTorch's historical approach:
-
PyTorch
: eager by default,
torch.compileopt-in (added in PyTorch 2.0, 2023) - JAX : compiled by default for performance-sensitive code, eager mode available for debugging
Here is what a typical JAX training step looks like:
# JAX: compilation is the natural way to run code
@jax.jit
def train_step(params, x, y):
def loss_fn(p):
pred = model(p, x)
return jnp.mean((pred - y) ** 2)
loss, grads = jax.value_and_grad(loss_fn)(params)
params = jax.tree.map(lambda p, g: p - 0.01 * g, params, grads)
return params, loss
# First call: traces + compiles via XLA
# Subsequent calls: runs compiled code directly (no Python overhead)
Notice how
jax.value_and_grad
computes both the loss value and the gradients in a single call — another functional pattern. There's no equivalent of
loss.backward()
mutating a tape; instead,
jax.grad
is a
source-level transformation
that takes a function and returns a new function computing its gradient. Differentiation is just another function transform, composable with
jit
,
vmap
, and
pmap
.
XLA: Accelerated Linear Algebra
XLA (Accelerated Linear Algebra) is Google's compiler for linear algebra operations. Originally developed for TensorFlow (circa 2017), it became the backend that makes JAX fast. When
jax.jit
traces a function, the result is an HLO program — a graph of high-level operations like "matmul", "add", "reduce" — which XLA then optimises and compiles down to device-specific machine code.
Here is XLA's compilation pipeline:
Python function
↓ jax.jit traces it
HLO IR (High-Level Operations)
↓ XLA optimises (fusion, layout, scheduling)
Device-specific code
├── GPU: LLVM IR → PTX → SASS (similar to Triton)
├── TPU: TPU-specific machine code
└── CPU: LLVM IR → x86/ARM assembly
XLA's optimisation passes are where the real magic happens. Operator fusion merges multiple elementwise operations (add, multiply, activation function) into a single kernel launch, eliminating intermediate memory reads and writes. Layout optimisation rearranges tensor memory layouts to match hardware preferences (e.g., choosing between row-major and column-major storage). Scheduling orders operations to maximise hardware utilisation and overlap computation with memory transfers.
The key difference from PyTorch's approach comes down to where kernels come from:
- PyTorch (eager) : uses pre-built cuBLAS/cuDNN kernels, optimised by NVIDIA over many years
- PyTorch (compiled) : uses TorchInductor + Triton, generates kernels at compile time
- JAX : uses XLA, generates kernels at compile time via LLVM
XLA's advantage
: it targets multiple backends from the same source. The same JAX code runs on GPU, TPU, and CPU without modification. PyTorch's
torch.compile
currently targets GPU (via Triton) and CPU (via C++/OpenMP), but does not natively support TPUs. For organisations with access to Google's TPU pods — which can offer excellent price-performance for large-scale training — this multi-backend capability is arguably JAX's strongest selling point.
XLA's disadvantage : it requires static shapes at compile time. Every tensor dimension must be known when the function is traced, and any change in shape triggers a full recompilation. A training loop where the batch size or sequence length varies from step to step can end up recompiling repeatedly, which is costly. PyTorch's eager mode handles dynamic shapes naturally — there's no compiled kernel to invalidate, so a batch of 32 and a batch of 37 both run without any compilation overhead.
Tradeoffs: When Each Shines
Neither framework is universally better — they make different tradeoffs that suit different workflows and constraints. Understanding where each one shines helps you pick the right tool for the job.
PyTorch strengths:
-
Debugging
: eager mode lets you
print(), inspect intermediate values, and set breakpoints anywhere in your model. What you write is what runs — no tracing surprises. - Dynamic shapes : batch sizes, sequence lengths, and graph structures can change freely between iterations without any recompilation penalty.
- Ecosystem : the largest model zoo (Hugging Face hosts tens of thousands of PyTorch models), the most tutorials, and the widest industry adoption. If you need a pre-trained checkpoint, it almost certainly exists in PyTorch format.
-
Gradual compilation
: start with eager mode for prototyping and debugging, then add
torch.compilewhen you're ready for speed — no rewrite required.
JAX strengths:
- TPU support : first-class TPU compilation via XLA. Google's TPU pods are among the most cost-effective hardware for large-scale training, and JAX is the most natural way to target them.
-
Functional transforms
:
jax.vmap(auto-batching),jax.pmap(auto-parallelism), andjax.gradcompose cleanly because everything is a pure function. You can vmap a gradient of a jitted function — these transforms nest naturally. -
Reproducibility
: explicit PRNG state (no global random seed) makes experiments reproducible by construction. Every random operation requires an explicit
jax.random.PRNGKey, so there's no hidden global state that can silently differ between runs. - Research velocity for certain workloads : DeepMind, Google Brain (now Google DeepMind), and several leading research labs use JAX extensively for large-scale experiments where functional composition and TPU access are critical advantages.
Side-by-Side: The Full Compilation Stacks
To wrap up, here is a visual summary that places both frameworks next to each other across every layer of the stack — from the frontend you write to the hardware that runs it:
PyTorch JAX
─────── ───
Frontend: Python (nn.Module) Python (pure functions)
Paradigm: Object-oriented, stateful Functional, stateless
Default: Eager (immediate execution) Compiled (jax.jit)
Compilation: torch.compile (opt-in) jax.jit (standard)
│ │
Graph capture: TorchDynamo JAX tracing
│ │
Optimiser: TorchInductor XLA
│ │
Kernel gen: Triton (GPU) / C++ (CPU) LLVM (GPU/TPU/CPU)
│ │
GPU assembly: Triton → PTX → SASS LLVM → PTX → SASS
│ │
Pre-built: cuBLAS, cuDNN (eager path) (none — always compiled)
│ │
Hardware: NVIDIA GPUs NVIDIA GPUs + Google TPUs
Gradient: Autograd (tape-based) jax.grad (source transform)
State: Inside model (self.weight) Explicit (params dict)
Random: Global seed (torch.manual_seed) Explicit key (jax.random.PRNGKey)
Several rows in this table deserve a moment of reflection. The
gradient
row highlights a deep architectural difference: PyTorch's autograd records operations onto a tape during the forward pass and replays that tape in reverse during
loss.backward()
. JAX's
jax.grad
is a
source-level transformation
— it takes a function and returns a new function that computes its derivative. There's no tape to build at runtime; the gradient function is constructed once and compiled.
The
random
row might seem like a minor detail, but it has real consequences for reproducibility. PyTorch uses a global random number generator (
torch.manual_seed(42)
), which means the sequence of random numbers depends on the
order
in which operations execute — add a dropout layer, and every subsequent random call shifts. JAX sidesteps this by requiring an explicit
PRNGKey
for every random operation, making randomness deterministic and independent of execution order.
Which one should you use? For most practitioners, PyTorch's larger ecosystem and easier debugging make it the default choice. For research pushing the boundaries of compilation, parallelism, or TPU-scale training, JAX offers powerful functional abstractions. Understanding both helps you pick the right tool — and appreciate what each framework does under the hood.
Quiz
Test your understanding of the philosophical and architectural differences between PyTorch and JAX.
What is the fundamental paradigm difference between PyTorch and JAX?
Why does jax.jit require pure functions?
What advantage does XLA have over Triton for kernel generation?
Why can eager mode handle dynamic shapes but compiled mode struggles?