¿Qué pasaría si reemplazáramos la U-Net con un transformer?

De 2020 a 2023, cada modelo de difusión importante usaba una U-Net como su eliminador de ruido: Stable Diffusion 1.x, 2.x y XL todos dependían de U-Nets convolucionales con capas de atención insertadas en resoluciones selectas. La estructura multi-resolución de la U-Net (submuestreo, procesamiento, sobremuestreo con conexiones residuales) parecía diseñada a propósito para la eliminación de ruido. Pero durante este mismo período, los transformers ya habían conquistado el NLP y estaban rápidamente tomando la visión también. El Vision Transformer ( ViT ) mostró que un transformer plano, sin ninguna jerarquía convolucional, podía igualar o superar a las CNNs en clasificación de imágenes. Una pregunta natural siguió: ¿puede un transformer ser el eliminador de ruido también?

La respuesta vino de (Peebles & Xie, 2023) con el Diffusion Transformer (DiT) . Su resultado fue impactante: no solo puede un transformer reemplazar la U-Net, sino que escala mejor. DiT mostró que los eliminadores de ruido basados en transformers siguen las mismas leyes de escalado observadas en modelos de lenguaje — más cómputo lleva a imágenes sistemáticamente mejores, sin señal de meseta. Este único hallazgo redirigió el campo: Stable Diffusion 3, Flux y Sora todos abandonaron las U-Nets por eliminadores de ruido basados en transformers.

Arquitectura DiT

DiT opera en el mismo espacio latente que los modelos de difusión latente. La entrada es un latente ruidoso $z_t \in \mathbb{R}^{h \times w \times c}$ producido por un codificador VAE preentrenado (típicamente $h = w = 32$ y $c = 4$ para una imagen de 256\times256). La pregunta es cómo alimentar esta cuadrícula espacial 2D a un transformer, que espera una secuencia 1D de tokens.

Parchificación. DiT toma prestado el mismo truco que ViT usa para imágenes: dividir el latente en parches $p \times p$ no superpuestos, aplanar cada parche en un vector, y proyectarlo a la dimensión oculta $d$ del modelo vía una capa lineal. Con un embedding posicional añadido a cada token, obtenemos una secuencia de $T = \frac{h \cdot w}{p^2}$ tokens, cada uno de dimensión $d$:

$$T = \frac{h \cdot w}{p^2}, \quad \text{each token} \in \mathbb{R}^d$$

Para un latente de $32 \times 32$ con parches $p = 2$: $T = \frac{32 \times 32}{4} = 256$ tokens. Con $p = 4$: solo 64 tokens. El tamaño de parche controla un compromiso cómputo-calidad — parches más pequeños significan secuencias más largas (costo de atención cuadrático) pero mayor resolución espacial. DiT-XL/2 (el modelo insignia, tamaño de parche 2) procesa 256 tokens por capa.

💡 Este paso de parchificación es idéntico a cómo ViT procesa imágenes. La única diferencia es que ViT parchifica imágenes en espacio de píxeles, mientras que DiT parchifica latentes de VAE. Como los latentes ya están espacialmente comprimidos ($32 \times 32$ en lugar de $256 \times 256$), incluso con tamaño de parche 2 la longitud de secuencia se mantiene manejable.

Los bloques transformer luego procesan estos tokens con la estructura estándar de auto-atención + red feed-forward . Pero el eliminador de ruido también debe saber dos cosas: en qué paso de tiempo $t$ está eliminando ruido (¿qué tan ruidosa es la entrada?) y qué clase $y$ generar (¿qué debería representar la salida?). DiT introduce un mecanismo de condicionamiento específico para esto.

Adaptive Layer Norm Zero (AdaLN-Zero). Un enfoque ingenuo sería añadir el embedding del paso de tiempo directamente a cada token (como se añaden las codificaciones posicionales sinusoidales en transformers estándar ). DiT en cambio modula la normalización misma. El paso de tiempo $t$ y la etiqueta de clase $y$ primero se incrustan y suman en un único vector de condicionamiento $c$. Un pequeño MLP luego predice seis parámetros por bloque transformer — escala ($\gamma_1, \gamma_2$), desplazamiento ($\beta_1, \beta_2$) y compuerta ($\alpha_1, \alpha_2$) — un triplete para el sub-bloque de atención y otro para el sub-bloque FFN. Estos parámetros modulan la salida de LayerNorm:

$$\text{AdaLN}(h, c) = \gamma(c) \odot \text{LayerNorm}(h) + \beta(c)$$

