Why Recompute What You Already Know?

When a transformer generates text, it works autoregressively : it produces one token at a time, feeding each new token back in to generate the next. At step $t$, the model computes attention scores over all previous tokens $1, 2, \ldots, t-1$ to decide what to attend to when predicting token $t$. But here's the problem: the keys and values for tokens $1$ through $t-1$ were already computed in previous steps, and they haven't changed. The attention mechanism projects each token's embedding through learned weight matrices $W_K$ and $W_V$ to produce key and value vectors. Those projections depend only on the token's position and the fixed model weights — not on what comes after. So every time we recompute them, we're doing pure redundant work.

How much waste are we talking about? At generation step $t$, recomputing all $t$ key-value pairs costs $O(t \cdot d)$ operations (where $d$ is the model dimension), but only the new token's pair is actually new information. Over a full sequence of length $s$, the total redundant computation is $O(1 + 2 + \ldots + s) = O(s^2)$, a quadratic cost that grows painfully as sequences get longer. For a 4,096-token generation, that's roughly 8 million unnecessary re-projections across all steps.

The KV cache eliminates this waste entirely. Instead of recomputing keys and values for every previous token at every step, we store them the first time they're computed and reuse them. At step $t$, we only compute the new token's query, key, and value vectors ($q_t$, $k_t$, $v_t$), append $k_t$ and $v_t$ to the cache, and then compute attention between $q_t$ and all cached keys $[k_1, \ldots, k_t]$. The generation cost per step drops from $O(t \cdot d)$ for projections to $O(d)$ — we still pay $O(t \cdot d)$ for the attention dot products themselves (the new query must attend to all cached keys), but we've eliminated the entire redundant projection cost.

💡 The KV cache trades compute for memory. We save $O(s^2)$ redundant FLOPs across a full generation, but now we need GPU memory to store all those cached key and value tensors. As we'll see, this memory cost becomes the dominant bottleneck for long-context inference.

How the KV Cache Works

Let's walk through the mechanics step by step. In standard multi-head attention , each layer projects the input through three weight matrices to produce queries ($Q$), keys ($K$), and values ($V$). Without a KV cache, at generation step $t$ the model must compute $Q$, $K$, and $V$ for all $t$ tokens in the sequence and then run the full attention computation. The projection cost alone is $O(t \cdot d)$ per matrix, and the attention dot product costs another $O(t^2 \cdot d_{\text{head}})$ across all heads.

With a KV cache, the picture changes dramatically. At step $t$, we compute $q_t$, $k_t$, and $v_t$ for only the single new token — a fixed $O(d)$ projection cost regardless of sequence length. We append $k_t$ and $v_t$ to the cache (which already holds $[k_1, \ldots, k_{t-1}]$ and $[v_1, \ldots, v_{t-1}]$), and then compute attention: $q_t$ attends to all $t$ cached keys to produce a weighted combination of all $t$ cached values. The attention dot products still cost $O(t \cdot d_{\text{head}})$ per head (one query against $t$ keys), but we've eliminated the $O(t \cdot d)$ cost of recomputing all previous projections.

What exactly does the cache store? For each layer in the model, it holds two tensors: the accumulated key vectors $[k_1, k_2, \ldots, k_t]$ and the accumulated value vectors $[v_1, v_2, \ldots, v_t]$. Each individual key or value vector has dimension $d_{\text{head}}$, and there are $n_{\text{heads}}$ heads per layer, so the cache stores $2 \times n_{\text{heads}} \times d_{\text{head}}$ values per token per layer (one key vector plus one value vector for each head).

Since $n_{\text{heads}} \times d_{\text{head}} = d_{\text{model}}$, the per-token-per-layer cache cost simplifies to $2 \times d_{\text{model}}$ values. In FP16 (2 bytes per value), that's $4 d_{\text{model}}$ bytes per token per layer. The total KV cache size across the full model and sequence is:

$$M_{\text{KV}} = 2 \times L \times n_{\text{heads}} \times d_{\text{head}} \times s \times b_{\text{precision}}$$

Let's unpack every variable:

  • $2$ — one set of keys and one set of values. We always store both because attention needs keys to compute scores and values to compute the weighted output.
  • $L$ — number of transformer layers. Each layer has its own attention mechanism with its own $W_K$ and $W_V$ projections, so each layer needs its own cache.
  • $n_{\text{heads}}$ — number of attention heads per layer. In standard multi-head attention, each head has an independent key and value projection.
  • $d_{\text{head}}$ — dimension of each attention head. Typically $d_{\text{model}} / n_{\text{heads}}$.
  • $s$ — sequence length (number of tokens cached so far). This is the dimension that grows during generation and makes long-context inference expensive.
  • $b_{\text{precision}}$ — bytes per value. 2 for FP16/BF16, 4 for FP32, 1 for INT8.

Since $n_{\text{heads}} \times d_{\text{head}} = d_{\text{model}}$, we can simplify:

$$M_{\text{KV}} = 2 \times L \times d_{\text{model}} \times s \times b_{\text{precision}}$$

Now let's see what this means in practice. Take LLaMA-7B: $L = 32$ layers, $d_{\text{model}} = 4096$, in FP16 ($b_{\text{precision}} = 2$ bytes). At a sequence length of 4,096 tokens:

$$M_{\text{KV}} = 2 \times 32 \times 4096 \times 4096 \times 2 = 2{,}147{,}483{,}648 \text{ bytes} = 2 \text{ GB}$$

Two gigabytes of GPU memory just for the KV cache of a single request. The model weights themselves for LLaMA-7B are about 14 GB in FP16, so at 4K context the cache is already 14% of the model size. Now push the sequence length to 32K tokens:

$$M_{\text{KV}} = 2 \times 32 \times 4096 \times 32768 \times 2 = 16 \text{ GB}$$

Sixteen gigabytes — larger than the model itself. And that's per request. If you're serving 64 concurrent users, multiply by 64: over 1 TB of KV cache memory. This is why KV cache management is central to inference optimization: at scale, the cache dominates GPU memory far more than the model weights do.

The code below computes KV cache sizes for several popular model architectures across different sequence lengths, making the scaling concrete:

import json, js

models = {
    "LLaMA-7B":   {"L": 32, "d_model": 4096,  "n_kv_heads": 32, "d_head": 128},
    "LLaMA-13B":  {"L": 40, "d_model": 5120,  "n_kv_heads": 40, "d_head": 128},
    "LLaMA-70B":  {"L": 80, "d_model": 8192,  "n_kv_heads": 8,  "d_head": 128},
    "Mistral-7B": {"L": 32, "d_model": 4096,  "n_kv_heads": 8,  "d_head": 128},
    "GPT-3 175B": {"L": 96, "d_model": 12288, "n_kv_heads": 96, "d_head": 128},
}

seq_lengths = [2048, 4096, 8192, 32768, 131072]
b_prec = 2  # FP16

rows = []
for name, cfg in models.items():
    for s in seq_lengths:
        kv_bytes = 2 * cfg["L"] * cfg["n_kv_heads"] * cfg["d_head"] * s * b_prec
        if kv_bytes < 1024**3:
            size_str = f"{kv_bytes / 1024**2:.0f} MB"
        else:
            size_str = f"{kv_bytes / 1024**3:.1f} GB"
        rows.append([name, str(s), str(cfg["L"]), str(cfg["n_kv_heads"]),
                      str(cfg["d_head"]), size_str])

js.window.py_table_data = json.dumps({
    "headers": ["Model", "Seq Len", "Layers", "KV Heads", "d_head", "KV Cache (FP16)"],
    "rows": rows
})

