What If We Replaced the U-Net with a Transformer?

From 2020 to 2023, every major diffusion model used a U-Net as its denoiser: Stable Diffusion 1.x, 2.x, and XL all relied on convolutional U-Nets with attention layers inserted at select resolutions. The U-Net's multi-resolution structure (downsample, process, upsample with skip connections) seemed purpose-built for denoising. But during this same period, transformers had already conquered NLP and were rapidly taking over vision too. The Vision Transformer ( ViT ) showed that a plain transformer, with no convolutional hierarchy at all, could match or beat CNNs on image classification. A natural question followed: can a transformer be the denoiser too?

The answer came from (Peebles & Xie, 2023) with the Diffusion Transformer (DiT) . Their result was striking: not only can a transformer replace the U-Net, it scales better. DiT showed that transformer-based denoisers follow the same scaling laws observed in language models — more compute leads to systematically better images, with no sign of plateau. This single finding redirected the field: Stable Diffusion 3, Flux, and Sora all abandoned U-Nets for transformer-based denoisers.

DiT Architecture

DiT operates in the same latent space as latent diffusion models. The input is a noisy latent $z_t \in \mathbb{R}^{h \times w \times c}$ produced by a pretrained VAE encoder (typically $h = w = 32$ and $c = 4$ for a 256\times256 image). The question is how to feed this 2D spatial grid into a transformer, which expects a 1D sequence of tokens.

Patchification. DiT borrows the same trick that ViT uses for images: split the latent into non-overlapping $p \times p$ patches, flatten each patch into a vector, and project it to the model's hidden dimension $d$ via a linear layer. With a positional embedding added to each token, we get a sequence of $T = \frac{h \cdot w}{p^2}$ tokens, each of dimension $d$:

$$T = \frac{h \cdot w}{p^2}, \quad \text{each token} \in \mathbb{R}^d$$

For a $32 \times 32$ latent with $p = 2$ patches: $T = \frac{32 \times 32}{4} = 256$ tokens. With $p = 4$: only 64 tokens. The patch size controls a compute-quality tradeoff — smaller patches mean longer sequences (quadratic attention cost) but finer spatial resolution. DiT-XL/2 (the flagship model, patch size 2) processes 256 tokens per layer.

💡 This patchification step is identical to how ViT processes images. The only difference is that ViT patches pixel-space images, while DiT patches VAE latents. Since latents are already spatially compressed ($32 \times 32$ instead of $256 \times 256$), even patch size 2 keeps the sequence length manageable.

The transformer blocks then process these tokens with the standard self-attention + feed-forward network structure. But the denoiser must also know two things: which timestep $t$ it is denoising at (how noisy is the input?) and which class $y$ to generate (what should the output depict?). DiT introduces a specific conditioning mechanism for this.

Adaptive Layer Norm Zero (AdaLN-Zero). A naive approach would be to add the timestep embedding directly to each token (the way sinusoidal positional encodings are added in standard transformers ). DiT instead modulates the normalisation itself. The timestep $t$ and class label $y$ are first embedded and summed into a single conditioning vector $c$. A small MLP then predicts six parameters per transformer block — scale ($\gamma_1, \gamma_2$), shift ($\beta_1, \beta_2$), and gate ($\alpha_1, \alpha_2$) — one triplet for the attention sub-block and one for the FFN sub-block. These parameters modulate the LayerNorm output:

$$\text{AdaLN}(h, c) = \gamma(c) \odot \text{LayerNorm}(h) + \beta(c)$$

Here $\gamma(c)$ and $\beta(c)$ are vectors predicted from the conditioning signal $c$, and $\odot$ is element-wise multiplication. When $\gamma = \mathbf{1}$ and $\beta = \mathbf{0}$, this reduces to standard LayerNorm. When $\gamma$ and $\beta$ deviate from these defaults, the normalisation shifts and scales differently depending on the timestep and class — the same hidden state gets processed differently at different noise levels.

The "Zero" part. AdaLN-Zero adds an additional gating parameter $\alpha$ that scales the entire residual connection:

$$h \leftarrow h + \alpha(c) \odot \text{Block}\big(\text{AdaLN}(h, c)\big)$$

The crucial detail: $\alpha$ is initialised to zero . At the start of training, $\alpha = \mathbf{0}$ for every block, so every block computes $h \leftarrow h + \mathbf{0} \cdot \text{Block}(\cdots) = h$. The entire transformer acts as the identity function — as if it had zero layers. Training then gradually "turns on" each block by learning non-zero $\alpha$ values. Why does this help? Deep networks are notoriously hard to train from random initialisation. By starting with the identity, DiT avoids the exploding/vanishing gradient problems at initialisation and ensures stable training even with 28+ transformer layers.

