¿Por qué la atención estándar es tan lenta?
Cada técnica que hemos cubierto hasta ahora — la KV cache, la cuantización, el batching continuo, la decodificación especulativa — optimiza alrededor del mecanismo de atención sin cambiar cómo se calcula la atención en sí. Pero el cálculo estándar de puntuaciones de atención es en sí mismo un cuello de botella. Veamos exactamente por qué.
La atención estándar de producto punto escalado calcula:
Aquí $Q$, $K$ y $V$ son las matrices de consulta, clave y valor para una sola cabeza de atención, cada una con $s$ filas (una por token en la secuencia) y $d_k$ columnas (la dimensión de la cabeza). La primera operación, $QK^T$, multiplica una matriz $s \times d_k$ por una matriz $d_k \times s$, produciendo una matriz $s \times s$ de puntuaciones de atención brutas. Eso son $s^2$ elementos. Luego aplicamos softmax por filas para normalizar las puntuaciones en una distribución de probabilidad, produciendo otra matriz $s \times s$. Finalmente, multiplicamos por $V$ (también $s \times d_k$) para obtener la salida.
El escalamiento $s^2$ es el problema. Con $s = 2{,}048$ (un contexto corto según estándares modernos), la matriz de atención tiene alrededor de 4 millones de elementos — manejable. Pero con $s = 32{,}768$ (contexto de 32K, que LLaMA-3 y muchos modelos modernos soportan), la matriz de atención tiene $32{,}768^2 \approx 1.07 \times 10^9$ elementos — más de mil millones de entradas. En FP32, eso son aproximadamente 4 GB de memoria por cabeza, por capa . Un modelo con 32 cabezas y 32 capas necesitaría $32 \times 32 \times 4 = 4{,}096$ GB solo para almacenar las matrices de atención durante el prefill. Eso es obviamente imposible en cualquier GPU individual.
En la práctica, los frameworks usan FP16/BF16 (reduciendo la memoria a la mitad) y procesan cabezas secuencialmente o en grupos pequeños en lugar de todas a la vez, así que la memoria pico es mucho menor que el cálculo del peor caso anterior. Pero el problema fundamental permanece: la matriz de atención debe calcularse, almacenarse en algún lugar, leerse de vuelta para la normalización softmax , escribirse de nuevo como las puntuaciones normalizadas, y luego leerse una vez más para multiplicar por $V$. Cada uno de estos pasos involucra un viaje de ida y vuelta a la HBM (memoria de alto ancho de banda) — la memoria principal de la GPU. La aritmética en sí es rápida; lo lento es transportar la enorme matriz $s \times s$ de ida y vuelta a través del bus de memoria.
Cuantifiquemos el tráfico de memoria. Para una cabeza de atención, la implementación estándar realiza aproximadamente estas operaciones de HBM:
- Leer $Q$ y $K$ de la HBM ($2 \times s \times d_k$ elementos), calcular $QK^T$, escribir el resultado $s \times s$ a la HBM.
- Leer las puntuaciones $s \times s$ de la HBM, calcular softmax por filas, escribir las puntuaciones $s \times s$ normalizadas de vuelta a la HBM.
- Leer las puntuaciones $s \times s$ normalizadas y $V$ de la HBM, calcular la multiplicación de matrices, escribir la salida $s \times d_k$ a la HBM.
Son tres lecturas y tres escrituras de datos $O(s^2)$, totalizando $O(s^2)$ bytes de tráfico de HBM. El cómputo real (multiplicaciones de matrices y softmax) es también $O(s^2 d_k)$ FLOPs, pero en GPUs modernas el cómputo termina mucho antes de que se completen las transferencias de memoria. La relación de cómputo útil a tráfico de memoria — la intensidad aritmética — es demasiado baja. La GPU queda inactiva esperando datos, no esperando cálculos. Este es el cuello de botella cuadrático de los transformers: tanto la memoria como el cómputo escalan como $O(s^2)$, pero el ancho de banda de memoria, no los FLOPs, es la restricción dominante.
import json, js
seq_lengths = [512, 2048, 8192, 32768, 131072]
d_k = 128 # typical head dimension
bytes_per_elem = 2 # FP16
rows = []
for s in seq_lengths:
attn_elements = s * s
attn_mem_bytes = attn_elements * bytes_per_elem
qkv_mem_bytes = 3 * s * d_k * bytes_per_elem
# Total HBM traffic (approx): 3 reads + 3 writes of s*s, plus Q,K,V,O
hbm_traffic = 3 * 2 * attn_mem_bytes + 2 * qkv_mem_bytes
if attn_mem_bytes < 1024**2:
attn_str = f"{attn_mem_bytes / 1024:.0f} KB"
elif attn_mem_bytes < 1024**3:
attn_str = f"{attn_mem_bytes / 1024**2:.1f} MB"
else:
attn_str = f"{attn_mem_bytes / 1024**3:.1f} GB"
if hbm_traffic < 1024**3:
hbm_str = f"{hbm_traffic / 1024**2:.1f} MB"
else:
hbm_str = f"{hbm_traffic / 1024**3:.1f} GB"
rows.append([f"{s:,}", f"{attn_elements:,.0f}", attn_str, hbm_str])
js.window.py_table_data = json.dumps({
"headers": ["Seq Length", "Attention Elements", "Attention Matrix (FP16)", "Approx HBM Traffic (1 head)"],
"rows": rows
})
print("Attention matrix size and HBM traffic per head (FP16, d_k=128)")
print("At 32K context, a single head's attention matrix is ~2 GB")
print("At 128K context, it's ~32 GB — far exceeding any GPU's SRAM")
FlashAttention: Nunca materializar la matriz de atención
Si el cuello de botella es leer y escribir la matriz de atención $s \times s$ a la HBM, la solución es conceptualmente simple: no almacenarla ahí en absoluto. Esa es la idea central de FlashAttention (Dao et al., 2022) , luego refinada en FlashAttention-2 (Dao, 2023) . En lugar de calcular la matriz $s \times s$ completa en la HBM y luego leerla de vuelta repetidamente, FlashAttention calcula la atención en pequeños bloques que caben enteramente en SRAM (la memoria en el chip de la GPU, aproximadamente 20 MB en una A100). La matriz de atención completa nunca existe en la HBM — solo la salida final $O = \text{softmax}(QK^T / \sqrt{d_k}) \, V$ se escribe de vuelta.
Pero hay un obstáculo matemático. Softmax es una operación global: para normalizar la fila $i$ de la matriz de atención, necesitamos $\max_j(S_{ij})$ y $\sum_j \exp(S_{ij} - \max)$ a través de la fila completa . Si estamos procesando $K$ y $V$ en bloques (viendo solo un fragmento de cada fila a la vez), ¿cómo calculamos un softmax exacto sin ver la fila completa de una vez? FlashAttention resuelve esto con el truco de softmax en línea : mantener un máximo acumulado y una suma acumulada de exponenciales, y corregir los resultados parciales a medida que llegan nuevos bloques. Crucialmente, esto produce exactamente la misma salida que la atención estándar — no es una aproximación.
El algoritmo funciona de la siguiente manera:
- Paso 1: Dividir las entradas en bloques. Dividir $Q$ en bloques de $B_r$ filas y $K$, $V$ en bloques de $B_c$ filas. Los tamaños de bloque se eligen para que el conjunto de trabajo (un bloque de $Q$, un bloque de $K$, un bloque de $V$ y salidas parciales) quepa en la SRAM.
- Paso 2: Bucle exterior sobre bloques de Q. Cargar un bloque de filas de $Q$ en SRAM. Para este bloque, acumularemos la salida final.
- Paso 3: Bucle interior sobre bloques de K/V. Para cada bloque de $K$/$V$, cargarlo en SRAM, calcular las puntuaciones de atención parciales $S_{\text{block}} = Q_{\text{block}} K_{\text{block}}^T / \sqrt{d_k}$, actualizar el máximo y suma acumulados para el softmax en línea, calcular los pesos parciales de softmax, multiplicar por $V_{\text{block}}$, y acumular en la salida. Todo esto sucede en SRAM — sin escrituras a HBM para resultados intermedios.
- Paso 4: Escribir la salida final. Después de iterar sobre todos los bloques de $K$/$V$, la salida acumulada para este bloque de $Q$ está completa y es exacta. Escribirla a la HBM. Pasar al siguiente bloque de $Q$.
La observación crítica es que la matriz de atención $s \times s$ nunca se materializa completamente en ningún lugar. Cada bloque de puntuaciones se calcula en SRAM, se usa inmediatamente para la acumulación ponderada por softmax de $V$, y luego se descarta. Los únicos datos escritos a la HBM son la salida final $O$ (tamaño $s \times d_k$), más una pequeña cantidad de contabilidad (los valores máximos por fila y log-sum-exp, necesarios para el pase hacia atrás).
¿Qué ahorra esto? La complejidad de IO (bytes totales leídos y escritos a la HBM) cuenta la historia. Sea $M$ el tamaño de SRAM en elementos:
Analicemos la fórmula de FlashAttention con análisis de límites. El numerador $s^2 \cdot d_k$ refleja el trabajo total: aún calculamos todas las $s^2$ puntuaciones de atención (el cálculo es exacto, no aproximado), y cada una involucra productos punto de dimensión $d_k$. El denominador $M$ captura el beneficio del tiling — SRAM más grande significa bloques más grandes, lo que significa que cada bloque de filas de $Q$ puede emparejarse con más bloques de $K$/$V$ antes de necesitar recargar $Q$. La relación $d_k / M$ determina el factor de ahorro de IO.
Los casos límite son iluminadores. Si la SRAM fuera infinitamente grande ($M \to \infty$), podríamos cargar todo $Q$, $K$, $V$ en SRAM de una vez, calcular todo en el chip, y escribir solo la salida $O$. El IO sería $O(s \cdot d_k)$ — solo leer las entradas y escribir la salida, el mínimo posible. Este es el límite inferior IO-óptimo . En el otro extremo, si la SRAM fuera cero ($M \to 0$, significando que cada valor intermedio debe vivir en HBM), cada puntuación parcial necesitaría un viaje de ida y vuelta a la HBM y recuperaríamos el IO $O(s^2)$ de la atención estándar. En hardware real como la A100 ($M \approx 20$ MB, o aproximadamente $10^7$ elementos FP16) con $d_k = 128$, el factor de ahorro es $M / d_k \approx 80{,}000$. Esa es una reducción sustancial en tráfico de HBM, que es exactamente por qué FlashAttention logra aceleraciones de 2-4 veces en tiempo real a pesar de realizar el mismo número de FLOPs.
Más allá de la velocidad, FlashAttention también resuelve el problema de memoria. La atención estándar asigna $O(s^2)$ de memoria para la matriz de atención. FlashAttention solo necesita $O(s)$ de memoria adicional (para la salida más las estadísticas por fila). Con contexto de 32K, esa es la diferencia entre asignar miles de millones de elementos por cabeza versus unas pocas decenas de miles.
El código a continuación demuestra la idea clave conceptualmente: procesar la atención en bloques con un softmax acumulado (en línea), comparado con el enfoque estándar de matriz completa. Ambos producen la misma salida.
import math
# Simulate tiled attention with online softmax vs standard attention
# Small example: s=8 tokens, d_k=4, block_size=2
s, d_k, B = 8, 4, 2
# Deterministic "random" Q, K, V using simple formula
def make_matrix(rows, cols, seed):
return [[math.sin(seed + i * cols + j) * 0.5
for j in range(cols)] for i in range(rows)]
Q = make_matrix(s, d_k, 1.0)
K = make_matrix(s, d_k, 2.0)
V = make_matrix(s, d_k, 3.0)
def dot(a, b):
return sum(x * y for x, y in zip(a, b))
def matmul(A, B_T):
# A[m][k] x B_T[n][k] -> C[m][n]
return [[dot(A[i], B_T[j]) for j in range(len(B_T))] for i in range(len(A))]
# ── Standard attention (full s x s matrix) ──
scale = 1.0 / math.sqrt(d_k)
S = [[dot(Q[i], K[j]) * scale for j in range(s)] for i in range(s)]
# Softmax each row
O_standard = []
for i in range(s):
row_max = max(S[i])
exps = [math.exp(S[i][j] - row_max) for j in range(s)]
row_sum = sum(exps)
weights = [e / row_sum for e in exps]
out = [sum(weights[j] * V[j][d] for j in range(s)) for d in range(d_k)]
O_standard.append(out)
# ── Tiled attention with online softmax (FlashAttention-style) ──
O_tiled = [[0.0] * d_k for _ in range(s)]
row_max_all = [-float('inf')] * s
row_sum_all = [0.0] * s
for q_start in range(0, s, B):
q_end = min(q_start + B, s)
# Reset accumulators for this Q block
local_max = [-float('inf')] * (q_end - q_start)
local_sum = [0.0] * (q_end - q_start)
local_out = [[0.0] * d_k for _ in range(q_end - q_start)]
for k_start in range(0, s, B):
k_end = min(k_start + B, s)
for qi in range(q_end - q_start):
i = q_start + qi
scores = [dot(Q[i], K[j]) * scale for j in range(k_start, k_end)]
block_max = max(scores)
# Online softmax update
old_max = local_max[qi]
new_max = max(old_max, block_max)
# Rescale previous accumulator
correction = math.exp(old_max - new_max) if old_max != -float('inf') else 0.0
local_sum[qi] = local_sum[qi] * correction
for d in range(d_k):
local_out[qi][d] *= correction
# Add new block contribution
for idx, j in enumerate(range(k_start, k_end)):
w = math.exp(scores[idx] - new_max)
local_sum[qi] += w
for d in range(d_k):
local_out[qi][d] += w * V[j][d]
local_max[qi] = new_max
# Normalise and write output
for qi in range(q_end - q_start):
for d in range(d_k):
O_tiled[q_start + qi][d] = local_out[qi][d] / local_sum[qi]
# Compare outputs
max_diff = max(abs(O_standard[i][d] - O_tiled[i][d])
for i in range(s) for d in range(d_k))
print(f"Sequence length: {s}, Head dim: {d_k}, Block size: {B}")
print(f"Standard attention: full {s}x{s} = {s*s} element matrix in memory")
print(f"Tiled attention: {B}x{B} = {B*B} element blocks (never stores full matrix)")
print(f"Max absolute difference: {max_diff:.2e}")
print(f"Outputs match: {max_diff < 1e-10}")
FlashDecoding: Optimizando la fase de decodificación
FlashAttention fue diseñado para la fase de prefill , donde la matriz de consulta $Q$ tiene muchas filas (una por token del prompt). La estrategia de tiling paraleliza sobre bloques de $Q$: cada bloque de hilos de GPU maneja un subconjunto diferente de filas de consulta, y hay suficiente trabajo para llenar los multiprocesadores de streaming de la GPU. Pero ¿qué sucede durante la decodificación ? Como discutimos en el artículo 1, la fase de decodificación genera un token a la vez, así que $Q$ tiene exactamente una fila. Solo hay un "bloque" de $Q$ a procesar, lo que significa que todo el bucle exterior de FlashAttention se ejecuta en un solo bloque de hilos. En una GPU con 108 multiprocesadores de streaming (como la A100), 107 de ellos quedan inactivos.
Este problema de subutilización empeora a medida que crece la longitud del contexto. Durante la decodificación con una KV cache de 32K tokens, el único bloque de hilos debe iterar sobre todos los $32{,}768 / B_c$ bloques de clave-valor secuencialmente. El trabajo por bloque es pequeño (una sola fila de consulta multiplicada por $B_c$ claves), así que la GPU tiene miles de pasos secuenciales con paralelismo mínimo. El resultado: FlashAttention durante la decodificación es significativamente más lento que durante el prefill, en relación con el pico teórico.
FlashDecoding (introducido por Tri Dao y colaboradores) soluciona esto cambiando sobre qué dimensión paralelizamos . En lugar de paralelizar sobre filas de $Q$ (de las cuales solo hay una durante la decodificación), FlashDecoding paraleliza sobre la dimensión de longitud de secuencia KV . El algoritmo divide la KV cache en fragmentos a lo largo de la dimensión de secuencia y asigna cada fragmento a un bloque de hilos separado:
- Paso 1: Dividir la KV cache. Dividir los $s$ pares clave-valor en caché en $P$ fragmentos de aproximadamente $s / P$ tokens cada uno. Cada fragmento se asigna a un bloque de hilos de GPU separado.
- Paso 2: Calcular atención parcial por fragmento. Cada bloque de hilos carga la única fila de consulta y su fragmento KV asignado en SRAM, calcula las puntuaciones de atención parciales, ejecuta un softmax local (rastreando el máximo y suma locales), y produce un vector de salida parcial — la suma ponderada por softmax de valores dentro de ese fragmento.
- Paso 3: Reducción global. Combinar las $P$ salidas parciales en el resultado final. Esto requiere corregir el hecho de que cada fragmento calculó softmax con un máximo local (no global). La corrección usa el mismo truco de reescalado de softmax en línea: dado el máximo parcial $m_p$ y suma $l_p$ de cada fragmento $p$, calcular el máximo global $m = \max_p m_p$, reescalar cada salida parcial por $\exp(m_p - m) \cdot l_p$, y normalizar por la suma global corregida.
El paso de reducción añade una pequeña sobrecarga (combinar $P$ resultados parciales, cada uno de dimensión $d_k$), pero $P$ es típicamente unos pocos cientos como máximo, y $d_k$ es típicamente 128, así que el costo es despreciable comparado con el cálculo de atención en sí. El beneficio crítico es que todos los $P$ bloques de hilos se ejecutan en paralelo, utilizando completamente la GPU. Si tenemos 108 multiprocesadores de streaming y elegimos $P = 108$, cada SM está ocupado — una mejora dramática sobre el escenario de un solo bloque.
¿Cuánto ayuda esto en la práctica? La aceleración depende de la longitud del contexto. Para contextos cortos (unos pocos cientos de tokens), el único bloque de hilos de FlashAttention estándar puede procesar todos los bloques KV suficientemente rápido como para que la sobrecarga de paralelismo de FlashDecoding no valga la pena. Pero para contextos largos — 8K, 32K, 128K tokens — FlashDecoding proporciona aceleraciones sustanciales (hasta 8 veces para secuencias muy largas) porque convierte un escaneo secuencial sobre miles de bloques KV en una operación paralela a través de toda la GPU.
Esto importa porque la decodificación es el cuello de botella para la latencia de extremo a extremo en la mayoría de los escenarios de servicio. Como discutimos en el artículo 1, la fase de decodificación genera tokens de uno en uno y domina el tiempo real para cualquier respuesta más larga que unos pocos tokens. Hacer la atención de la fase de decodificación más rápida reduce directamente el tiempo por token, lo que reduce directamente la latencia percibida por el usuario.
Multi-Head Latent Attention (MLA)
En el artículo 2, vimos cómo Grouped-Query Attention (GQA) reduce la KV cache compartiendo proyecciones de clave-valor entre grupos de cabezas de consulta. GQA intercambia algo de capacidad representacional por ahorros de memoria, y funciona bien — LLaMA 2/3, Mistral, y la mayoría de los modelos modernos lo usan. Pero hay un enfoque fundamentalmente diferente: en lugar de compartir K y V entre cabezas, comprimirlos en un espacio latente de bajo rango. Eso es Multi-Head Latent Attention (MLA) , introducida en la arquitectura DeepSeek-V2 (DeepSeek-AI, 2024) .
En la atención multi-cabeza estándar (MHA), la KV cache almacena vectores de clave y valor separados para cada cabeza en cada posición de token. Para un modelo con $n_h$ cabezas y dimensión de cabeza $d_h$, eso son $2 \times n_h \times d_h$ valores por token por capa (el factor de 2 para claves y valores). Con GQA usando $n_{\text{kv}}$ grupos de clave-valor, esto baja a $2 \times n_{\text{kv}} \times d_h$. MLA toma un camino completamente diferente: en lugar de almacenar en caché los vectores completos de clave y valor, almacena en caché un solo vector comprimido $c_t \in \mathbb{R}^{d_c}$ por token, donde $d_c$ es la dimensión latente y $d_c \ll n_h \times d_h$.
La compresión funciona a través de un par de proyecciones aprendidas. Cuando el token $t$ se procesa por primera vez, en lugar de calcular y almacenar en caché $K_t$ y $V_t$ directamente, MLA calcula una proyección hacia abajo:
Esta representación comprimida $c_t$ es lo que se almacena en la KV cache — solo $d_c$ valores por token en lugar de $2 \times n_h \times d_h$. En tiempo de atención, las claves y valores se reconstruyen al vuelo mediante proyecciones hacia arriba:
donde $W_K^{\text{up}} \in \mathbb{R}^{d_c \times (n_h \cdot d_h)}$ y $W_V^{\text{up}} \in \mathbb{R}^{d_c \times (n_h \cdot d_h)}$ son matrices aprendidas que descomprimen el latente de vuelta a vectores completos de clave y valor para todas las cabezas.
Los ahorros en KV cache son dramáticos. Recorramos los números de DeepSeek-V2. El modelo tiene $n_h = 128$ cabezas de atención con $d_h = 128$, así que MHA estándar almacenaría en caché $2 \times 128 \times 128 = 32{,}768$ valores por token por capa. MLA usa $d_c = 512$, almacenando solo 512 valores por token por capa. Eso es una reducción de 64 veces en el tamaño de la KV cache — mucho más allá de lo que GQA logra.
import json, js
# Compare KV cache: MHA vs GQA vs MLA for DeepSeek-V2 scale
n_h = 128 # query heads
d_h = 128 # head dimension
n_kv_gqa = 8 # GQA groups (hypothetical)
d_c = 512 # MLA latent dimension
L = 60 # layers
b_prec = 2 # FP16 bytes
configs = {
"MHA": 2 * n_h * d_h, # full K + V per head
"GQA (8 groups)": 2 * n_kv_gqa * d_h,
"MLA (d_c=512)": d_c, # single compressed vector
}
seq_lengths = [2048, 8192, 32768, 131072]
rows = []
for name, per_token_vals in configs.items():
for s in seq_lengths:
cache_bytes = L * per_token_vals * s * b_prec
if cache_bytes < 1024**3:
size_str = f"{cache_bytes / 1024**2:.0f} MB"
else:
size_str = f"{cache_bytes / 1024**3:.1f} GB"
rows.append([name, str(per_token_vals), f"{s:,}", size_str])
js.window.py_table_data = json.dumps({
"headers": ["Attention Type", "Values/Token/Layer", "Seq Length", "KV Cache (FP16, 60 layers)"],
"rows": rows
})
mha_per_token = 2 * n_h * d_h
mla_per_token = d_c
print(f"MHA caches {mha_per_token:,} values per token per layer")
print(f"MLA caches {mla_per_token:,} values per token per layer")
print(f"Reduction factor: {mha_per_token / mla_per_token:.0f}x")
print(f"At 128K context, 60 layers: MHA needs ~{60 * mha_per_token * 131072 * 2 / 1024**3:.0f} GB, MLA needs ~{60 * mla_per_token * 131072 * 2 / 1024**3:.1f} GB")
La pregunta obvia: ¿el costo de descompresión no elimina los ahorros? Las proyecciones hacia arriba $W_K^{\text{up}}$ y $W_V^{\text{up}}$ multiplican un vector de dimensión $d_c$ por una matriz grande para reconstruir las claves y valores completos. Eso son FLOPs adicionales. Pero aquí está la idea crucial: durante la decodificación , el modelo está limitado por el ancho de banda de memoria , no por el cómputo (artículo 1). Las unidades aritméticas de la GPU están en gran parte inactivas, esperando a que los datos lleguen de la HBM. Añadir cómputo que reduce el tráfico de memoria es esencialmente gratis — la GPU hace las multiplicaciones de matrices extra durante tiempo que de otra manera pasaría esperando. Las matrices de pesos de proyección hacia arriba $W_K^{\text{up}}$ y $W_V^{\text{up}}$ son parte de los pesos del modelo (cargados una vez, compartidos entre todos los tokens), mientras que la caché comprimida $c_t$ es mucho más pequeña de lo que sería la KV cache completa. El efecto neto es menos tráfico total de memoria, aunque estemos haciendo más aritmética.
Durante el prefill (que está limitado por el cómputo), la descompresión extra sí añade costo medible. DeepSeek-V2 aborda esto absorbiendo la proyección hacia arriba en el cálculo de atención algebraicamente: en lugar de descomprimir $K$ y $V$ explícitamente y luego calcular la atención, el modelo reformula las ecuaciones de atención para operar directamente sobre las representaciones comprimidas. La matemática es equivalente, pero la implementación evita materializar los tensores $K$ y $V$ de tamaño completo, ahorrando tanto memoria como algo de cómputo.
Elegir optimizaciones de atención
Hemos cubierto cuatro optimizaciones de atención que apuntan a diferentes aspectos del problema. ¿Cómo se relacionan entre sí y cuándo deberías usar cada una?
FlashAttention es universal. Produce resultados exactos, usa menos memoria, y es más rápido que la atención estándar en cada escenario. No hay compensación y no hay razón para no usarlo. Si tu framework lo soporta (PyTorch 2.0+, HuggingFace Transformers, y virtualmente todos los motores de inferencia modernos lo hacen), debería estar habilitado por defecto. FlashAttention optimiza cómo se calcula la atención (tiling consciente de IO) sin cambiar qué se calcula.
Grouped-Query Attention (GQA) es una decisión arquitectónica tomada en tiempo de entrenamiento. Un modelo debe ser entrenado con GQA desde el inicio (o convertido mediante up-training, como demuestra el artículo de Ainslie et al.). No puedes aplicar retroactivamente GQA a un modelo entrenado con MHA estándar. GQA reduce la KV cache compartiendo proyecciones de clave-valor entre grupos de cabezas de consulta, típicamente dando una reducción de caché de 4-8 veces con pérdida de calidad despreciable. LLaMA-2 70B, LLaMA-3, Mistral 7B, Mixtral, Gemma, y la mayoría de los modelos post-2023 usan GQA.
Multi-Head Latent Attention (MLA) es también una decisión arquitectónica, y más reciente. Logra una compresión de KV cache mucho más agresiva que GQA (hasta 64 veces vs 8 veces) aprendiendo una representación latente de bajo rango. MLA es actualmente usada por las familias de modelos DeepSeek-V2 y DeepSeek-V3. Añade cómputo durante la descompresión, pero esto se compensa con los enormes ahorros de memoria durante la decodificación. A medida que más arquitecturas adopten MLA (o estrategias de compresión similares), puede volverse tan estándar como lo es GQA hoy.
FlashDecoding aborda un cuello de botella específico: la subutilización de la GPU durante la fase de decodificación con contextos largos. Es una optimización de runtime (como FlashAttention), no una decisión arquitectónica, y aplica a cualquier modelo. Proporciona las mayores aceleraciones para la decodificación con contextos largos — exactamente el escenario donde el cálculo de atención es más costoso en relación con el resto del modelo. Para contextos cortos, el beneficio es mínimo.
Crucialmente, estas técnicas se componen . FlashAttention maneja el tiling eficiente en IO. GQA o MLA reduce el tamaño de lo que se almacena y carga. FlashDecoding asegura que la GPU se utilice completamente al calcular atención durante la decodificación. Un stack de inferencia moderno típicamente combina todas las técnicas aplicables: por ejemplo, servir LLaMA-3 70B usa FlashAttention (tiling consciente de IO) + GQA (8 grupos de cabezas KV, reduciendo la caché por 8 veces) + FlashDecoding (decodificación paralela sobre la dimensión de secuencia KV) + PagedAttention de vLLM (gestión eficiente de memoria de caché) + cuantización INT8 de la KV cache. Cada técnica apunta a una parte diferente del problema, y sus beneficios se apilan multiplicativamente.
import json, js
rows = [
["FlashAttention",
"Runtime (IO-aware kernel)",
"Any model",
"2-4x speed, O(s) memory",
"None (exact, always faster)"],
["FlashDecoding",
"Runtime (decode parallelism)",
"Any model, long-context decode",
"Up to 8x decode speedup",
"Minimal for short contexts"],
["GQA",
"Architecture (training time)",
"LLaMA-2/3, Mistral, Gemma, etc.",
"4-8x KV cache reduction",
"Slight quality loss vs MHA"],
["MLA",
"Architecture (training time)",
"DeepSeek-V2/V3",
"Up to 64x KV cache reduction",
"Extra decompression FLOPs"],
]
js.window.py_table_data = json.dumps({
"headers": ["Technique", "Type", "Applies To", "Benefit", "Trade-off"],
"rows": rows
})
print("Attention optimisation summary")
print("FlashAttention + GQA/MLA + FlashDecoding is the standard stack")
print("All compose: each targets a different bottleneck")
Quiz
Pon a prueba tu comprensión de las técnicas de optimización de atención.
¿Por qué la atención estándar es lenta a pesar de que las GPUs modernas tienen enorme capacidad de cómputo?
¿Cuál es la propiedad clave que diferencia a FlashAttention de los métodos de atención aproximada?
¿Por qué FlashAttention tiene bajo rendimiento durante la fase de decodificación, motivando FlashDecoding?
En Multi-Head Latent Attention (MLA), ¿por qué el cómputo extra para descomprimir claves y valores durante la decodificación es esencialmente gratis?