print("KV cache sizes for popular models at various sequence lengths (FP16)")
print("Note: LLaMA-70B and Mistral-7B use GQA (fewer KV heads), so their caches are smaller than MHA equivalents")
💡 Notice how LLaMA-70B (8 KV heads via GQA) has a smaller KV cache than LLaMA-13B (40 KV heads via standard MHA) at the same sequence length, despite being a 5x larger model. The number of KV heads, not the total model size, determines cache cost. We'll explore GQA in detail in Section 4.

The KV Cache Memory Problem

The table above tells a clear story: for long sequences, the KV cache becomes the dominant consumer of GPU memory, eclipsing the model weights themselves. Let's put concrete numbers on the problem. Consider a 70B parameter model (like LLaMA-2 70B) with $L = 80$ layers, $d_{\text{model}} = 8192$, and standard multi-head attention with $n_{\text{heads}} = 64$. In FP16, the model weights occupy roughly 140 GB. Now compute the KV cache at 32K context:

$$M_{\text{KV}} = 2 \times 80 \times 8192 \times 32768 \times 2 = 85{,}899{,}345{,}920 \text{ bytes} \approx 80 \text{ GB}$$

Eighty gigabytes for a single request's KV cache — more than half the model's own weight footprint. And this is per request. If you want to serve even 8 concurrent users, the KV cache alone needs $8 \times 80 = 640$ GB, far exceeding the memory of any single GPU (an A100 has 80 GB, an H100 has 80 GB). This means that in production, the maximum batch size is not limited by the model's memory footprint — it's limited by how many KV caches fit in the remaining GPU memory after the model is loaded.

This creates a direct tension between throughput (serving more users simultaneously) and context length (supporting longer conversations or documents). Double the sequence length and you halve the number of concurrent requests you can serve. This is why context length extensions are so hard: supporting 128K or 1M context windows isn't just a modelling challenge — it's fundamentally a memory management problem. The KV cache grows linearly with sequence length, and GPU memory is finite.

This memory wall is what motivates nearly every technique in the rest of this track:

  • Grouped-Query Attention (GQA): reduce the number of KV heads so there's less to cache per token (covered next).
  • PagedAttention: manage KV cache memory like an operating system manages virtual memory — allocate and free in pages to avoid fragmentation (covered in article 4 on continuous batching).
  • KV cache quantisation: store cached keys and values in INT8 or INT4 instead of FP16, cutting memory by 2-4x.
  • Sliding window attention: bound the cache size by only attending to the last $w$ tokens (covered in Section 5).
💡 In practice, inference serving systems like vLLM report that KV cache memory is the primary bottleneck for throughput. The model weights are loaded once and shared across all requests, but each request needs its own KV cache. This is why KV cache optimisation has more impact on serving cost than model compression in many deployments.

Multi-Query and Grouped-Query Attention

If the KV cache is the bottleneck, the most direct fix is to reduce how much we cache per token. In standard multi-head attention (MHA) , each of the $n_{\text{heads}}$ attention heads has its own independent key and value projections. That means we store $n_{\text{heads}}$ separate key vectors and $n_{\text{heads}}$ separate value vectors per token per layer. But do all those heads really need their own keys and values? The query projections differ per head (each head learns to ask different questions), but perhaps the keys and values (the information being looked up) can be shared.

Multi-Query Attention (MQA) (Shazeer, 2019) takes this idea to its extreme: all attention heads share a single set of key and value projections. Each head still has its own query projection $W_Q^{(h)}$, so each head asks a different question, but they all search against the same keys and retrieve from the same values. The KV cache shrinks by a factor of $n_{\text{heads}}$, because we store only one key vector and one value vector per token per layer instead of $n_{\text{heads}}$ of each.

For LLaMA-7B with 32 heads, this is a dramatic reduction. The cache shrinks from 2 GB (at 4K context in FP16) to $2{,}048 / 32 = 64$ MB. That's a 32x saving — enough to serve 32 times as many concurrent requests in the same GPU memory, or extend context length by 32x. The tradeoff is quality: with only one set of keys and values, the attention heads can no longer specialise their key-value representations. In practice, MQA shows a slight quality degradation compared to standard MHA, particularly on tasks requiring fine-grained multi-aspect reasoning.

