Functional vs Object-Oriented

PyTorch models are objects. You inherit from nn.Module , store weights as attributes, and call self.linear(x) . State lives inside the object — weights, buffers, running statistics. This object-oriented design feels natural to most Python programmers and mirrors how we tend to think about neural networks: a model is something, and it has parameters.

JAX takes a fundamentally different approach: pure functions . A JAX model is a function that takes parameters and inputs as separate arguments and returns output. There's no hidden state — everything is explicit. The model doesn't own its weights; instead, weights are just another argument you pass in.

Here are both paradigms side by side:

# ── PyTorch: Object-Oriented ──────────────────────
import torch
import torch.nn as nn

class Linear(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(out_dim, in_dim))
        self.bias = nn.Parameter(torch.zeros(out_dim))

    def forward(self, x):
        return x @ self.weight.T + self.bias

model = Linear(4, 3)
output = model(torch.randn(2, 4))  # state is inside model

# ── JAX: Functional ───────────────────────────────
import jax
import jax.numpy as jnp

def linear(params, x):
    return x @ params['weight'].T + params['bias']

params = {
    'weight': jax.random.normal(jax.random.PRNGKey(0), (3, 4)),
    'bias': jnp.zeros(3)
}
output = linear(params, jnp.ones((2, 4)))  # state is explicit

Why did JAX choose this? Pure functions are easier to compile (no hidden state to track), easier to parallelise (no shared mutable state), and easier to reason about mathematically (function composition). When a function has no side effects and depends only on its inputs, the compiler can freely reorder, fuse, and distribute operations without worrying about invisible dependencies.

The tradeoff is verbosity — passing params everywhere gets tedious, especially for deep networks with dozens of layers. This is why libraries like Flax (Google, 2020) and Equinox (Kidger, 2021) add nn.Module -like abstractions back on top of JAX. They give you the ergonomics of object-oriented code while preserving JAX's functional semantics under the hood — the best of both worlds, though with an additional layer of abstraction to learn.

jax.jit: Compile by Default

JAX's design philosophy is that compilation should be the default , not an opt-in afterthought. The @jax.jit decorator traces the function, captures a computation graph via XLA's HLO (High-Level Operations) intermediate representation, and compiles it into an optimised kernel. The first call pays the compilation cost; subsequent calls with the same input shapes run the compiled code directly, with zero Python overhead.

This stands in sharp contrast to PyTorch's historical approach:

  • PyTorch : eager by default, torch.compile opt-in (added in PyTorch 2.0, 2023)
  • JAX : compiled by default for performance-sensitive code, eager mode available for debugging

Here is what a typical JAX training step looks like:

# JAX: compilation is the natural way to run code
@jax.jit
def train_step(params, x, y):
    def loss_fn(p):
        pred = model(p, x)
        return jnp.mean((pred - y) ** 2)

    loss, grads = jax.value_and_grad(loss_fn)(params)
    params = jax.tree.map(lambda p, g: p - 0.01 * g, params, grads)
    return params, loss

# First call: traces + compiles via XLA
# Subsequent calls: runs compiled code directly (no Python overhead)

Notice how jax.value_and_grad computes both the loss value and the gradients in a single call — another functional pattern. There's no equivalent of loss.backward() mutating a tape; instead, jax.grad is a source-level transformation that takes a function and returns a new function computing its gradient. Differentiation is just another function transform, composable with jit , vmap , and pmap .

💡 jax.jit requires that the function be 'pure' — same inputs must produce same outputs, no side effects. This constraint is what makes compilation possible: the compiler can reason about the entire function without worrying about hidden state changes. If your function reads a global variable or mutates an array in place, jax.jit will either silently capture a stale value or raise an error. This strictness is a feature, not a bug — it forces code that the compiler can actually optimise.

XLA: Accelerated Linear Algebra

XLA (Accelerated Linear Algebra) is Google's compiler for linear algebra operations. Originally developed for TensorFlow (circa 2017), it became the backend that makes JAX fast. When jax.jit traces a function, the result is an HLO program — a graph of high-level operations like "matmul", "add", "reduce" — which XLA then optimises and compiles down to device-specific machine code.

Here is XLA's compilation pipeline:

Python function
    ↓ jax.jit traces it
HLO IR (High-Level Operations)
    ↓ XLA optimises (fusion, layout, scheduling)
Device-specific code
    ├── GPU: LLVM IR → PTX → SASS (similar to Triton)
    ├── TPU: TPU-specific machine code
    └── CPU: LLVM IR → x86/ARM assembly

XLA's optimisation passes are where the real magic happens. Operator fusion merges multiple elementwise operations (add, multiply, activation function) into a single kernel launch, eliminating intermediate memory reads and writes. Layout optimisation rearranges tensor memory layouts to match hardware preferences (e.g., choosing between row-major and column-major storage). Scheduling orders operations to maximise hardware utilisation and overlap computation with memory transfers.

The key difference from PyTorch's approach comes down to where kernels come from:

  • PyTorch (eager) : uses pre-built cuBLAS/cuDNN kernels, optimised by NVIDIA over many years
  • PyTorch (compiled) : uses TorchInductor + Triton, generates kernels at compile time
  • JAX : uses XLA, generates kernels at compile time via LLVM