Aquí $\gamma(c)$ y $\beta(c)$ son vectores predichos a partir de la señal de condicionamiento $c$, y $\odot$ es multiplicación elemento a elemento. Cuando $\gamma = \mathbf{1}$ y $\beta = \mathbf{0}$, esto se reduce a LayerNorm estándar. Cuando $\gamma$ y $\beta$ se desvían de estos valores predeterminados, la normalización se desplaza y escala de manera diferente dependiendo del paso de tiempo y la clase — el mismo estado oculto se procesa de manera diferente en diferentes niveles de ruido.

La parte "Zero". AdaLN-Zero añade un parámetro de compuerta adicional $\alpha$ que escala toda la conexión residual:

$$h \leftarrow h + \alpha(c) \odot \text{Block}\big(\text{AdaLN}(h, c)\big)$$

El detalle crucial: $\alpha$ se inicializa a cero . Al inicio del entrenamiento, $\alpha = \mathbf{0}$ para cada bloque, así que cada bloque calcula $h \leftarrow h + \mathbf{0} \cdot \text{Block}(\cdots) = h$. El transformer completo actúa como la función identidad — como si tuviera cero capas. El entrenamiento luego gradualmente "enciende" cada bloque aprendiendo valores no nulos de $\alpha$. ¿Por qué ayuda esto? Las redes profundas son notoriamente difíciles de entrenar desde una inicialización aleatoria. Al comenzar con la identidad, DiT evita los problemas de gradientes que explotan/desvanecen al inicio y asegura entrenamiento estable incluso con 28+ capas transformer.

La cabeza de salida. Después de todos los bloques transformer, una capa final AdaLN y una proyección lineal mapean cada token de vuelta a un vector de dimensión $p^2 \cdot c$. Estos vectores se reorganizan en parches de $p \times p$ y se ensamblan ( desparchifican ) para reconstruir la cuadrícula espacial completa. La salida tiene la misma forma que la entrada: $h \times w \times c$. Esta es la predicción de ruido $\epsilon_\theta(z_t, t, y)$ (o la velocidad predicha $v_\theta$, dependiendo del objetivo de entrenamiento).

import json, js

configs = [
    ("32x32 latent, p=2", 32, 32, 2),
    ("32x32 latent, p=4", 32, 32, 4),
    ("64x64 latent, p=2", 64, 64, 2),
    ("64x64 latent, p=4", 64, 64, 4),
]

rows = []
for name, h, w, p in configs:
    tokens = (h * w) // (p * p)
    attn = tokens * tokens
    rows.append([name, f"{p}x{p}", str(tokens), f"{attn:,}"])

js.window.py_table_data = json.dumps({
    "headers": ["Config", "Patches", "Tokens", "Attn Cost"],
    "rows": rows
})

print("Smaller patches = more tokens = finer detail but quadratic attention cost.")

Por qué los transformers escalan mejor que las U-Nets

El resultado principal del artículo DiT es una ley de escalado limpia: FID (Frechet Inception Distance, menor es mejor) mejora log-linealmente con el cómputo medido en GFLOPs. Los autores entrenaron cuatro tamaños de modelo (DiT-S, DiT-B, DiT-L, DiT-XL) a dos tamaños de parche (2 y 4) y mostraron que graficar GFLOPs contra FID produce una línea casi recta en escala logarítmica. Duplicar el presupuesto de cómputo del modelo consistía en reducir a la mitad la brecha hacia el FID perfecto.

¿Por qué las U-Nets no escalan tan elegantemente? Varias restricciones arquitectónicas las frenan:

  • Jerarquía de resolución fija: Las U-Nets submuestrean a través de un conjunto fijo de resoluciones (por ejemplo, 32 -> 16 -> 8 -> 4), luego sobremuestrean de vuelta. Cada etapa de resolución tiene su propio conjunto de convoluciones. Añadir capacidad significa añadir más canales o más bloques en cada etapa, pero la estructura multi-resolución misma restringe cómo fluye la información.
  • Atención limitada: En la práctica, la auto-atención en U-Nets solo se aplica en las resoluciones más bajas (8x8 o 16x16) porque la atención a resolución completa es demasiado costosa con mapas de características basados en convolución. La mayor parte de la red depende de convoluciones locales.
  • Retornos decrecientes: Escalar U-Nets más allá de aproximadamente 2-3 mil millones de parámetros mostró mejoras decrecientes. La jerarquía convolucional se convierte en un cuello de botella — más canales no ayudan si la información solo puede fluir a través de campos receptivos locales en las resoluciones altas.