Grouped-Query Attention (GQA) (Ainslie et al., 2023) finds the middle ground. Instead of all heads sharing one KV set (MQA) or each head having its own (MHA), GQA organises the $n_{\text{heads}}$ query heads into $n_{\text{kv\_heads}}$ groups, where each group shares one set of keys and values. Within a group, multiple query heads (specifically $n_{\text{heads}} / n_{\text{kv\_heads}}$ of them) attend to the same keys and values but with different query projections.

The cache formula becomes:

$$M_{\text{KV}} = 2 \times L \times n_{\text{kv\_heads}} \times d_{\text{head}} \times s \times b_{\text{precision}}$$

Notice that $n_{\text{kv\_heads}}$ has replaced $n_{\text{heads}}$ from the original formula. The boundary cases are revealing: when $n_{\text{kv\_heads}} = n_{\text{heads}}$, every query head has its own KV pair and we recover standard MHA. When $n_{\text{kv\_heads}} = 1$, all heads share one KV pair and we recover MQA. Any value in between gives us GQA, with the cache reduction factor being $n_{\text{heads}} / n_{\text{kv\_heads}}$.

LLaMA-2 70B uses GQA with $n_{\text{kv\_heads}} = 8$ and $n_{\text{heads}} = 64$, giving an 8x cache reduction compared to standard MHA. The Ainslie et al. paper showed that GQA with a moderate number of KV groups recovers nearly all the quality of full MHA while providing most of the memory savings of MQA. This has made GQA the de facto standard for modern large language models: LLaMA-2, LLaMA-3, Mistral, Mixtral, and many others all use it.

The table below compares KV cache sizes for LLaMA-7B-scale models under MHA, GQA, and MQA, making the savings concrete:

import json, js

# LLaMA-7B scale: L=32, d_model=4096, n_heads=32, d_head=128
L = 32
d_head = 128
n_heads = 32
b_prec = 2  # FP16

configs = {
    "MHA (32 KV heads)": 32,
    "GQA-8 (8 KV heads)": 8,
    "GQA-4 (4 KV heads)": 4,
    "MQA (1 KV head)": 1,
}

seq_lengths = [2048, 4096, 8192, 32768, 131072]

rows = []
for config_name, n_kv in configs.items():
    for s in seq_lengths:
        kv_bytes = 2 * L * n_kv * d_head * s * b_prec
        if kv_bytes < 1024**3:
            size_str = f"{kv_bytes / 1024**2:.0f} MB"
        else:
            size_str = f"{kv_bytes / 1024**3:.1f} GB"
        reduction = n_heads / n_kv
        rows.append([config_name, str(s), size_str, f"{reduction:.0f}x"])

js.window.py_table_data = json.dumps({
    "headers": ["Attention Type", "Seq Len", "KV Cache (FP16)", "Reduction vs MHA"],
    "rows": rows
})

print("KV cache comparison: MHA vs GQA vs MQA (LLaMA-7B scale, FP16)")
print(f"Model: L={L}, d_head={d_head}, n_query_heads={n_heads}")
print()
print("GQA-8 gives 4x reduction; MQA gives 32x but with potential quality loss")
💡 GQA is not just about memory — it also improves inference speed. Fewer KV heads mean less data to load from GPU memory during attention, and modern GPUs are often memory-bandwidth-bound during the decoding phase. The Ainslie et al. paper reports that GQA matches MHA quality while achieving inference speeds close to MQA.

Sliding Window and Other Cache Strategies

GQA reduces the per-token cache cost, but the cache still grows linearly with sequence length. For very long sequences (32K, 128K, or beyond), even a GQA-compressed cache can exceed available memory. Sliding window attention takes a fundamentally different approach: instead of caching the entire sequence, it only attends to the most recent $w$ tokens. The cache size is bounded at $w$ entries per layer regardless of how long the total sequence gets.