The output head. After all transformer blocks, a final AdaLN layer and a linear projection map each token back to a vector of dimension $p^2 \cdot c$. These vectors are reshaped into $p \times p$ patches and stitched together ( unpatchified ) to reconstruct the full spatial grid. The output has the same shape as the input: $h \times w \times c$. This is the predicted noise $\epsilon_\theta(z_t, t, y)$ (or the predicted velocity $v_\theta$, depending on the training objective).

import json, js

configs = [
    ("32x32 latent, p=2", 32, 32, 2),
    ("32x32 latent, p=4", 32, 32, 4),
    ("64x64 latent, p=2", 64, 64, 2),
    ("64x64 latent, p=4", 64, 64, 4),
]

rows = []
for name, h, w, p in configs:
    tokens = (h * w) // (p * p)
    attn = tokens * tokens
    rows.append([name, f"{p}x{p}", str(tokens), f"{attn:,}"])

js.window.py_table_data = json.dumps({
    "headers": ["Config", "Patches", "Tokens", "Attn Cost"],
    "rows": rows
})

print("Smaller patches = more tokens = finer detail but quadratic attention cost.")

Why Transformers Scale Better Than U-Nets

The headline result of the DiT paper is a clean scaling law : FID (Frechet Inception Distance, lower is better) improves log-linearly with compute measured in GFLOPs. The authors trained four model sizes (DiT-S, DiT-B, DiT-L, DiT-XL) at two patch sizes (2 and 4) and showed that plotting GFLOPs against FID produces a nearly straight line on a log scale. Doubling the model's compute budget consistently halved the gap to perfect FID.

Why don't U-Nets scale as gracefully? Several architectural constraints hold them back:

  • Fixed resolution hierarchy: U-Nets downsample through a fixed set of resolutions (e.g. 32 -> 16 -> 8 -> 4), then upsample back. Each resolution stage has its own set of convolutions. Adding capacity means adding more channels or more blocks at each stage, but the multi-resolution structure itself constrains how information flows.
  • Limited attention: In practice, self-attention in U-Nets is only applied at the lowest resolutions (8x8 or 16x16) because attention at full resolution is too expensive with convolution-based feature maps. Most of the network relies on local convolutions.
  • Diminishing returns: Scaling U-Nets beyond roughly 2-3 billion parameters showed diminishing improvements. The convolutional hierarchy becomes a bottleneck — more channels don't help if information can only flow through local receptive fields at high resolutions.

Transformers sidestep all of these. Every token attends to every other token at every layer — there is no forced hierarchy and no locality constraint. Information from any spatial position can influence any other position at every layer. This global attention at every layer is what makes transformers more parameter-efficient: adding parameters (deeper or wider) reliably improves the model's ability to denoise, with no architectural ceiling.

The numbers confirmed this. DiT-XL/2 (675M parameters) achieved state-of-the-art FID of 2.27 on class-conditional ImageNet 256x256 generation, outperforming the previous best U-Net-based model (ADM, ~554M params, FID 4.59) while being architecturally simpler. This wasn't a marginal improvement — it was a near-halving of FID with a straightforward architecture swap.

import json, js

models = [
    ("DiT-S/2",  "33",   "6.1",   "68.40"),
    ("DiT-B/2",  "130",  "23.0",  "43.50"),
    ("DiT-L/2",  "458",  "80.7",  "9.60"),
    ("DiT-XL/2", "675",  "118.6", "2.27"),
    ("ADM (U-Net)", "~554", "~1120", "4.59"),
]

js.window.py_table_data = json.dumps({
    "headers": ["Model", "Params (M)", "GFLOPs", "FID"],
    "rows": [list(m) for m in models]
})

print("DiT-XL/2 achieves lower FID than ADM with ~10x fewer GFLOPs per sample.")
💡 The scaling law result was the key contribution. Individual FID numbers will be beaten by future models, but the finding that transformer denoisers follow predictable, log-linear scaling — the same pattern seen in GPT-style language models — told the field that investing in larger transformer denoisers would reliably pay off.

From Class-Conditional to Text-Conditional DiT

The original DiT was class-conditional : it generated images of ImageNet classes ("golden retriever", "volcano", "espresso") by feeding a class label into AdaLN-Zero. But real text-to-image systems need open-ended text conditioning — a user types "a cat wearing a top hat on the moon" and the model must understand and render arbitrary descriptions. How do we condition a DiT on text instead of a class label?

Two main approaches emerged:

