Does Every Token Need to See Every Other Token?
Standard full (dense) attention computes a score between every pair of tokens in the sequence. For a sequence of $n$ tokens, that means $n^2$ attention scores, $n^2$ entries in the attention matrix, and $O(n^2)$ time and memory. When $n = 4{,}096$, that's about 16 million scores. At $n = 128{,}000$ (a Llama 3-class context), it's over 16 billion.
But is all that compute actually useful? Empirically, most attention weights concentrate on a small subset of positions: nearby tokens (local context) and a handful of globally important ones (beginning-of-sequence, punctuation, task-specific anchors). The vast majority of the $n^2$ scores are near zero and contribute almost nothing to the output. If we could skip computing those near-zero scores, we would save most of the work without meaningfully changing the result.
That's the core idea behind sparse attention patterns : instead of letting every token attend to every other token, we define a pattern that selects which (query, key) pairs are actually computed. The rest are treated as if their score were $-\infty$ (masked out before softmax). The savings depend on the pattern: the sparser the mask, the fewer scores we compute, and the faster (and more memory-efficient) the layer becomes.
Sliding Window Attention
The simplest sparse pattern is the sliding window : each token only attends to the $w$ most recent tokens (including itself). If the token is at position $i$, it sees positions $\max(0, \, i - w + 1)$ through $i$. Everything before that window is masked out.
Let's check this at the boundaries. When $w = n$ (window equals the full sequence), every token sees every other token and we recover full attention at $O(n^2)$. When $w = 1$, each token sees only itself and attention degenerates to a pointwise operation at $O(n)$. In practice, models like Mistral 7B use $w = 4{,}096$: each token attends to the last 4,096 positions, giving $O(4096 \cdot n)$ cost per layer, which is linear in $n$.
But if each layer can only see $w$ tokens back, how does the model handle long-range dependencies? Through information propagation across layers . After one layer, token $i$ has information about tokens $[i - w + 1, \, i]$. After two layers, token $i$ has indirect information about tokens $[i - 2w + 2, \, i]$, because the tokens in its window already attended to their own windows in the previous layer. After $L$ layers, the effective receptive field is:
For Mistral 7B ($L = 32$, $w = 4{,}096$): $32 \times 4{,}096 = 131{,}072$ tokens. At $L = 1$, a token can only access $w = 4{,}096$ positions. At $L = 32$, information from up to 131K tokens back can influence the output, though it is increasingly attenuated with each hop (each layer mixes but also dilutes signals). This is why sliding window models can handle contexts far beyond $w$ at the cost of reduced fidelity for very distant tokens.
A major practical benefit is that the KV cache is bounded. Instead of caching all $n$ past key-value pairs per layer (which grows linearly with generated tokens), we only need to store the last $w$ entries. For Mistral with $w = 4{,}096$, the KV cache is fixed at 4,096 entries per layer regardless of total sequence length. This is implemented as a rolling buffer (circular buffer) : new entries overwrite the oldest ones at position $i \mod w$. No memory reallocation, no growing cache. For an in-depth discussion of KV cache management and sliding window inference, see our KV cache article .
# Sliding window attention: rolling buffer KV cache
w = 8 # window size
class RollingKVCache:
def __init__(self, window_size):
self.w = window_size
self.buffer = [None] * window_size
self.count = 0
def insert(self, kv_entry):
pos = self.count % self.w # circular index
evicted = self.buffer[pos]
self.buffer[pos] = kv_entry
self.count += 1
return pos, evicted
def active_entries(self):
return [e for e in self.buffer if e is not None]
cache = RollingKVCache(window_size=w)
# Simulate inserting 12 tokens into a window of size 8
for token_id in range(12):
pos, evicted = cache.insert(f"tok_{token_id}")
if evicted:
print(f"Step {token_id:2d}: insert tok_{token_id} at slot {pos}, evicted {evicted}")
else:
print(f"Step {token_id:2d}: insert tok_{token_id} at slot {pos}")
print(f"\nFinal buffer (window={w}): {cache.buffer}")
print(f"Tokens 0-3 were evicted; only the last {w} remain.")
The limitation is clear: there is no direct attention path between distant tokens. A token at position 50,000 cannot directly attend to a token at position 0. Information must hop through intermediate representations, layer by layer. Each hop attenuates the signal, so while the theoretical receptive field is $L \times w$, the practical influence of very distant tokens is weak. Patterns like global tokens (covered next) address this.
Longformer and BigBird: Structured Sparsity
Sliding windows work well for local context, but some tasks demand genuine long-range connections: classifying a legal document based on a clause buried thousands of tokens from the [CLS] token, or answering a question whose evidence spans multiple paragraphs. Two influential papers introduced structured sparse patterns that combine local and global attention.
Longformer (Beltagy et al., 2020) combines three attention patterns in a single layer:
- Sliding window: each token attends to $w$ neighbors on each side. This captures local syntactic and semantic context, just like the pattern we described above.
- Dilated sliding window: instead of attending to $w$ contiguous neighbors, attend to every 2nd (or $d$-th) token within a larger range. This is analogous to dilated convolutions in CNNs: by skipping tokens, the window covers a broader span ($w \times d$ positions) at the same computational cost as a standard window of size $w$.
- Global attention: a small set of designated tokens (e.g., the [CLS] token, or the first token of each paragraph) attend to all positions, and all positions attend back to them. If we have $g$ global tokens, this adds $O(g \cdot n)$ scores.
The global tokens are the key insight. They act as information bottlenecks that create shortcuts for long-range information flow. Without them, a signal from position 0 must hop through $\lceil n/w \rceil$ layers to reach position $n$. With a single global token, any position can reach any other in at most 2 hops (position $\rightarrow$ global token $\rightarrow$ target position). Since $g$ is small (often just 1 or 2), the total cost remains $O(n \cdot w + g \cdot n) = O(n \cdot (w + g))$, which is linear in $n$.
BigBird (Zaheer et al., 2020) takes a different approach to ensuring connectivity. It combines:
- Sliding window: same as Longformer, for local context.
- Global tokens: same idea, a few tokens attend to/from all positions.
- Random attention: each token additionally attends to $r$ randomly chosen positions. This is the novel component. The random edges ensure the attention graph is well-connected: with high probability, any two tokens are separated by a short path (logarithmic in $n$) through the graph.
BigBird's theoretical contribution is proving that this combination of local + global + random attention is a universal approximator : it can approximate any function that full attention can compute, provided $g$ and $r$ are set appropriately. The random edges are critical to this result. From graph theory, a graph with local connections plus a few random long-range edges is an expander graph with high probability, meaning any two nodes are connected by a short path. In attention terms: any two tokens can exchange information through just a few hops of attention.
Both models achieve $O(n)$ complexity (with constants depending on window size $w$, number of global tokens $g$, and random edges $r$). To see why, count the attention edges per token: $w$ from the sliding window, $g$ from global tokens (constant), and $r$ from random attention (constant). The total per token is $w + g + r$, all constants independent of $n$, so the total across all $n$ tokens is $O(n \cdot (w + g + r)) = O(n)$.
# Visualise the three attention patterns side by side
import json, js
n = 16 # sequence length (small for visualisation)
w = 3 # window half-size
g_idxs = [0] # global token indices
r = 2 # random edges per token
import random
random.seed(42)
def make_mask(n, w, g_idxs, r):
"""Build BigBird-style attention mask: local + global + random."""
mask = [[0]*n for _ in range(n)]
local_count = 0
global_count = 0
random_count = 0
for i in range(n):
# Sliding window
for j in range(max(0, i - w), min(n, i + w + 1)):
if mask[i][j] == 0:
mask[i][j] = 1
local_count += 1
# Global tokens
for g in g_idxs:
if mask[i][g] == 0:
mask[i][g] = 1
global_count += 1
if mask[g][i] == 0:
mask[g][i] = 1
global_count += 1
# Random
candidates = [j for j in range(n) if mask[i][j] == 0]
chosen = random.sample(candidates, min(r, len(candidates)))
for j in chosen:
mask[i][j] = 2 # mark as random (for colour)
random_count += 1
return mask, local_count, global_count, random_count
mask, lc, gc, rc = make_mask(n, w, g_idxs, r)
full_attention_scores = n * n
sparse_scores = sum(1 for row in mask for v in row if v > 0)
print(f"Sequence length: {n}")
print(f"Window half-size: {w} (each token sees {2*w+1} neighbors)")
print(f"Global tokens: {g_idxs}")
print(f"Random edges per token: {r}")
print(f"")
print(f"Full attention scores: {full_attention_scores}")
print(f"BigBird sparse scores: {sparse_scores}")
print(f" - Local (window): {lc}")
print(f" - Global: {gc}")
print(f" - Random: {rc}")
print(f"Sparsity: {1 - sparse_scores/full_attention_scores:.1%} of scores skipped")
Dilated and Strided Attention
Rather than using the same attention pattern in every layer, some architectures vary the pattern across layers to build a hierarchical receptive field . The idea is borrowed from dilated convolutions in CNNs: early layers look at fine-grained local detail, while deeper layers look at broader context with coarser resolution.
Concretely, consider a model with 4 layers and a base window of $w = 4{,}096$ tokens:
- Layer 0: attend to the nearest 4,096 tokens (stride 1). Dense local context.
- Layer 1: attend to every 2nd token up to 8,192 positions (stride 2). Broader range, half the density.
- Layer 2: attend to every 4th token up to 16,384 positions (stride 4). Even broader, even sparser.
- Layer 3: attend to every 8th token up to 32,768 positions (stride 8). Coarse global coverage.
Each layer computes exactly $w$ attention scores per token (the window size is fixed), so the per-layer cost is always $O(n \cdot w)$, linear in $n$. But the effective range doubles with each layer. At stride $s$ and window $w$, the attention spans $w \times s$ positions. At the boundaries: stride $s = 1$ gives the standard local window ($w$ positions); stride $s = n/w$ covers the entire sequence with exactly $w$ samples. As stride grows, we trade precision (skipping intermediate tokens) for reach (covering more of the sequence).
The combination across layers is powerful: layer 0 captures exact local word order, layer 1 captures paragraph-level structure, layer 2 captures section-level structure, and layer 3 captures document-level structure. By stacking these, the model gets dense local detail and sparse global context, all at $O(n \cdot w)$ total cost per layer. The Longformer paper's dilated window is a special case of this idea applied within a single layer.
# Dilated attention: same compute per layer, increasing reach
w = 8 # tokens attended per layer (budget)
layers = [
{"name": "Layer 0", "stride": 1},
{"name": "Layer 1", "stride": 2},
{"name": "Layer 2", "stride": 4},
{"name": "Layer 3", "stride": 8},
]
print(f"Window budget per layer: {w} tokens")
print(f"{'Layer':<10} {'Stride':<8} {'Range (positions)':<20} {'Cost (scores/token)'}")
print("-" * 60)
for layer in layers:
s = layer["stride"]
reach = w * s
print(f"{layer['name']:<10} {s:<8} {reach:<20} {w}")
total_range = w * layers[-1]["stride"]
print(f"\nAfter {len(layers)} layers, the model covers {total_range} positions")
print(f"Total cost per token: {w * len(layers)} scores ({w} per layer x {len(layers)} layers)")
print(f"Full attention would cost: {total_range} scores per token")
Ring Attention: Distributing Across GPUs
All the patterns above reduce the number of attention scores computed. But for very long sequences (millions of tokens), there's a more fundamental problem: the sequence doesn't fit in a single GPU's memory at all. Even with sliding window attention, storing the activations for millions of tokens during training exceeds the memory of any single device. We need to distribute the sequence itself across multiple GPUs.
Ring Attention (Liu et al., 2023) does exactly this. Arrange $d$ GPUs in a logical ring. The input sequence of $n$ tokens is split into $d$ contiguous blocks of $n/d$ tokens each. Each GPU $i$ holds:
- Its local Q block: the query vectors for its $n/d$ tokens. This stays on the GPU and never moves.
- One KV block: the key-value vectors for some chunk of the sequence. This rotates around the ring.
The algorithm proceeds in $d$ rounds. In each round, every GPU computes the attention between its local Q block and the current KV block it's holding, producing a partial attention output. Then, all GPUs simultaneously send their KV block to the next GPU in the ring and receive the KV block from the previous GPU. After $d$ rounds, every Q block has attended to every KV block in the sequence.
The key to efficiency is overlapping communication with computation . While GPU $i$ is computing attention for the current KV block, it is simultaneously sending that block to GPU $i+1$ and receiving the next block from GPU $i-1$. As long as the compute time for one block is greater than or equal to the transfer time, the communication cost is fully hidden.
Memory per GPU scales as $O(n/d)$: each GPU only stores queries for $n/d$ tokens and one KV block of size $n/d$ at a time. With $d$ GPUs, we can process sequences $d$ times longer than a single GPU could hold. This is how models like Gemini process 10M+ token contexts: distribute the sequence across hundreds of TPUs in a ring, with each device handling a manageable chunk.
# Ring Attention simulation: d GPUs, n tokens
n = 24 # total sequence length (tokens)
d = 4 # number of GPUs
block = n // d # tokens per GPU
print(f"Sequence length: {n} tokens")
print(f"GPUs: {d}")
print(f"Block size: {block} tokens per GPU")
print(f"Memory per GPU: O({n}/{d}) = O({block}) instead of O({n})")
print()
# Simulate the ring
for round_num in range(d):
print(f"Round {round_num}:")
for gpu in range(d):
kv_source = (gpu - round_num) % d # which GPU's KV block we're computing with
q_range = f"tokens [{gpu*block}-{(gpu+1)*block - 1}]"
kv_range = f"tokens [{kv_source*block}-{(kv_source+1)*block - 1}]"
print(f" GPU {gpu}: Q={q_range} x KV={kv_range}")
if round_num < d - 1:
print(f" >> All GPUs pass KV block to next neighbor")
print()
print(f"After {d} rounds: every Q block has seen every KV block")
print(f"Result: EXACT full attention, computed in distributed fashion")
print(f"Communication: {d-1} KV block transfers per GPU (overlapped with compute)")
Choosing the Right Pattern
Each pattern trades off simplicity, compute savings, and information flow. The right choice depends on the task, the model architecture, and the hardware constraints. Here's a summary:
import json, js
rows = [
["Full attention", "O(n\u00b2)", "Exact", "Single GPU, short ctx", "GPT-2, BERT"],
["Sliding window", "O(n\u00b7w)", "Approx (local)", "Streaming / inference", "Mistral 7B (w=4096)"],
["Longformer", "O(n)", "Approx (local+global)", "Long-doc classification", "Longformer-4096"],
["BigBird", "O(n)", "Approx (local+global+random)", "Long-doc QA/NER", "BigBird-4096"],
["Dilated / strided", "O(n\u00b7w)", "Approx (hierarchical)", "Multi-scale context", "Various research"],
["Ring Attention", "O(n\u00b2/d)", "Exact (distributed)", "Training on 1M+ tokens", "Gemini, long-ctx training"],
]
js.window.py_table_data = json.dumps({
"headers": ["Pattern", "Complexity", "Accuracy", "Best For", "Example"],
"rows": rows
})
print("Complexity notes:")
print(" n = sequence length, w = window size, d = number of GPUs")
print(" Sliding window: linear in n for fixed w")
print(" Ring Attention: same total work as full attention, but memory is O(n/d) per GPU")
A few practical guidelines:
- For autoregressive LLM inference: sliding window attention is the dominant choice. It bounds the KV cache, simplifies memory management, and works well for conversational and generative tasks where local context matters most. Mistral and its derivatives have proven this at scale.
- For encoder tasks on long documents: Longformer and BigBird patterns (local + global tokens) remain strong choices. The global tokens ensure that classification or extraction tasks can gather document-wide signals.
- For training on extremely long sequences: Ring Attention is the go-to distribution strategy. It allows exact attention over million-token contexts by spreading memory across GPUs. It composes with any local attention pattern.
- All patterns compose with FlashAttention: FlashAttention handles the low-level block computation (tiling, SRAM management) regardless of which sparse pattern selects the blocks. Sparse patterns choose what to compute; FlashAttention optimises how to compute it.
Quiz
Test your understanding of efficient attention patterns.
In sliding window attention with window size $w$ and $L$ layers, what is the effective receptive field of the final layer?
What is the purpose of global tokens in Longformer and BigBird?
How does Ring Attention differ from sparse attention patterns like sliding window or BigBird?
Why does BigBird include random attention edges in addition to local and global attention?