¿Por qué recalcular lo que ya conoces?
Cuando un transformer genera texto, trabaja autorregresivamente : produce un token a la vez, alimentando cada nuevo token de vuelta para generar el siguiente. En el paso $t$, el modelo calcula puntuaciones de atención sobre todos los tokens anteriores $1, 2, \ldots, t-1$ para decidir a qué prestar atención al predecir el token $t$. Pero aquí está el problema: las claves y valores para los tokens $1$ hasta $t-1$ ya fueron calculados en pasos anteriores, y no han cambiado. El mecanismo de atención proyecta el embedding de cada token a través de matrices de pesos aprendidas $W_K$ y $W_V$ para producir vectores de clave y valor. Esas proyecciones dependen solo de la posición del token y los pesos fijos del modelo — no de lo que viene después. Así que cada vez que los recalculamos, estamos haciendo trabajo puramente redundante.
¿Cuánto desperdicio estamos generando? En el paso de generación $t$, recalcular los $t$ pares clave-valor cuesta $O(t \cdot d)$ operaciones (donde $d$ es la dimensión del modelo), pero solo el par del nuevo token es información nueva. Sobre una secuencia completa de longitud $s$, el cómputo redundante total es $O(1 + 2 + \ldots + s) = O(s^2)$, un costo cuadrático que crece dolorosamente a medida que las secuencias se alargan. Para una generación de 4,096 tokens, eso son aproximadamente 8 millones de re-proyecciones innecesarias a través de todos los pasos.
La KV cache elimina este desperdicio por completo. En lugar de recalcular claves y valores para cada token anterior en cada paso, los almacenamos la primera vez que se calculan y los reutilizamos. En el paso $t$, solo calculamos los vectores de consulta, clave y valor del nuevo token ($q_t$, $k_t$, $v_t$), añadimos $k_t$ y $v_t$ a la caché, y luego calculamos la atención entre $q_t$ y todas las claves en caché $[k_1, \ldots, k_t]$. El costo de generación por paso baja de $O(t \cdot d)$ para proyecciones a $O(d)$ — aún pagamos $O(t \cdot d)$ por los productos punto de atención (la nueva consulta debe atender a todas las claves en caché), pero hemos eliminado todo el costo de proyección redundante.
Cómo funciona la KV cache
Recorramos la mecánica paso a paso. En la atención multi-cabeza estándar, cada capa proyecta la entrada a través de tres matrices de pesos para producir consultas ($Q$), claves ($K$) y valores ($V$). Sin una KV cache, en el paso de generación $t$ el modelo debe calcular $Q$, $K$ y $V$ para los $t$ tokens en la secuencia y luego ejecutar el cálculo completo de atención. El costo de proyección solo es $O(t \cdot d)$ por matriz, y el producto punto de atención cuesta otros $O(t^2 \cdot d_{\text{head}})$ a través de todas las cabezas.
Con una KV cache, el panorama cambia drásticamente. En el paso $t$, calculamos $q_t$, $k_t$ y $v_t$ para solo el nuevo token — un costo fijo de proyección $O(d)$ independientemente de la longitud de la secuencia. Añadimos $k_t$ y $v_t$ a la caché (que ya contiene $[k_1, \ldots, k_{t-1}]$ y $[v_1, \ldots, v_{t-1}]$), y luego calculamos la atención: $q_t$ atiende a las $t$ claves en caché para producir una combinación ponderada de los $t$ valores en caché. Los productos punto de atención aún cuestan $O(t \cdot d_{\text{head}})$ por cabeza (una consulta contra $t$ claves), pero hemos eliminado el costo $O(t \cdot d)$ de recalcular todas las proyecciones anteriores.
¿Qué almacena exactamente la caché? Para cada capa del modelo, contiene dos tensores: los vectores de clave acumulados $[k_1, k_2, \ldots, k_t]$ y los vectores de valor acumulados $[v_1, v_2, \ldots, v_t]$. Cada vector individual de clave o valor tiene dimensión $d_{\text{head}}$, y hay $n_{\text{heads}}$ cabezas por capa, así que la caché almacena $2 \times n_{\text{heads}} \times d_{\text{head}}$ valores por token por capa (un vector de clave más un vector de valor para cada cabeza).
Dado que $n_{\text{heads}} \times d_{\text{head}} = d_{\text{model}}$, el costo de caché por token por capa se simplifica a $2 \times d_{\text{model}}$ valores. En FP16 (2 bytes por valor), eso son $4 d_{\text{model}}$ bytes por token por capa. El tamaño total de la KV cache a través de todo el modelo y la secuencia es:
Desglosemos cada variable:
- $2$ — un conjunto de claves y un conjunto de valores. Siempre almacenamos ambos porque la atención necesita claves para calcular puntuaciones y valores para calcular la salida ponderada.
- $L$ — número de capas del transformer. Cada capa tiene su propio mecanismo de atención con sus propias proyecciones $W_K$ y $W_V$, así que cada capa necesita su propia caché.
- $n_{\text{heads}}$ — número de cabezas de atención por capa. En la atención multi-cabeza estándar, cada cabeza tiene una proyección independiente de clave y valor.
- $d_{\text{head}}$ — dimensión de cada cabeza de atención. Típicamente $d_{\text{model}} / n_{\text{heads}}$.
- $s$ — longitud de la secuencia (número de tokens en caché hasta ahora). Esta es la dimensión que crece durante la generación y hace costosa la inferencia con contextos largos.
- $b_{\text{precision}}$ — bytes por valor. 2 para FP16/BF16, 4 para FP32, 1 para INT8.
Dado que $n_{\text{heads}} \times d_{\text{head}} = d_{\text{model}}$, podemos simplificar:
Ahora veamos qué significa esto en la práctica. Tomemos LLaMA-7B: $L = 32$ capas, $d_{\text{model}} = 4096$, en FP16 ($b_{\text{precision}} = 2$ bytes). Con una longitud de secuencia de 4,096 tokens:
Dos gigabytes de memoria de GPU solo para la KV cache de una sola solicitud. Los pesos del modelo para LLaMA-7B son aproximadamente 14 GB en FP16, así que con contexto de 4K la caché ya es el 14% del tamaño del modelo. Ahora llevemos la longitud de la secuencia a 32K tokens:
Dieciséis gigabytes — más grande que el modelo en sí. Y eso es por solicitud. Si estás sirviendo a 64 usuarios concurrentes, multiplica por 64: más de 1 TB de memoria de KV cache. Por eso la gestión de la KV cache es central para la optimización de inferencia: a escala, la caché domina la memoria de la GPU mucho más que los pesos del modelo.
El código a continuación calcula los tamaños de KV cache para varias arquitecturas de modelos populares en diferentes longitudes de secuencia, haciendo concreto el escalamiento:
import json, js
models = {
"LLaMA-7B": {"L": 32, "d_model": 4096, "n_kv_heads": 32, "d_head": 128},
"LLaMA-13B": {"L": 40, "d_model": 5120, "n_kv_heads": 40, "d_head": 128},
"LLaMA-70B": {"L": 80, "d_model": 8192, "n_kv_heads": 8, "d_head": 128},
"Mistral-7B": {"L": 32, "d_model": 4096, "n_kv_heads": 8, "d_head": 128},
"GPT-3 175B": {"L": 96, "d_model": 12288, "n_kv_heads": 96, "d_head": 128},
}
seq_lengths = [2048, 4096, 8192, 32768, 131072]
b_prec = 2 # FP16
rows = []
for name, cfg in models.items():
for s in seq_lengths:
kv_bytes = 2 * cfg["L"] * cfg["n_kv_heads"] * cfg["d_head"] * s * b_prec
if kv_bytes < 1024**3:
size_str = f"{kv_bytes / 1024**2:.0f} MB"
else:
size_str = f"{kv_bytes / 1024**3:.1f} GB"
rows.append([name, str(s), str(cfg["L"]), str(cfg["n_kv_heads"]),
str(cfg["d_head"]), size_str])
js.window.py_table_data = json.dumps({
"headers": ["Model", "Seq Len", "Layers", "KV Heads", "d_head", "KV Cache (FP16)"],
"rows": rows
})
print("KV cache sizes for popular models at various sequence lengths (FP16)")
print("Note: LLaMA-70B and Mistral-7B use GQA (fewer KV heads), so their caches are smaller than MHA equivalents")
El problema de memoria de la KV cache
La tabla anterior cuenta una historia clara: para secuencias largas, la KV cache se convierte en el consumidor dominante de memoria de GPU, eclipsando los pesos del modelo en sí. Pongamos números concretos al problema. Consideremos un modelo de 70B parámetros (como LLaMA-2 70B) con $L = 80$ capas, $d_{\text{model}} = 8192$, y atención multi-cabeza estándar con $n_{\text{heads}} = 64$. En FP16, los pesos del modelo ocupan aproximadamente 140 GB. Ahora calculemos la KV cache con contexto de 32K:
Ochenta gigabytes para la KV cache de una sola solicitud — más de la mitad de la huella de pesos del propio modelo. Y esto es por solicitud. Si quieres servir incluso a 8 usuarios concurrentes, la KV cache sola necesita $8 \times 80 = 640$ GB, superando con creces la memoria de cualquier GPU individual (una A100 tiene 80 GB, una H100 tiene 80 GB). Esto significa que en producción, el tamaño máximo de batch no está limitado por la huella de memoria del modelo — está limitado por cuántas KV caches caben en la memoria restante de la GPU después de cargar el modelo.
Esto crea una tensión directa entre rendimiento (servir a más usuarios simultáneamente) y longitud de contexto (soportar conversaciones o documentos más largos). Duplica la longitud de la secuencia y reduces a la mitad el número de solicitudes concurrentes que puedes servir. Por eso las extensiones de longitud de contexto son tan difíciles: soportar ventanas de contexto de 128K o 1M no es solo un desafío de modelado — es fundamentalmente un problema de gestión de memoria. La KV cache crece linealmente con la longitud de la secuencia, y la memoria de la GPU es finita.
Esta pared de memoria es lo que motiva casi todas las técnicas en el resto de este track:
- Grouped-Query Attention (GQA): reducir el número de cabezas KV para que haya menos que almacenar en caché por token (cubierto a continuación).
- PagedAttention: gestionar la memoria de la KV cache como un sistema operativo gestiona la memoria virtual — asignar y liberar en páginas para evitar la fragmentación (cubierto en el artículo 4 sobre batching continuo).
- Cuantización de la KV cache: almacenar claves y valores en caché en INT8 o INT4 en lugar de FP16, reduciendo la memoria 2-4 veces.
- Atención de ventana deslizante: limitar el tamaño de la caché atendiendo solo a los últimos $w$ tokens (cubierto en la Sección 5).
Atención Multi-Query y Grouped-Query
Si la KV cache es el cuello de botella, la solución más directa es reducir cuánto almacenamos en caché por token. En la atención multi-cabeza (MHA) estándar, cada una de las $n_{\text{heads}}$ cabezas de atención tiene sus propias proyecciones independientes de clave y valor. Eso significa que almacenamos $n_{\text{heads}}$ vectores de clave separados y $n_{\text{heads}}$ vectores de valor separados por token por capa. Pero ¿todas esas cabezas realmente necesitan sus propias claves y valores? Las proyecciones de consulta difieren por cabeza (cada cabeza aprende a hacer preguntas diferentes), pero quizás las claves y valores (la información que se busca) pueden compartirse.
Multi-Query Attention (MQA) (Shazeer, 2019) lleva esta idea a su extremo: todas las cabezas de atención comparten un único conjunto de proyecciones de clave y valor. Cada cabeza aún tiene su propia proyección de consulta $W_Q^{(h)}$, así que cada cabeza hace una pregunta diferente, pero todas buscan contra las mismas claves y recuperan de los mismos valores. La KV cache se reduce por un factor de $n_{\text{heads}}$, porque almacenamos solo un vector de clave y un vector de valor por token por capa en lugar de $n_{\text{heads}}$ de cada uno.
Para LLaMA-7B con 32 cabezas, esta es una reducción dramática. La caché se reduce de 2 GB (con contexto de 4K en FP16) a $2{,}048 / 32 = 64$ MB. Eso es un ahorro de 32 veces — suficiente para servir 32 veces más solicitudes concurrentes en la misma memoria de GPU, o extender la longitud de contexto por 32 veces. La compensación es la calidad: con un solo conjunto de claves y valores, las cabezas de atención ya no pueden especializar sus representaciones de clave-valor. En la práctica, MQA muestra una ligera degradación de calidad en comparación con MHA estándar, particularmente en tareas que requieren razonamiento multi-aspecto de grano fino.
Grouped-Query Attention (GQA) (Ainslie et al., 2023) encuentra el punto medio. En lugar de que todas las cabezas compartan un conjunto KV (MQA) o que cada cabeza tenga el suyo (MHA), GQA organiza las $n_{\text{heads}}$ cabezas de consulta en $n_{\text{kv\_heads}}$ grupos, donde cada grupo comparte un conjunto de claves y valores. Dentro de un grupo, múltiples cabezas de consulta (específicamente $n_{\text{heads}} / n_{\text{kv\_heads}}$ de ellas) atienden a las mismas claves y valores pero con diferentes proyecciones de consulta.
La fórmula de la caché se convierte en:
Observa que $n_{\text{kv\_heads}}$ ha reemplazado a $n_{\text{heads}}$ de la fórmula original. Los casos límite son reveladores: cuando $n_{\text{kv\_heads}} = n_{\text{heads}}$, cada cabeza de consulta tiene su propio par KV y recuperamos MHA estándar. Cuando $n_{\text{kv\_heads}} = 1$, todas las cabezas comparten un par KV y recuperamos MQA. Cualquier valor intermedio nos da GQA, con el factor de reducción de caché siendo $n_{\text{heads}} / n_{\text{kv\_heads}}$.
LLaMA-2 70B usa GQA con $n_{\text{kv\_heads}} = 8$ y $n_{\text{heads}} = 64$, dando una reducción de caché de 8 veces en comparación con MHA estándar. El artículo de Ainslie et al. demostró que GQA con un número moderado de grupos KV recupera casi toda la calidad de MHA completo mientras proporciona la mayoría de los ahorros de memoria de MQA. Esto ha hecho de GQA el estándar de facto para los modelos de lenguaje grandes modernos: LLaMA-2, LLaMA-3, Mistral, Mixtral, y muchos otros lo usan.
La tabla a continuación compara los tamaños de KV cache para modelos a escala LLaMA-7B bajo MHA, GQA y MQA, haciendo concretos los ahorros:
import json, js
# LLaMA-7B scale: L=32, d_model=4096, n_heads=32, d_head=128
L = 32
d_head = 128
n_heads = 32
b_prec = 2 # FP16
configs = {
"MHA (32 KV heads)": 32,
"GQA-8 (8 KV heads)": 8,
"GQA-4 (4 KV heads)": 4,
"MQA (1 KV head)": 1,
}
seq_lengths = [2048, 4096, 8192, 32768, 131072]
rows = []
for config_name, n_kv in configs.items():
for s in seq_lengths:
kv_bytes = 2 * L * n_kv * d_head * s * b_prec
if kv_bytes < 1024**3:
size_str = f"{kv_bytes / 1024**2:.0f} MB"
else:
size_str = f"{kv_bytes / 1024**3:.1f} GB"
reduction = n_heads / n_kv
rows.append([config_name, str(s), size_str, f"{reduction:.0f}x"])
js.window.py_table_data = json.dumps({
"headers": ["Attention Type", "Seq Len", "KV Cache (FP16)", "Reduction vs MHA"],
"rows": rows
})
print("KV cache comparison: MHA vs GQA vs MQA (LLaMA-7B scale, FP16)")
print(f"Model: L={L}, d_head={d_head}, n_query_heads={n_heads}")
print()
print("GQA-8 gives 4x reduction; MQA gives 32x but with potential quality loss")
Ventana deslizante y otras estrategias de caché
GQA reduce el costo de caché por token, pero la caché aún crece linealmente con la longitud de la secuencia. Para secuencias muy largas (32K, 128K, o más), incluso una caché comprimida con GQA puede exceder la memoria disponible. La atención de ventana deslizante toma un enfoque fundamentalmente diferente: en lugar de almacenar en caché toda la secuencia, solo atiende a los $w$ tokens más recientes. El tamaño de la caché está limitado a $w$ entradas por capa independientemente de cuán larga sea la secuencia total.
Mistral 7B usa atención de ventana deslizante con $w = 4{,}096$. En cualquier paso de generación $t$, el modelo solo atiende a tokens en posiciones $\max(1, t - w + 1)$ hasta $t$. La caché contiene como máximo $w$ pares clave-valor por capa, así que el tamaño máximo de caché es:
Observa que $s$ (la longitud de secuencia siempre creciente) ha sido reemplazada por $w$ (el tamaño fijo de la ventana). Para Mistral 7B ($L = 32$, $n_{\text{kv\_heads}} = 8$, $d_{\text{head}} = 128$, $w = 4096$), la caché está limitada a $2 \times 32 \times 8 \times 128 \times 4096 \times 2 = 512$ MB en FP16 — independientemente de si la conversación total tiene 4K, 32K o 128K tokens de largo.
Los casos límite clarifican el espacio de diseño: cuando $w = s$ (la ventana iguala la secuencia completa), la ventana deslizante se reduce a atención completa estándar — cada token atiende a todos los demás tokens y no hay ahorro de memoria. Cuando $w = 1$, el modelo no tiene contexto en absoluto — cada token solo se ve a sí mismo, haciendo imposible la generación coherente. El punto óptimo práctico está entre ambos: una ventana lo suficientemente grande para capturar el contexto local relevante para la mayoría de las tareas, pero lo suficientemente pequeña para mantener la memoria acotada.
Pero ¿descartar tokens antiguos no pierde información importante? No completamente. La información de tokens más allá de la ventana aún puede propagarse a través de las capas. Consideremos un modelo con $L$ capas y un tamaño de ventana $w$. En la capa 1, el token $t$ ve tokens en $[t - w + 1, t]$. Pero las representaciones de esos tokens fueron construidas en la capa 0 a partir de sus propias ventanas, que se extienden hasta $t - 2w + 1$. Después de $L$ capas de este apilamiento, la información de hasta $L \times w$ tokens atrás puede teóricamente influir en la salida actual — no a través de atención directa, sino por estar incorporada en las representaciones intermedias. Para Mistral 7B con $L = 32$ y $w = 4096$, el campo receptivo teórico es $32 \times 4096 = 131{,}072$ tokens. En la práctica la señal se atenúa a lo largo de muchas capas, pero por eso la atención de ventana deslizante funciona mejor de lo que el tamaño de la ventana solo podría sugerir.
Más allá de las ventanas deslizantes, varias otras estrategias ayudan a gestionar la memoria de la KV cache:
- KV cache paginada (PagedAttention): en lugar de pre-asignar memoria contigua para la longitud máxima posible de secuencia de cada solicitud, PagedAttention asigna memoria de caché en páginas de tamaño fijo (como la memoria virtual de un sistema operativo). Las páginas se asignan bajo demanda y se liberan cuando una solicitud se completa, eliminando la fragmentación interna que desperdicia el 60-80% de la memoria de la KV cache en implementaciones ingenuas. Cubrimos esto en profundidad en el artículo 4 sobre batching continuo.
- Cuantización de la KV cache: almacenar claves y valores en caché en INT8 (1 byte) o INT4 (0.5 bytes) en lugar de FP16 (2 bytes). Esto da una reducción de tamaño de 2-4 veces con impacto mínimo en la calidad, porque los valores en caché son activaciones intermedias (no pesos aprendidos) y toleran bien el ruido de cuantización. Combinado con GQA, las KV caches cuantizadas pueden ser 16-32 veces más pequeñas que las caches MHA estándar en FP16.
- Evicción de tokens: identificar y desalojar los tokens en caché menos importantes basándose en sus puntuaciones de atención. Los tokens que consistentemente reciben atención cercana a cero en pasos recientes probablemente no serán necesarios y pueden descartarse de forma segura. Enfoques como H2O (Heavy-Hitter Oracle) mantienen solo los tokens que históricamente han recibido más atención, limitando el tamaño de la caché mientras preservan el contexto más relevante.
Quiz
Pon a prueba tu comprensión de la KV cache y sus optimizaciones.
¿Por qué no necesita recalcularse la KV cache en cada paso de generación?
Para un modelo con $L = 40$ capas, $d_{\text{model}} = 5120$, a longitud de secuencia 4096 en FP16, ¿cuál es el tamaño de la KV cache por solicitud?
En Grouped-Query Attention (GQA), ¿qué sucede cuando $n_{\text{kv\_heads}} = 1$?
¿Cómo puede la información de tokens fuera de la ventana deslizante aún influir en la salida del modelo?