¿Necesita cada token ver a todos los demás?
La atención completa (densa) estándar calcula una puntuación entre cada par de tokens en la secuencia. Para una secuencia de $n$ tokens, eso significa $n^2$ puntuaciones de atención, $n^2$ entradas en la matriz de atención, y $O(n^2)$ en tiempo y memoria. Cuando $n = 4{,}096$, son aproximadamente 16 millones de puntuaciones. Con $n = 128{,}000$ (un contexto de la clase de Llama 3), son más de 16 mil millones.
¿Pero es realmente útil todo ese cómputo? Empíricamente, la mayoría de los pesos de atención se concentran en un pequeño subconjunto de posiciones: tokens cercanos (contexto local) y un puñado de tokens globalmente importantes (inicio de secuencia, puntuación, anclas específicas de la tarea). La gran mayoría de las $n^2$ puntuaciones son cercanas a cero y contribuyen casi nada a la salida. Si pudiéramos omitir el cálculo de esas puntuaciones cercanas a cero, ahorraríamos la mayor parte del trabajo sin cambiar significativamente el resultado.
Esa es la idea central detrás de los patrones de atención dispersa : en lugar de dejar que cada token atienda a todos los demás, definimos un patrón que selecciona qué pares (query, key) se calculan realmente. El resto se tratan como si su puntuación fuera $-\infty$ (enmascarados antes del softmax). Los ahorros dependen del patrón: cuanto más dispersa la máscara, menos puntuaciones calculamos, y más rápida (y eficiente en memoria) se vuelve la capa.
Atención de ventana deslizante
El patrón disperso más simple es la ventana deslizante : cada token solo atiende a los $w$ tokens más recientes (incluyéndose a sí mismo). Si el token está en la posición $i$, ve las posiciones $\max(0, \, i - w + 1)$ hasta $i$. Todo lo anterior a esa ventana se enmascara.
Verifiquemos esto en los extremos. Cuando $w = n$ (la ventana iguala la secuencia completa), cada token ve a todos los demás y recuperamos la atención completa en $O(n^2)$. Cuando $w = 1$, cada token solo se ve a sí mismo y la atención degenera en una operación por punto en $O(n)$. En la práctica, modelos como Mistral 7B usan $w = 4{,}096$: cada token atiende a las últimas 4,096 posiciones, dando un costo de $O(4096 \cdot n)$ por capa, que es lineal en $n$.
Pero si cada capa solo puede ver $w$ tokens hacia atrás, ¿cómo maneja el modelo las dependencias de largo alcance? A través de la propagación de información a través de capas . Después de una capa, el token $i$ tiene información sobre los tokens $[i - w + 1, \, i]$. Después de dos capas, el token $i$ tiene información indirecta sobre los tokens $[i - 2w + 2, \, i]$, porque los tokens en su ventana ya atendieron a sus propias ventanas en la capa anterior. Después de $L$ capas, el campo receptivo efectivo es:
Para Mistral 7B ($L = 32$, $w = 4{,}096$): $32 imes 4{,}096 = 131{,}072$ tokens. Con $L = 1$, un token solo puede acceder a $w = 4{,}096$ posiciones. Con $L = 32$, información de hasta 131K tokens atrás puede influir en la salida, aunque se atenúa cada vez más con cada salto (cada capa mezcla pero también diluye señales). Por esto los modelos de ventana deslizante pueden manejar contextos mucho más allá de $w$ a costa de fidelidad reducida para tokens muy distantes.
Un beneficio práctico importante es que el KV cache está acotado. En lugar de almacenar en caché todos los $n$ pares key-value pasados por capa (lo que crece linealmente con los tokens generados), solo necesitamos almacenar las últimas $w$ entradas. Para Mistral con $w = 4{,}096$, el KV cache se fija en 4,096 entradas por capa independientemente de la longitud total de la secuencia. Esto se implementa como un buffer circular (buffer rotativo) : las nuevas entradas sobreescriben las más antiguas en la posición $i \mod w$. Sin reasignación de memoria, sin caché creciente. Para una discusión en profundidad sobre la gestión del KV cache y la inferencia con ventana deslizante, consulta nuestro artículo sobre KV cache .
# Sliding window attention: rolling buffer KV cache
w = 8 # window size
class RollingKVCache:
def __init__(self, window_size):
self.w = window_size
self.buffer = [None] * window_size
self.count = 0
def insert(self, kv_entry):
pos = self.count % self.w # circular index
evicted = self.buffer[pos]
self.buffer[pos] = kv_entry
self.count += 1
return pos, evicted
def active_entries(self):
return [e for e in self.buffer if e is not None]
cache = RollingKVCache(window_size=w)
# Simulate inserting 12 tokens into a window of size 8
for token_id in range(12):
pos, evicted = cache.insert(f"tok_{token_id}")
if evicted:
print(f"Step {token_id:2d}: insert tok_{token_id} at slot {pos}, evicted {evicted}")
else:
print(f"Step {token_id:2d}: insert tok_{token_id} at slot {pos}")
print(f"
Final buffer (window={w}): {cache.buffer}")
print(f"Tokens 0-3 were evicted; only the last {w} remain.")
La limitación es clara: no hay un camino de atención directa entre tokens distantes. Un token en la posición 50,000 no puede atender directamente a un token en la posición 0. La información debe saltar a través de representaciones intermedias, capa por capa. Cada salto atenúa la señal, así que aunque el campo receptivo teórico es $L imes w$, la influencia práctica de tokens muy distantes es débil. Patrones como los tokens globales (cubiertos a continuación) abordan esto.
Longformer y BigBird: Dispersión estructurada
Las ventanas deslizantes funcionan bien para el contexto local, pero algunas tareas demandan conexiones genuinas de largo alcance: clasificar un documento legal basándose en una cláusula enterrada a miles de tokens del token [CLS], o responder una pregunta cuya evidencia abarca múltiples párrafos. Dos artículos influyentes introdujeron patrones dispersos estructurados que combinan atención local y global.
Longformer (Beltagy et al., 2020) combina tres patrones de atención en una sola capa:
- Ventana deslizante: cada token atiende a $w$ vecinos en cada lado. Esto captura contexto sintáctico y semántico local, igual que el patrón que describimos arriba.
- Ventana deslizante dilatada: en lugar de atender a $w$ vecinos contiguos, atiende a cada segundo (o $d$-ésimo) token dentro de un rango mayor. Esto es análogo a las convoluciones dilatadas en CNNs: al saltar tokens, la ventana cubre un rango más amplio ($w imes d$ posiciones) al mismo costo computacional que una ventana estándar de tamaño $w$.
- Atención global: un pequeño conjunto de tokens designados (por ejemplo, el token [CLS], o el primer token de cada párrafo) atiende a todas las posiciones, y todas las posiciones atienden de vuelta a ellos. Si tenemos $g$ tokens globales, esto añade $O(g \cdot n)$ puntuaciones.
Los tokens globales son la idea clave. Actúan como cuellos de botella de información que crean atajos para el flujo de información de largo alcance. Sin ellos, una señal desde la posición 0 debe saltar a través de $\lceil n/w \rceil$ capas para alcanzar la posición $n$. Con un solo token global, cualquier posición puede alcanzar cualquier otra en máximo 2 saltos (posición $\rightarrow$ token global $\rightarrow$ posición objetivo). Como $g$ es pequeño (a menudo solo 1 o 2), el costo total permanece en $O(n \cdot w + g \cdot n) = O(n \cdot (w + g))$, que es lineal en $n$.
BigBird (Zaheer et al., 2020) toma un enfoque diferente para asegurar la conectividad. Combina:
- Ventana deslizante: igual que Longformer, para contexto local.
- Tokens globales: la misma idea, unos pocos tokens atienden a/desde todas las posiciones.
- Atención aleatoria: cada token adicionalmente atiende a $r$ posiciones elegidas aleatoriamente. Este es el componente novedoso. Las aristas aleatorias aseguran que el grafo de atención esté bien conectado: con alta probabilidad, dos tokens cualesquiera están separados por un camino corto (logarítmico en $n$) a través del grafo.
La contribución teórica de BigBird es demostrar que esta combinación de atención local + global + aleatoria es un aproximador universal : puede aproximar cualquier función que la atención completa pueda calcular, siempre que $g$ y $r$ se configuren apropiadamente. Las aristas aleatorias son críticas para este resultado. Desde la teoría de grafos, un grafo con conexiones locales más unas pocas aristas aleatorias de largo alcance es un grafo expansor con alta probabilidad, lo que significa que dos nodos cualesquiera están conectados por un camino corto. En términos de atención: dos tokens cualesquiera pueden intercambiar información a través de solo unos pocos saltos de atención.
Ambos modelos logran $O(n)$ complexity (con constantes que dependen del tamaño de ventana $w$, número de tokens globales $g$ y aristas aleatorias $r$). Para ver por qué, cuenta las aristas de atención por token: $w$ de la ventana deslizante, $g$ de tokens globales (constante) y $r$ de atención aleatoria (constante). El total por token es $w + g + r$, todas constantes independientes de $n$, así que el total a través de los $n$ tokens es $O(n \cdot (w + g + r)) = O(n)$.
# Visualise the three attention patterns side by side
import json, js
n = 16 # sequence length (small for visualisation)
w = 3 # window half-size
g_idxs = [0] # global token indices
r = 2 # random edges per token
import random
random.seed(42)
def make_mask(n, w, g_idxs, r):
"""Build BigBird-style attention mask: local + global + random."""
mask = [[0]*n for _ in range(n)]
local_count = 0
global_count = 0
random_count = 0
for i in range(n):
# Sliding window
for j in range(max(0, i - w), min(n, i + w + 1)):
if mask[i][j] == 0:
mask[i][j] = 1
local_count += 1
# Global tokens
for g in g_idxs:
if mask[i][g] == 0:
mask[i][g] = 1
global_count += 1
if mask[g][i] == 0:
mask[g][i] = 1
global_count += 1
# Random
candidates = [j for j in range(n) if mask[i][j] == 0]
chosen = random.sample(candidates, min(r, len(candidates)))
for j in chosen:
mask[i][j] = 2 # mark as random (for colour)
random_count += 1
return mask, local_count, global_count, random_count
mask, lc, gc, rc = make_mask(n, w, g_idxs, r)
full_attention_scores = n * n
sparse_scores = sum(1 for row in mask for v in row if v > 0)
print(f"Sequence length: {n}")
print(f"Window half-size: {w} (each token sees {2*w+1} neighbors)")
print(f"Global tokens: {g_idxs}")
print(f"Random edges per token: {r}")
print(f"")
print(f"Full attention scores: {full_attention_scores}")
print(f"BigBird sparse scores: {sparse_scores}")
print(f" - Local (window): {lc}")
print(f" - Global: {gc}")
print(f" - Random: {rc}")
print(f"Sparsity: {1 - sparse_scores/full_attention_scores:.1%} of scores skipped")
Atención dilatada y con paso
En lugar de usar el mismo patrón de atención en cada capa, algunas arquitecturas varían el patrón entre capas para construir un campo receptivo jerárquico . La idea está tomada de las convoluciones dilatadas en CNNs: las capas iniciales miran detalles locales de grano fino, mientras que las capas más profundas miran contexto más amplio con resolución más gruesa.
Concretamente, considera un modelo con 4 capas y una ventana base de $w = 4{,}096$ tokens:
- Layer 0: atiende a los 4,096 tokens más cercanos (paso 1). Contexto local denso.
- Layer 1: atiende a cada segundo token hasta 8,192 posiciones (paso 2). Rango más amplio, mitad de la densidad.
- Layer 2: atiende a cada cuarto token hasta 16,384 posiciones (paso 4). Aún más amplio, aún más disperso.
- Layer 3: atiende a cada octavo token hasta 32,768 posiciones (paso 8). Cobertura global gruesa.
Cada capa calcula exactamente $w$ puntuaciones de atención por token (el tamaño de ventana es fijo), así que el costo por capa siempre es $O(n \cdot w)$, lineal en $n$. Pero el rango efectivo se duplica con cada capa. Con paso $s$ y ventana $w$, la atención abarca $w imes s$ posiciones. En los extremos: paso $s = 1$ da la ventana local estándar ($w$ posiciones); paso $s = n/w$ cubre la secuencia completa con exactamente $w$ muestras. A medida que el paso crece, intercambiamos precisión (saltando tokens intermedios) por alcance (cubriendo más de la secuencia).
La combinación entre capas es poderosa: la capa 0 captura el orden exacto de palabras locales, la capa 1 captura estructura a nivel de párrafo, la capa 2 captura estructura a nivel de sección, y la capa 3 captura estructura a nivel de documento. Al apilar estas, el modelo obtiene detalle local denso y contexto global disperso, todo a un costo total de $O(n \cdot w)$ por capa. La ventana dilatada del artículo de Longformer es un caso especial de esta idea aplicada dentro de una sola capa.
# Dilated attention: same compute per layer, increasing reach
w = 8 # tokens attended per layer (budget)
layers = [
{"name": "Layer 0", "stride": 1},
{"name": "Layer 1", "stride": 2},
{"name": "Layer 2", "stride": 4},
{"name": "Layer 3", "stride": 8},
]
print(f"Window budget per layer: {w} tokens")
print(f"{'Layer':<10} {'Stride':<8} {'Range (positions)':<20} {'Cost (scores/token)'}")
print("-" * 60)
for layer in layers:
s = layer["stride"]
reach = w * s
print(f"{layer['name']:<10} {s:<8} {reach:<20} {w}")
total_range = w * layers[-1]["stride"]
print(f"
After {len(layers)} layers, the model covers {total_range} positions")
print(f"Total cost per token: {w * len(layers)} scores ({w} per layer x {len(layers)} layers)")
print(f"Full attention would cost: {total_range} scores per token")
Ring Attention: Distribución entre GPUs
Todos los patrones anteriores reducen el número de puntuaciones de atención calculadas. Pero para secuencias muy largas (millones de tokens), hay un problema más fundamental: la secuencia no cabe en la memoria de una sola GPU en absoluto. Incluso con atención de ventana deslizante, almacenar las activaciones para millones de tokens durante el entrenamiento excede la memoria de cualquier dispositivo individual. Necesitamos distribuir la secuencia misma a través de múltiples GPUs.
Ring Attention (Liu et al., 2023) hace exactamente esto. Se organizan $d$ GPUs en un anillo lógico. La secuencia de entrada de $n$ tokens se divide en $d$ bloques contiguos de $n/d$ tokens cada uno. Cada GPU $i$ contiene:
- Su bloque Q local: los vectores de query para sus $n/d$ tokens. Este permanece en la GPU y nunca se mueve.
- Un bloque KV: los vectores key-value para algún fragmento de la secuencia. Este rota alrededor del anillo.
El algoritmo procede en $d$ rondas. En cada ronda, cada GPU calcula la atención entre su bloque Q local y el bloque KV actual que sostiene, produciendo una salida de atención parcial. Luego, todas las GPUs simultáneamente envían su bloque KV a la siguiente GPU en el anillo y reciben el bloque KV de la GPU anterior. Después de $d$ rondas, cada bloque Q ha atendido a cada bloque KV en la secuencia.
La clave de la eficiencia es solapar comunicación con cómputo . Mientras la GPU $i$ está calculando la atención para el bloque KV actual, simultáneamente está enviando ese bloque a la GPU $i+1$ y recibiendo el siguiente bloque de la GPU $i-1$. Siempre que el tiempo de cómputo para un bloque sea mayor o igual al tiempo de transferencia, el costo de comunicación queda completamente oculto.
La memoria por GPU escala como $O(n/d)$: cada GPU solo almacena queries para $n/d$ tokens y un bloque KV de tamaño $n/d$ a la vez. Con $d$ GPUs, podemos procesar secuencias $d$ veces más largas de lo que una sola GPU podría almacenar. Así es como modelos como Gemini procesan contextos de más de 10M de tokens: distribuyendo la secuencia a través de cientos de TPUs en un anillo, con cada dispositivo manejando un fragmento manejable.
# Ring Attention simulation: d GPUs, n tokens
n = 24 # total sequence length (tokens)
d = 4 # number of GPUs
block = n // d # tokens per GPU
print(f"Sequence length: {n} tokens")
print(f"GPUs: {d}")
print(f"Block size: {block} tokens per GPU")
print(f"Memory per GPU: O({n}/{d}) = O({block}) instead of O({n})")
print()
# Simulate the ring
for round_num in range(d):
print(f"Round {round_num}:")
for gpu in range(d):
kv_source = (gpu - round_num) % d # which GPU's KV block we're computing with
q_range = f"tokens [{gpu*block}-{(gpu+1)*block - 1}]"
kv_range = f"tokens [{kv_source*block}-{(kv_source+1)*block - 1}]"
print(f" GPU {gpu}: Q={q_range} x KV={kv_range}")
if round_num < d - 1:
print(f" >> All GPUs pass KV block to next neighbor")
print()
print(f"After {d} rounds: every Q block has seen every KV block")
print(f"Result: EXACT full attention, computed in distributed fashion")
print(f"Communication: {d-1} KV block transfers per GPU (overlapped with compute)")
Elegir el patrón adecuado
Cada patrón intercambia simplicidad, ahorro de cómputo y flujo de información. La elección correcta depende de la tarea, la arquitectura del modelo y las restricciones de hardware. Aquí hay un resumen:
import json, js
rows = [
["Full attention", "O(n\u00b2)", "Exact", "Single GPU, short ctx", "GPT-2, BERT"],
["Sliding window", "O(n\u00b7w)", "Approx (local)", "Streaming / inference", "Mistral 7B (w=4096)"],
["Longformer", "O(n)", "Approx (local+global)", "Long-doc classification", "Longformer-4096"],
["BigBird", "O(n)", "Approx (local+global+random)", "Long-doc QA/NER", "BigBird-4096"],
["Dilated / strided", "O(n\u00b7w)", "Approx (hierarchical)", "Multi-scale context", "Various research"],
["Ring Attention", "O(n\u00b2/d)", "Exact (distributed)", "Training on 1M+ tokens", "Gemini, long-ctx training"],
]
js.window.py_table_data = json.dumps({
"headers": ["Pattern", "Complexity", "Accuracy", "Best For", "Example"],
"rows": rows
})
print("Complexity notes:")
print(" n = sequence length, w = window size, d = number of GPUs")
print(" Sliding window: linear in n for fixed w")
print(" Ring Attention: same total work as full attention, but memory is O(n/d) per GPU")
Algunas directrices prácticas:
- Para inferencia autorregresiva de LLM: la atención de ventana deslizante es la opción dominante. Acota el KV cache, simplifica la gestión de memoria y funciona bien para tareas conversacionales y generativas donde el contexto local importa más. Mistral y sus derivados han demostrado esto a escala.
- Para tareas de codificador en documentos largos: los patrones de Longformer y BigBird (local + tokens globales) siguen siendo opciones fuertes. Los tokens globales aseguran que las tareas de clasificación o extracción puedan recopilar señales de todo el documento.
- Para entrenar con secuencias extremadamente largas: Ring Attention es la estrategia de distribución predilecta. Permite atención exacta sobre contextos de millones de tokens distribuyendo la memoria entre GPUs. Se compone con cualquier patrón de atención local.
- Todos los patrones se componen con FlashAttention: FlashAttention maneja el cómputo de bloques de bajo nivel (mosaico, gestión de SRAM) independientemente de qué patrón disperso seleccione los bloques. Los patrones dispersos eligen qué calcular; FlashAttention optimiza cómo calcularlo.
Quiz
Pon a prueba tu comprensión de los patrones de atención eficientes.
En la atención de ventana deslizante con tamaño de ventana $w$ y $L$ capas, ¿cuál es el campo receptivo efectivo de la capa final?
¿Cuál es el propósito de los tokens globales en Longformer y BigBird?
¿Cómo se diferencia Ring Attention de los patrones de atención dispersa como ventana deslizante o BigBird?
¿Por qué BigBird incluye aristas de atención aleatorias además de atención local y global?