Why Is Standard Attention So Slow?

Every technique we've covered so far — the KV cache, quantisation, continuous batching, speculative decoding — optimises around the attention mechanism without changing how attention itself is computed. But the standard attention scores computation is itself a bottleneck. Let's look at exactly why.

Standard scaled dot-product attention computes:

$$\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$

Here $Q$, $K$, and $V$ are the query, key, and value matrices for a single attention head, each with $s$ rows (one per token in the sequence) and $d_k$ columns (the head dimension). The first operation, $QK^T$, multiplies an $s \times d_k$ matrix by a $d_k \times s$ matrix, producing an $s \times s$ matrix of raw attention scores . That's $s^2$ elements. We then apply softmax row-wise to normalise the scores into a probability distribution, producing another $s \times s$ matrix. Finally, we multiply by $V$ (also $s \times d_k$) to get the output.

The $s^2$ scaling is the problem. At $s = 2{,}048$ (a short context by modern standards), the attention matrix has about 4 million elements — manageable. But at $s = 32{,}768$ (32K context, which LLaMA-3 and many modern models support), the attention matrix has $32{,}768^2 \approx 1.07 \times 10^9$ elements — over a billion entries. In FP32, that's roughly 4 GB of memory per head, per layer . A model with 32 heads and 32 layers would need $32 \times 32 \times 4 = 4{,}096$ GB just to store the attention matrices during prefill. That's obviously impossible on any single GPU.

In practice, frameworks use FP16/BF16 (halving the memory) and process heads sequentially or in small groups rather than all at once, so the peak memory is much less than the worst-case calculation above. But the fundamental issue remains: the attention matrix must be computed, stored somewhere, read back for softmax normalisation, written again as the normalised scores, and then read back once more to multiply by $V$. Each of these steps involves a round trip to HBM (high-bandwidth memory) — the GPU's main memory. The arithmetic itself is fast; what's slow is shuttling the enormous $s \times s$ matrix back and forth through the memory bus.

Let's quantify the memory traffic. For one attention head, the standard implementation performs roughly these HBM operations:

  • Read $Q$ and $K$ from HBM ($2 \times s \times d_k$ elements), compute $QK^T$, write the $s \times s$ result to HBM.
  • Read the $s \times s$ scores from HBM, compute softmax row-wise, write the $s \times s$ normalised scores back to HBM.
  • Read the $s \times s$ normalised scores and $V$ from HBM, compute the matrix multiply, write the $s \times d_k$ output to HBM.

That's three reads and three writes of $O(s^2)$ data, totalling $O(s^2)$ bytes of HBM traffic. The actual compute (matrix multiplies and softmax) is also $O(s^2 d_k)$ FLOPs, but on modern GPUs the compute finishes long before the memory transfers complete. The ratio of useful compute to memory traffic — the arithmetic intensity — is too low. The GPU sits idle waiting for data, not waiting for math. This is the quadratic bottleneck of transformers: both memory and compute scale as $O(s^2)$, but memory bandwidth, not FLOPs, is the binding constraint.

import json, js

seq_lengths = [512, 2048, 8192, 32768, 131072]
d_k = 128  # typical head dimension
bytes_per_elem = 2  # FP16

rows = []
for s in seq_lengths:
    attn_elements = s * s
    attn_mem_bytes = attn_elements * bytes_per_elem
    qkv_mem_bytes = 3 * s * d_k * bytes_per_elem
    # Total HBM traffic (approx): 3 reads + 3 writes of s*s, plus Q,K,V,O
    hbm_traffic = 3 * 2 * attn_mem_bytes + 2 * qkv_mem_bytes
    if attn_mem_bytes < 1024**2:
        attn_str = f"{attn_mem_bytes / 1024:.0f} KB"
    elif attn_mem_bytes < 1024**3:
        attn_str = f"{attn_mem_bytes / 1024**2:.1f} MB"
    else:
        attn_str = f"{attn_mem_bytes / 1024**3:.1f} GB"
    if hbm_traffic < 1024**3:
        hbm_str = f"{hbm_traffic / 1024**2:.1f} MB"
    else:
        hbm_str = f"{hbm_traffic / 1024**3:.1f} GB"
    rows.append([f"{s:,}", f"{attn_elements:,.0f}", attn_str, hbm_str])