XLA's advantage : it targets multiple backends from the same source. The same JAX code runs on GPU, TPU, and CPU without modification. PyTorch's torch.compile currently targets GPU (via Triton) and CPU (via C++/OpenMP), but does not natively support TPUs. For organisations with access to Google's TPU pods — which can offer excellent price-performance for large-scale training — this multi-backend capability is arguably JAX's strongest selling point.

XLA's disadvantage : it requires static shapes at compile time. Every tensor dimension must be known when the function is traced, and any change in shape triggers a full recompilation. A training loop where the batch size or sequence length varies from step to step can end up recompiling repeatedly, which is costly. PyTorch's eager mode handles dynamic shapes naturally — there's no compiled kernel to invalidate, so a batch of 32 and a batch of 37 both run without any compilation overhead.

Tradeoffs: When Each Shines

Neither framework is universally better — they make different tradeoffs that suit different workflows and constraints. Understanding where each one shines helps you pick the right tool for the job.

PyTorch strengths:

  • Debugging : eager mode lets you print() , inspect intermediate values, and set breakpoints anywhere in your model. What you write is what runs — no tracing surprises.
  • Dynamic shapes : batch sizes, sequence lengths, and graph structures can change freely between iterations without any recompilation penalty.
  • Ecosystem : the largest model zoo (Hugging Face hosts tens of thousands of PyTorch models), the most tutorials, and the widest industry adoption. If you need a pre-trained checkpoint, it almost certainly exists in PyTorch format.
  • Gradual compilation : start with eager mode for prototyping and debugging, then add torch.compile when you're ready for speed — no rewrite required.

JAX strengths:

  • TPU support : first-class TPU compilation via XLA. Google's TPU pods are among the most cost-effective hardware for large-scale training, and JAX is the most natural way to target them.
  • Functional transforms : jax.vmap (auto-batching), jax.pmap (auto-parallelism), and jax.grad compose cleanly because everything is a pure function. You can vmap a gradient of a jitted function — these transforms nest naturally.
  • Reproducibility : explicit PRNG state (no global random seed) makes experiments reproducible by construction. Every random operation requires an explicit jax.random.PRNGKey , so there's no hidden global state that can silently differ between runs.
  • Research velocity for certain workloads : DeepMind, Google Brain (now Google DeepMind), and several leading research labs use JAX extensively for large-scale experiments where functional composition and TPU access are critical advantages.
💡 The gap between the two is narrowing. PyTorch 2.0+ added torch.compile for graph-mode compilation, and PyTorch/XLA provides an experimental path to run PyTorch on TPUs. JAX added debugging tools like jax.debug.print and jax.disable_jit for stepping through code eagerly. Both frameworks are converging toward 'easy to write, fast to run' — they just started from different ends of the design spectrum.

Side-by-Side: The Full Compilation Stacks

To wrap up, here is a visual summary that places both frameworks next to each other across every layer of the stack — from the frontend you write to the hardware that runs it:

              PyTorch                          JAX
              ───────                          ───
Frontend:     Python (nn.Module)               Python (pure functions)
Paradigm:     Object-oriented, stateful        Functional, stateless
Default:      Eager (immediate execution)      Compiled (jax.jit)

Compilation:  torch.compile (opt-in)           jax.jit (standard)
              │                                │
Graph capture: TorchDynamo                     JAX tracing
              │                                │
Optimiser:    TorchInductor                    XLA
              │                                │
Kernel gen:   Triton (GPU) / C++ (CPU)         LLVM (GPU/TPU/CPU)
              │                                │
GPU assembly: Triton → PTX → SASS             LLVM → PTX → SASS
              │                                │
Pre-built:    cuBLAS, cuDNN (eager path)       (none — always compiled)
              │                                │
Hardware:     NVIDIA GPUs                      NVIDIA GPUs + Google TPUs

Gradient:     Autograd (tape-based)            jax.grad (source transform)
State:        Inside model (self.weight)       Explicit (params dict)
Random:       Global seed (torch.manual_seed)  Explicit key (jax.random.PRNGKey)

Several rows in this table deserve a moment of reflection. The gradient row highlights a deep architectural difference: PyTorch's autograd records operations onto a tape during the forward pass and replays that tape in reverse during loss.backward() . JAX's jax.grad is a source-level transformation — it takes a function and returns a new function that computes its derivative. There's no tape to build at runtime; the gradient function is constructed once and compiled.

The random row might seem like a minor detail, but it has real consequences for reproducibility. PyTorch uses a global random number generator ( torch.manual_seed(42) ), which means the sequence of random numbers depends on the order in which operations execute — add a dropout layer, and every subsequent random call shifts. JAX sidesteps this by requiring an explicit PRNGKey for every random operation, making randomness deterministic and independent of execution order.

Which one should you use? For most practitioners, PyTorch's larger ecosystem and easier debugging make it the default choice. For research pushing the boundaries of compilation, parallelism, or TPU-scale training, JAX offers powerful functional abstractions. Understanding both helps you pick the right tool — and appreciate what each framework does under the hood.

Quiz

Test your understanding of the philosophical and architectural differences between PyTorch and JAX.

What is the fundamental paradigm difference between PyTorch and JAX?

Why does jax.jit require pure functions?

What advantage does XLA have over Triton for kernel generation?

Why can eager mode handle dynamic shapes but compiled mode struggles?