Los transformers evitan todos estos problemas. Cada token atiende a todos los demás tokens en cada capa — no hay jerarquía forzada ni restricción de localidad. La información de cualquier posición espacial puede influir en cualquier otra posición en cada capa. Esta atención global en cada capa es lo que hace a los transformers más eficientes en parámetros: añadir parámetros (más profundo o más ancho) mejora de manera confiable la capacidad del modelo para eliminar ruido, sin techo arquitectónico.

Los números lo confirmaron. DiT-XL/2 (675M parámetros) logró un FID de vanguardia de 2.27 en generación condicional por clase de ImageNet 256x256, superando al anterior mejor modelo basado en U-Net (ADM, ~554M parámetros, FID 4.59) siendo arquitectónicamente más simple. No fue una mejora marginal — fue una reducción casi a la mitad del FID con un intercambio directo de arquitectura.

import json, js

models = [
    ("DiT-S/2",  "33",   "6.1",   "68.40"),
    ("DiT-B/2",  "130",  "23.0",  "43.50"),
    ("DiT-L/2",  "458",  "80.7",  "9.60"),
    ("DiT-XL/2", "675",  "118.6", "2.27"),
    ("ADM (U-Net)", "~554", "~1120", "4.59"),
]

js.window.py_table_data = json.dumps({
    "headers": ["Model", "Params (M)", "GFLOPs", "FID"],
    "rows": [list(m) for m in models]
})

print("DiT-XL/2 achieves lower FID than ADM with ~10x fewer GFLOPs per sample.")
💡 El resultado de la ley de escalado fue la contribución clave. Los números individuales de FID serán superados por modelos futuros, pero el hallazgo de que los eliminadores de ruido tipo transformer siguen escalado log-lineal predecible — el mismo patrón visto en modelos de lenguaje tipo GPT — le dijo al campo que invertir en eliminadores de ruido transformer más grandes rendiría de manera confiable.

De condicionamiento por clase a condicionamiento por texto en DiT

El DiT original era condicional por clase : generaba imágenes de clases de ImageNet ("golden retriever", "volcán", "espresso") alimentando una etiqueta de clase en AdaLN-Zero. Pero los sistemas reales de texto a imagen necesitan condicionamiento por texto abierto — un usuario escribe "un gato con sombrero de copa en la luna" y el modelo debe entender y renderizar descripciones arbitrarias. ¿Cómo condicionamos un DiT en texto en lugar de una etiqueta de clase?

Surgieron dos enfoques principales:

1. Atención cruzada (el enfoque U-Net). Así es como Stable Diffusion 1/2/XL condicionaba en texto: codificar el prompt con un codificador de texto (CLIP o T5), luego insertar capas de atención cruzada donde los tokens de imagen atienden a los tokens de texto. Los tokens de imagen forman las queries, los tokens de texto forman los keys y values. Esto funciona bien pero significa que la información de texto solo entra a la red a través de estas capas periódicas de atención cruzada — entre ellas, los tokens de imagen se procesan solos.

2. Atención conjunta (MMDiT). Stable Diffusion 3 (Esser et al., 2024) introdujo el Multimodal Diffusion Transformer (MMDiT) , que toma un enfoque fundamentalmente diferente: concatenar los tokens de texto y los tokens de imagen en una única secuencia y procesarlos a través de las mismas capas de auto-atención. Ambas modalidades atienden a ambas modalidades en cada capa.

En MMDiT, cada modalidad tiene sus propios pesos de proyección para queries, keys y values — $W_Q^{\text{text}}, W_K^{\text{text}}, W_V^{\text{text}}$ para tokens de texto y $W_Q^{\text{img}}, W_K^{\text{img}}, W_V^{\text{img}}$ para tokens de imagen — pero comparten el mismo cálculo de atención. Después de proyectar, los keys y values de ambas modalidades se concatenan, de modo que cada token (texto o imagen) atiende a todos los demás tokens (texto e imagen):

$$\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{Q \, [K^{\text{text}}; K^{\text{img}}]^\top}{\sqrt{d_k}}\right) [V^{\text{text}}; V^{\text{img}}]$$