js.window.py_table_data = json.dumps({
    "headers": ["Seq Length", "Attention Elements", "Attention Matrix (FP16)", "Approx HBM Traffic (1 head)"],
    "rows": rows
})

print("Attention matrix size and HBM traffic per head (FP16, d_k=128)")
print("At 32K context, a single head's attention matrix is ~2 GB")
print("At 128K context, it's ~32 GB — far exceeding any GPU's SRAM")
💡 The table above is for a single attention head. A model with $n_h = 32$ heads multiplies these numbers by 32 (though in practice, heads are processed sequentially or in small groups to avoid allocating all attention matrices simultaneously). The key takeaway is that even for moderate sequence lengths, the HBM traffic for the $s \times s$ attention matrix dominates the total data movement.

FlashAttention: Never Materialise the Attention Matrix

If the bottleneck is reading and writing the $s \times s$ attention matrix to HBM, the solution is conceptually simple: don't store it there at all. That's the core insight of FlashAttention (Dao et al., 2022) , later refined in FlashAttention-2 (Dao, 2023) . Instead of computing the full $s \times s$ matrix in HBM and then reading it back repeatedly, FlashAttention computes attention in small tiles that fit entirely in SRAM (the GPU's on-chip memory, roughly 20 MB on an A100). The full attention matrix never exists in HBM — only the final output $O = \text{softmax}(QK^T / \sqrt{d_k}) \, V$ is written back.

But there's a mathematical obstacle. Softmax is a global operation: to normalise row $i$ of the attention matrix, we need $\max_j(S_{ij})$ and $\sum_j \exp(S_{ij} - \max)$ across the entire row. If we're processing $K$ and $V$ in blocks (only seeing a chunk of each row at a time), how do we compute an exact softmax without seeing the full row at once? FlashAttention solves this with the online softmax trick : maintain a running maximum and a running sum of exponentials, and correct the partial results as new blocks arrive. Crucially, this produces the exact same output as standard attention — it is not an approximation.

The algorithm works as follows:

  • Step 1: Tile the inputs. Divide $Q$ into blocks of $B_r$ rows and $K$, $V$ into blocks of $B_c$ rows. The block sizes are chosen so that the working set (one $Q$ block, one $K$ block, one $V$ block, and partial outputs) fits in SRAM.
  • Step 2: Outer loop over Q blocks. Load a block of $Q$ rows into SRAM. For this block, we will accumulate the final output.
  • Step 3: Inner loop over K/V blocks. For each $K$/$V$ block, load it into SRAM, compute the partial attention scores $S_{\text{block}} = Q_{\text{block}} K_{\text{block}}^T / \sqrt{d_k}$, update the running max and sum for the online softmax, compute partial softmax weights, multiply by $V_{\text{block}}$, and accumulate into the output. All of this happens in SRAM — no HBM writes for intermediate results.
  • Step 4: Write the final output. After iterating over all $K$/$V$ blocks, the accumulated output for this $Q$ block is complete and exact. Write it to HBM. Move on to the next $Q$ block.

The critical observation is that the $s \times s$ attention matrix is never fully materialised anywhere. Each tile of scores is computed in SRAM, used immediately for the softmax-weighted $V$ accumulation, and then discarded. The only data written to HBM is the final output $O$ (size $s \times d_k$), plus a small amount of bookkeeping (the per-row max and log-sum-exp values, needed for the backward pass).

What does this save? The IO complexity (total bytes read from and written to HBM) tells the story. Let $M$ be the size of SRAM in elements:

$$\text{Standard attention HBM IO} = O(s^2 + s \cdot d_k) = O(s^2)$$
$$\text{FlashAttention HBM IO} = O\!\left(\frac{s^2 \cdot d_k}{M}\right)$$

Let's unpack the FlashAttention formula with boundary analysis. The numerator $s^2 \cdot d_k$ reflects the total work: we still compute all $s^2$ attention scores (the computation is exact, not approximate), and each involves $d_k$-dimensional dot products. The denominator $M$ captures the tiling benefit — larger SRAM means larger tiles, which means each block of $Q$ rows can be paired with more $K$/$V$ blocks before needing to reload $Q$. The ratio $d_k / M$ determines the IO saving factor.

The boundary cases are illuminating. If SRAM were infinitely large ($M \to \infty$), we could load all of $Q$, $K$, $V$ into SRAM at once, compute everything on-chip, and write only the output $O$. The IO would be $O(s \cdot d_k)$ — just reading the inputs and writing the output, the minimum possible. This is the IO-optimal lower bound. At the other extreme, if SRAM were zero ($M \to 0$, meaning every intermediate value must live in HBM), every partial score would need an HBM round trip and we'd recover the $O(s^2)$ IO of standard attention. On real hardware like the A100 ($M \approx 20$ MB, or roughly $10^7$ FP16 elements) with $d_k = 128$, the saving factor is $M / d_k \approx 80{,}000$. That's a substantial reduction in HBM traffic, which is exactly why FlashAttention achieves 2-4x wall-clock speedups despite performing the same number of FLOPs.

Beyond speed, FlashAttention also solves the memory problem. Standard attention allocates $O(s^2)$ memory for the attention matrix. FlashAttention only needs $O(s)$ additional memory (for the output plus the per-row statistics). At 32K context, that's the difference between allocating billions of elements per head versus a few tens of thousands.

💡 FlashAttention is now the default attention implementation in PyTorch (via torch.nn.functional.scaled_dot_product_attention, which automatically dispatches to FlashAttention kernels when available) and in HuggingFace Transformers (via the attn_implementation="flash_attention_2" flag or automatic selection). There is no reason not to use it: it produces exactly the same result as standard attention, uses less memory, and runs faster.

The code below demonstrates the key idea conceptually: processing attention in blocks with a running (online) softmax, compared to the standard full-matrix approach. Both produce the same output.

import math

# Simulate tiled attention with online softmax vs standard attention
# Small example: s=8 tokens, d_k=4, block_size=2

s, d_k, B = 8, 4, 2

# Deterministic "random" Q, K, V using simple formula
def make_matrix(rows, cols, seed):
    return [[math.sin(seed + i * cols + j) * 0.5
             for j in range(cols)] for i in range(rows)]

Q = make_matrix(s, d_k, 1.0)
K = make_matrix(s, d_k, 2.0)
V = make_matrix(s, d_k, 3.0)

def dot(a, b):
    return sum(x * y for x, y in zip(a, b))

def matmul(A, B_T):
    # A[m][k] x B_T[n][k] -> C[m][n]
    return [[dot(A[i], B_T[j]) for j in range(len(B_T))] for i in range(len(A))]

# ── Standard attention (full s x s matrix) ──
scale = 1.0 / math.sqrt(d_k)
S = [[dot(Q[i], K[j]) * scale for j in range(s)] for i in range(s)]

# Softmax each row
O_standard = []
for i in range(s):
    row_max = max(S[i])
    exps = [math.exp(S[i][j] - row_max) for j in range(s)]
    row_sum = sum(exps)
    weights = [e / row_sum for e in exps]
    out = [sum(weights[j] * V[j][d] for j in range(s)) for d in range(d_k)]
    O_standard.append(out)

# ── Tiled attention with online softmax (FlashAttention-style) ──
O_tiled = [[0.0] * d_k for _ in range(s)]
row_max_all = [-float('inf')] * s
row_sum_all = [0.0] * s

for q_start in range(0, s, B):
    q_end = min(q_start + B, s)
    # Reset accumulators for this Q block
    local_max = [-float('inf')] * (q_end - q_start)
    local_sum = [0.0] * (q_end - q_start)
    local_out = [[0.0] * d_k for _ in range(q_end - q_start)]

    for k_start in range(0, s, B):
        k_end = min(k_start + B, s)
        for qi in range(q_end - q_start):
            i = q_start + qi
            scores = [dot(Q[i], K[j]) * scale for j in range(k_start, k_end)]
            block_max = max(scores)
            # Online softmax update
            old_max = local_max[qi]
            new_max = max(old_max, block_max)
            # Rescale previous accumulator
            correction = math.exp(old_max - new_max) if old_max != -float('inf') else 0.0
            local_sum[qi] = local_sum[qi] * correction
            for d in range(d_k):
                local_out[qi][d] *= correction
            # Add new block contribution
            for idx, j in enumerate(range(k_start, k_end)):
                w = math.exp(scores[idx] - new_max)
                local_sum[qi] += w
                for d in range(d_k):
                    local_out[qi][d] += w * V[j][d]
            local_max[qi] = new_max

    # Normalise and write output
    for qi in range(q_end - q_start):
        for d in range(d_k):
            O_tiled[q_start + qi][d] = local_out[qi][d] / local_sum[qi]

# Compare outputs
max_diff = max(abs(O_standard[i][d] - O_tiled[i][d])
               for i in range(s) for d in range(d_k))

print(f"Sequence length: {s}, Head dim: {d_k}, Block size: {B}")
print(f"Standard attention: full {s}x{s} = {s*s} element matrix in memory")
print(f"Tiled attention:    {B}x{B} = {B*B} element blocks (never stores full matrix)")
print(f"Max absolute difference: {max_diff:.2e}")
print(f"Outputs match: {max_diff < 1e-10}")

FlashDecoding: Optimising the Decode Phase

FlashAttention was designed for the prefill phase , where the query matrix $Q$ has many rows (one per prompt token). The tiling strategy parallelises over $Q$ blocks: each GPU thread block handles a different subset of query rows, and there's enough work to fill the GPU's streaming multiprocessors. But what happens during decode ? As we discussed in article 1, the decode phase generates one token at a time, so $Q$ has exactly one row. There's only a single $Q$ "block" to process, which means the entire FlashAttention outer loop runs in a single thread block. On a GPU with 108 streaming multiprocessors (like the A100), 107 of them sit idle.

This underutilisation problem gets worse as context length grows. During decode with a 32K-token KV cache, the single thread block must iterate over all $32{,}768 / B_c$ key-value blocks sequentially. The work per block is small (a single query row dotted with $B_c$ keys), so the GPU has thousands of sequential steps with minimal parallelism. The result: FlashAttention during decode is significantly slower than during prefill, relative to the theoretical peak.

FlashDecoding (introduced by Tri Dao and collaborators) fixes this by changing which dimension we parallelise over . Instead of parallelising over $Q$ rows (of which there is only one during decode), FlashDecoding parallelises over the KV sequence length dimension. The algorithm splits the KV cache into chunks along the sequence dimension and assigns each chunk to a separate thread block:

  • Step 1: Split the KV cache. Divide the $s$ cached key-value pairs into $P$ chunks of roughly $s / P$ tokens each. Each chunk is assigned to a separate GPU thread block.
  • Step 2: Compute partial attention per chunk. Each thread block loads the single query row and its assigned KV chunk into SRAM, computes the partial attention scores, runs a local softmax (tracking the local max and sum), and produces a partial output vector — the softmax-weighted sum of values within that chunk.
  • Step 3: Global reduction. Combine the $P$ partial outputs into the final result. This requires correcting for the fact that each chunk computed softmax with a local (not global) maximum. The correction uses the same online softmax rescaling trick: given the partial max $m_p$ and sum $l_p$ from each chunk $p$, compute the global max $m = \max_p m_p$, rescale each partial output by $\exp(m_p - m) \cdot l_p$, and normalise by the corrected global sum.

The reduction step adds a small overhead (combining $P$ partial results, each of dimension $d_k$), but $P$ is typically a few hundred at most, and $d_k$ is typically 128, so the cost is negligible compared to the attention computation itself. The critical benefit is that all $P$ thread blocks run in parallel, fully utilising the GPU. If we have 108 streaming multiprocessors and choose $P = 108$, every SM is busy — a dramatic improvement over the single-block scenario.

How much does this help in practice? The speedup depends on context length. For short contexts (a few hundred tokens), standard FlashAttention's single thread block can process all KV blocks quickly enough that the parallelism overhead of FlashDecoding isn't worth it. But for long contexts — 8K, 32K, 128K tokens — FlashDecoding provides substantial speedups (up to 8x for very long sequences) because it converts a sequential scan over thousands of KV blocks into a parallel operation across the entire GPU.

This matters because decode is the bottleneck for end-to-end latency in most serving scenarios. As we discussed in article 1, the decode phase generates tokens one at a time and dominates wall-clock time for any response longer than a few tokens. Making decode-phase attention faster directly reduces the time-per-token, which directly reduces the user-perceived latency.

💡 FlashDecoding is conceptually the same idea as FlashAttention (tiled, IO-aware, never materialise the full attention matrix), just with the parallelism axis rotated from Q rows to KV sequence positions. Both produce exact results. The distinction matters only during decode, when Q has a single row; during prefill, standard FlashAttention already parallelises effectively over Q rows.

Multi-Head Latent Attention (MLA)

In article 2, we saw how Grouped-Query Attention (GQA) reduces the KV cache by sharing key-value projections across groups of query heads. GQA trades some representational capacity for memory savings, and it works well — LLaMA 2/3, Mistral, and most modern models use it. But there's a fundamentally different approach: instead of sharing K and V across heads, compress them into a low-rank latent space. That's Multi-Head Latent Attention (MLA) , introduced in the DeepSeek-V2 architecture (DeepSeek-AI, 2024) .

In standard multi-head attention (MHA), the KV cache stores separate key and value vectors for every head at every token position. For a model with $n_h$ heads and head dimension $d_h$, that's $2 \times n_h \times d_h$ values per token per layer (the factor of 2 for keys and values). With GQA using $n_{\text{kv}}$ key-value groups, this drops to $2 \times n_{\text{kv}} \times d_h$. MLA takes a completely different path: instead of caching the full key and value vectors, it caches a single compressed vector $c_t \in \mathbb{R}^{d_c}$ per token, where $d_c$ is the latent dimension and $d_c \ll n_h \times d_h$.

The compression works through a pair of learned projections. When token $t$ is first processed, instead of computing and caching $K_t$ and $V_t$ directly, MLA computes a down-projection:

$$c_t = x_t \, W^{\text{down}} \quad \text{where} \quad W^{\text{down}} \in \mathbb{R}^{d_{\text{model}} \times d_c}$$

This compressed representation $c_t$ is what gets stored in the KV cache — just $d_c$ values per token instead of $2 \times n_h \times d_h$. At attention time, the keys and values are reconstructed on the fly via up-projections:

$$K_t = c_t \, W_K^{\text{up}}, \quad V_t = c_t \, W_V^{\text{up}}$$

where $W_K^{\text{up}} \in \mathbb{R}^{d_c \times (n_h \cdot d_h)}$ and $W_V^{\text{up}} \in \mathbb{R}^{d_c \times (n_h \cdot d_h)}$ are learned matrices that decompress the latent back into full key and value vectors for all heads.

The KV cache savings are dramatic. Let's work through the DeepSeek-V2 numbers. The model has $n_h = 128$ attention heads with $d_h = 128$, so standard MHA would cache $2 \times 128 \times 128 = 32{,}768$ values per token per layer. MLA uses $d_c = 512$, caching just 512 values per token per layer. That's a 64x reduction in KV cache size — far beyond what GQA achieves.

import json, js

# Compare KV cache: MHA vs GQA vs MLA for DeepSeek-V2 scale
n_h = 128       # query heads
d_h = 128       # head dimension
n_kv_gqa = 8    # GQA groups (hypothetical)
d_c = 512       # MLA latent dimension
L = 60          # layers
b_prec = 2      # FP16 bytes

configs = {
    "MHA": 2 * n_h * d_h,           # full K + V per head
    "GQA (8 groups)": 2 * n_kv_gqa * d_h,
    "MLA (d_c=512)": d_c,           # single compressed vector
}

seq_lengths = [2048, 8192, 32768, 131072]

rows = []
for name, per_token_vals in configs.items():
    for s in seq_lengths:
        cache_bytes = L * per_token_vals * s * b_prec
        if cache_bytes < 1024**3:
            size_str = f"{cache_bytes / 1024**2:.0f} MB"
        else:
            size_str = f"{cache_bytes / 1024**3:.1f} GB"
        rows.append([name, str(per_token_vals), f"{s:,}", size_str])

js.window.py_table_data = json.dumps({
    "headers": ["Attention Type", "Values/Token/Layer", "Seq Length", "KV Cache (FP16, 60 layers)"],
    "rows": rows
})

mha_per_token = 2 * n_h * d_h
mla_per_token = d_c
print(f"MHA caches {mha_per_token:,} values per token per layer")
print(f"MLA caches {mla_per_token:,} values per token per layer")
print(f"Reduction factor: {mha_per_token / mla_per_token:.0f}x")
print(f"At 128K context, 60 layers: MHA needs ~{60 * mha_per_token * 131072 * 2 / 1024**3:.0f} GB, MLA needs ~{60 * mla_per_token * 131072 * 2 / 1024**3:.1f} GB")

The obvious question: doesn't the decompression cost eat into the savings? The up-projections $W_K^{\text{up}}$ and $W_V^{\text{up}}$ multiply a $d_c$-dimensional vector by a large matrix to reconstruct the full keys and values. That's additional FLOPs. But here's the crucial insight: during decode , the model is memory-bandwidth-bound , not compute-bound (article 1). The GPU's arithmetic units are largely idle, waiting for data to arrive from HBM. Adding compute that reduces memory traffic is essentially free — the GPU does the extra matrix multiplies during time it would otherwise spend waiting. The up-projection weight matrices $W_K^{\text{up}}$ and $W_V^{\text{up}}$ are part of the model weights (loaded once, shared across all tokens), while the compressed cache $c_t$ is much smaller than the full KV cache would be. The net effect is less total memory traffic, even though we're doing more arithmetic.

During prefill (which is compute-bound), the extra decompression does add measurable cost. DeepSeek-V2 addresses this by absorbing the up-projection into the attention computation algebraically: instead of decompressing $K$ and $V$ explicitly and then computing attention, the model reformulates the attention equations to operate directly on the compressed representations. The math is equivalent, but the implementation avoids materialising the full-size $K$ and $V$ tensors, saving both memory and some compute.

💡 MLA can be seen as applying the idea of low-rank factorisation (the same principle behind LoRA for fine-tuning) to the KV cache. Instead of storing the full-rank key and value matrices, we store a low-rank bottleneck and reconstruct on the fly. The trade-off — more FLOPs, less memory — is exactly the right one for memory-bound inference.

Choosing Attention Optimisations

We've covered four attention optimisations that target different aspects of the problem. How do they relate to each other, and when should you use each one?

FlashAttention is universal. It produces exact results, uses less memory, and runs faster than standard attention in every scenario. There is no trade-off and no reason not to use it. If your framework supports it (PyTorch 2.0+, HuggingFace Transformers, and virtually every modern inference engine do), it should be enabled by default. FlashAttention optimises how attention is computed (IO-aware tiling) without changing what is computed.

Grouped-Query Attention (GQA) is an architectural decision made at training time. A model must be trained with GQA from the start (or converted via up-training, as the Ainslie et al. paper demonstrates). You cannot retroactively apply GQA to a model trained with standard MHA. GQA reduces the KV cache by sharing key-value projections across groups of query heads, typically giving 4-8x cache reduction with negligible quality loss. LLaMA-2 70B, LLaMA-3, Mistral 7B, Mixtral, Gemma, and most post-2023 models use GQA.

Multi-Head Latent Attention (MLA) is also an architectural decision, and a more recent one. It achieves far more aggressive KV cache compression than GQA (up to 64x vs 8x) by learning a low-rank latent representation. MLA is currently used by the DeepSeek-V2 and DeepSeek-V3 model families. It adds compute during decompression, but this is offset by the massive memory savings during decode. As more architectures adopt MLA (or similar compression strategies), it may become as standard as GQA is today.

FlashDecoding addresses a specific bottleneck: GPU underutilisation during the decode phase with long contexts. It's a runtime optimisation (like FlashAttention), not an architectural choice, and applies to any model. It provides the largest speedups for long-context decode — exactly the scenario where attention computation is most expensive relative to the rest of the model. For short contexts, the benefit is minimal.

Crucially, these techniques compose . FlashAttention handles the IO-efficient tiling. GQA or MLA reduces the size of what's being cached and loaded. FlashDecoding parallelises the decode-phase computation. A modern inference stack typically combines all applicable techniques: for example, serving LLaMA-3 70B uses FlashAttention (IO-aware tiling) + GQA (8 KV head groups, reducing cache by 8x) + FlashDecoding (parallel decode over the KV sequence dimension) + PagedAttention from vLLM (efficient cache memory management) + INT8 KV cache quantisation. Each technique targets a different part of the problem, and their benefits stack multiplicatively.

import json, js

rows = [
    ["FlashAttention",
     "Runtime (IO-aware kernel)",
     "Any model",
     "2-4x speed, O(s) memory",
     "None (exact, always faster)"],
    ["FlashDecoding",
     "Runtime (decode parallelism)",
     "Any model, long-context decode",
     "Up to 8x decode speedup",
     "Minimal for short contexts"],
    ["GQA",
     "Architecture (training time)",
     "LLaMA-2/3, Mistral, Gemma, etc.",
     "4-8x KV cache reduction",
     "Slight quality loss vs MHA"],
    ["MLA",
     "Architecture (training time)",
     "DeepSeek-V2/V3",
     "Up to 64x KV cache reduction",
     "Extra decompression FLOPs"],
]

js.window.py_table_data = json.dumps({
    "headers": ["Technique", "Type", "Applies To", "Benefit", "Trade-off"],
    "rows": rows
})

print("Attention optimisation summary")
print("FlashAttention + GQA/MLA + FlashDecoding is the standard stack")
print("All compose: each targets a different bottleneck")
💡 A useful mental model: FlashAttention reduces the cost of computing attention on whatever data the model gives it. GQA and MLA reduce the amount of data that needs to be stored and loaded. FlashDecoding ensures the GPU is fully utilised when computing attention during decode. Together, they attack the attention bottleneck from three complementary angles: IO efficiency, data volume, and parallelism.

Quiz

Test your understanding of attention optimisation techniques.

Why is standard attention slow despite modern GPUs having enormous compute throughput?

What is the key property that makes FlashAttention different from approximate attention methods?

Why does FlashAttention underperform during the decode phase, motivating FlashDecoding?

In Multi-Head Latent Attention (MLA), why is the extra compute for decompressing keys and values during decode essentially free?