Mistral 7B uses sliding window attention with $w = 4{,}096$. At any generation step $t$, the model only attends to tokens at positions $\max(1, t - w + 1)$ through $t$. The cache holds at most $w$ key-value pairs per layer, so the maximum cache size is:

$$M_{\text{KV}}^{\text{window}} = 2 \times L \times n_{\text{kv\_heads}} \times d_{\text{head}} \times w \times b_{\text{precision}}$$

Notice that $s$ (the ever-growing sequence length) has been replaced by $w$ (the fixed window size). For Mistral 7B ($L = 32$, $n_{\text{kv\_heads}} = 8$, $d_{\text{head}} = 128$, $w = 4096$), the cache is bounded at $2 \times 32 \times 8 \times 128 \times 4096 \times 2 = 512$ MB in FP16 — regardless of whether the total conversation is 4K, 32K, or 128K tokens long.

The boundary cases clarify the design space: when $w = s$ (window equals the full sequence), sliding window reduces to standard full attention — every token attends to every other token and there is no memory saving. When $w = 1$, the model has no context at all — each token only sees itself, making coherent generation impossible. The practical sweet spot lies in between: a window large enough to capture the relevant local context for most tasks, but small enough to keep memory bounded.

But doesn't throwing away old tokens lose important information? Not entirely. Information from tokens beyond the window can still propagate through layers. Consider a model with $L$ layers and a window size $w$. At layer 1, token $t$ sees tokens in $[t - w + 1, t]$. But the representations of those tokens were built at layer 0 from their own windows, which extend back to $t - 2w + 1$. After $L$ layers of this stacking, information from up to $L \times w$ tokens back can theoretically influence the current output — not through direct attention, but by being baked into the intermediate representations. For Mistral 7B with $L = 32$ and $w = 4096$, the theoretical receptive field is $32 \times 4096 = 131{,}072$ tokens. In practice the signal attenuates over many layers, but this is why sliding window attention works better than the window size alone might suggest.

Beyond sliding windows, several other strategies help manage KV cache memory:

  • Paged KV cache (PagedAttention): rather than pre-allocating contiguous memory for each request's maximum possible sequence length, PagedAttention allocates cache memory in fixed-size pages (like an operating system's virtual memory). Pages are allocated on demand and freed when a request completes, eliminating the internal fragmentation that wastes 60-80% of KV cache memory in naive implementations. We cover this in depth in article 4 on continuous batching.
  • KV cache quantisation: store cached keys and values in INT8 (1 byte) or INT4 (0.5 bytes) instead of FP16 (2 bytes). This gives a 2-4x size reduction with minimal quality impact, because the cached values are intermediate activations (not learned weights) and tolerate quantisation noise well. Combined with GQA, quantised KV caches can be 16-32x smaller than standard MHA caches in FP16.
  • Token eviction: identify and evict the least important cached tokens based on their attention scores. Tokens that consistently receive near-zero attention across recent steps are unlikely to be needed and can be safely dropped. Approaches like H2O (Heavy-Hitter Oracle) maintain only the tokens that have historically received the most attention, bounding cache size while preserving the most relevant context.
💡 These techniques compose multiplicatively. GQA with 8 KV heads (8x reduction) + sliding window at 4K (bounded size) + INT8 quantisation (2x reduction) + PagedAttention (eliminates fragmentation) can reduce effective KV cache memory by 50-100x compared to a naive MHA implementation with pre-allocated FP16 buffers. This is what makes it practical to serve long-context models to many concurrent users on commodity hardware.

Quiz

Test your understanding of the KV cache and its optimisations.

Why doesn't the KV cache need to be recomputed at each generation step?

For a model with $L = 40$ layers, $d_{\text{model}} = 5120$, at sequence length 4096 in FP16, what is the KV cache size per request?

In Grouped-Query Attention (GQA), what happens when $n_{\text{kv\_heads}} = 1$?

How can information from tokens outside the sliding window still influence the model's output?