1. Cross-attention (the U-Net approach). This is how Stable Diffusion 1/2/XL conditioned on text: encode the prompt with a text encoder (CLIP or T5), then insert cross-attention layers where image tokens attend to text tokens. The image tokens form the queries, the text tokens form the keys and values. This works well but means text information only enters the network through these periodic cross-attention layers — between them, the image tokens process alone.

2. Joint attention (MMDiT). Stable Diffusion 3 (Esser et al., 2024) introduced the Multimodal Diffusion Transformer (MMDiT) , which takes a fundamentally different approach: concatenate the text tokens and image tokens into a single sequence and process them through the same self-attention layers. Both modalities attend to both modalities at every layer.

In MMDiT, each modality has its own projection weights for queries, keys, and values — $W_Q^{\text{text}}, W_K^{\text{text}}, W_V^{\text{text}}$ for text tokens and $W_Q^{\text{img}}, W_K^{\text{img}}, W_V^{\text{img}}$ for image tokens — but they share the same attention computation. After projecting, the keys and values from both modalities are concatenated, so every token (text or image) attends to every other token (text and image):

$$\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{Q \, [K^{\text{text}}; K^{\text{img}}]^\top}{\sqrt{d_k}}\right) [V^{\text{text}}; V^{\text{img}}]$$

where $[\cdot\,;\,\cdot]$ denotes concatenation along the sequence dimension. This means text features and image features are integrated at every single layer , not just at periodic cross-attention insertion points. The cost is a longer sequence (text + image tokens), which increases the quadratic attention cost. The benefit is tighter text-image alignment — every layer can refine the relationship between what the text says and what the image shows.

Flux (Black Forest Labs, 2024) took this even further with single-stream blocks where text and image tokens share the same projection weights entirely (no separate $W_Q^{\text{text}}$ and $W_Q^{\text{img}}$). The two modalities are treated as one unified sequence from start to finish. Flux uses a hybrid architecture: the first half of its layers are MMDiT-style (separate projections), and the second half are single-stream (shared projections), progressively merging the modalities.

💡 Why separate projection weights per modality? Text and image tokens live in different representation spaces (text from a language encoder, images from a VAE). Separate projections let each modality map into a shared attention space in its own way. Single-stream blocks assume the representations have been sufficiently aligned by earlier layers.

The Practical Impact

DiT did not just improve a benchmark number. It changed the trajectory of the entire image and video generation field. The progression of denoiser architectures tells the story:

  • Stable Diffusion 1.x / 2.x / XL: U-Net denoiser with cross-attention for text conditioning. The architecture that brought diffusion to the mainstream.
  • Stable Diffusion 3 / 3.5: MMDiT denoiser (transformer). Switched from U-Net to transformer-based backbone, with joint text-image attention.
  • Flux: Transformer-based, inspired by DiT. Hybrid MMDiT + single-stream architecture.
  • Sora: Spacetime DiT for video. Extends DiT to 3D by treating video frames as additional spatial tokens, enabling generation of temporally coherent video clips.

The shift happened for several reinforcing reasons. First, better scaling : the DiT paper proved that investing in larger models pays off predictably, which is exactly the signal companies need to justify training runs costing millions of dollars. Second, simpler architecture : a plain transformer with self-attention is architecturally simpler than a U-Net with its multi-resolution convolutional hierarchy, skip connections, and heterogeneous block types. Simpler architectures are easier to debug, optimise, and parallelise. Third, infrastructure reuse : the entire GPU software stack (FlashAttention, tensor parallelism, sequence parallelism, gradient checkpointing) was built for transformer-based language models. DiT-style models can directly leverage all of this without adaptation.

One important caveat: the VAE is unchanged . DiT replaced only the denoiser — the component that takes noisy latents and predicts the noise (or velocity). The VAE that compresses pixel-space images into latents and decodes latents back into pixels remains the same convolutional autoencoder used in latent diffusion. The latent space is the interface: the VAE produces it, the transformer denoises within it, and the VAE decodes the clean result back to pixels.

The DiT paper's deeper lesson is about universality . The same architecture that scales for next-token prediction in language also scales for noise prediction in images. This convergence suggests that the transformer is not specifically good at language — it is good at learning from data, and the scaling laws hold regardless of what that data represents.

Quiz

Test your understanding of the Diffusion Transformer architecture and its impact.

In DiT, what is the purpose of initialising the gating parameter $\alpha$ to zero in AdaLN-Zero?

How does MMDiT (used in Stable Diffusion 3) differ from cross-attention conditioning?

What was the key scaling result demonstrated by the DiT paper?

When the field moved from U-Net to DiT denoisers, which component of the latent diffusion pipeline stayed the same?