donde $[\cdot\,;\,\cdot]$ denota concatenación a lo largo de la dimensión de secuencia. Esto significa que las características de texto e imagen se integran en cada capa , no solo en puntos periódicos de inserción de atención cruzada. El costo es una secuencia más larga (tokens de texto + imagen), lo que aumenta el costo cuadrático de atención. El beneficio es una alineación texto-imagen más estrecha — cada capa puede refinar la relación entre lo que dice el texto y lo que muestra la imagen.

Flux (Black Forest Labs, 2024) llevó esto aún más lejos con bloques de flujo único donde los tokens de texto e imagen comparten los mismos pesos de proyección completamente (sin $W_Q^{\text{text}}$ y $W_Q^{\text{img}}$ separados). Las dos modalidades se tratan como una secuencia unificada de principio a fin. Flux usa una arquitectura híbrida: la primera mitad de sus capas son estilo MMDiT (proyecciones separadas), y la segunda mitad son de flujo único (proyecciones compartidas), fusionando progresivamente las modalidades.

💡 ¿Por qué pesos de proyección separados por modalidad? Los tokens de texto e imagen viven en diferentes espacios de representación (texto de un codificador de lenguaje, imágenes de un VAE). Proyecciones separadas permiten que cada modalidad se mapee a un espacio de atención compartido a su manera. Los bloques de flujo único asumen que las representaciones han sido suficientemente alineadas por capas anteriores.

El impacto práctico

DiT no solo mejoró un número de benchmark. Cambió la trayectoria de todo el campo de generación de imágenes y vídeo. La progresión de arquitecturas de eliminación de ruido cuenta la historia:

  • Stable Diffusion 1.x / 2.x / XL: eliminador de ruido U-Net con atención cruzada para condicionamiento de texto. La arquitectura que llevó la difusión a la corriente principal.
  • Stable Diffusion 3 / 3.5: eliminador de ruido MMDiT (transformer). Cambió de U-Net a backbone basado en transformer, con atención conjunta texto-imagen.
  • Flux: basado en transformer, inspirado en DiT. Arquitectura híbrida MMDiT + flujo único.
  • Sora: DiT espaciotemporal para vídeo. Extiende DiT a 3D tratando los cuadros de vídeo como tokens espaciales adicionales, permitiendo la generación de clips de vídeo temporalmente coherentes.

El cambio ocurrió por varias razones que se refuerzan mutuamente. Primero, mejor escalado : el artículo DiT demostró que invertir en modelos más grandes rinde de manera predecible, que es exactamente la señal que las empresas necesitan para justificar entrenamientos que cuestan millones de dólares. Segundo, arquitectura más simple : un transformer plano con auto-atención es arquitectónicamente más simple que una U-Net con su jerarquía convolucional multi-resolución, conexiones residuales y tipos de bloques heterogéneos. Arquitecturas más simples son más fáciles de depurar, optimizar y paralelizar. Tercero, reutilización de infraestructura : todo el stack de software GPU (FlashAttention, paralelismo de tensores, paralelismo de secuencia, gradient checkpointing) fue construido para modelos de lenguaje basados en transformers. Los modelos estilo DiT pueden aprovechar directamente todo esto sin adaptación.

Una advertencia importante: el VAE no cambia . DiT reemplazó solo el eliminador de ruido — el componente que toma latentes ruidosos y predice el ruido (o la velocidad). El VAE que comprime imágenes del espacio de píxeles a latentes y decodifica latentes de vuelta a píxeles sigue siendo el mismo autoencoder convolucional usado en la difusión latente. El espacio latente es la interfaz: el VAE lo produce, el transformer elimina ruido dentro de él, y el VAE decodifica el resultado limpio de vuelta a píxeles.

La lección más profunda del artículo DiT trata sobre universalidad . La misma arquitectura que escala para la predicción del siguiente token en lenguaje también escala para la predicción de ruido en imágenes. Esta convergencia sugiere que el transformer no es específicamente bueno en lenguaje — es bueno aprendiendo de datos, y las leyes de escalado se mantienen independientemente de lo que esos datos representan.

Quiz

Pon a prueba tu comprensión de la arquitectura del Diffusion Transformer y su impacto.

En DiT, ¿cuál es el propósito de inicializar el parámetro de compuerta $\alpha$ a cero en AdaLN-Zero?

¿Cómo difiere MMDiT (usado en Stable Diffusion 3) del condicionamiento por atención cruzada?

¿Cuál fue el resultado clave de escalado demostrado por el artículo DiT?

Cuando el campo pasó de eliminadores de ruido U-Net a DiT, ¿qué componente del pipeline de difusión latente se mantuvo igual?