What If the Model Could Remember Beyond Its Window?
Sparse attention (the previous article) limits which tokens attend to which, cutting the quadratic cost. But the total context is still bounded: a sliding window of size $w$ can only see $w$ tokens into the past, and even global + local hybrid patterns still operate over a single, finite input sequence. What if the model had a separate memory that persists across context windows — a store of information from the past that the model can query at each step, regardless of how far back the original tokens appeared?
There are three broad approaches to building this memory:
- Exact retrieval: store all past KV pairs in a big external cache, and retrieve the most relevant ones via approximate nearest-neighbour search (kNN) when needed.
- Compressive memory: compress old KV pairs into a fixed-size memory matrix. The memory never grows, no matter how much past context has been processed, but it's lossy.
- Learned neural memory: make the memory a small neural network that is updated via gradient descent at test time. New inputs are "written" to memory by training the network on them.
This is different from RAG , which retrieves from an external corpus of documents the model has never seen during its forward pass. Memory-augmented transformers remember their own past context — tokens that already flowed through the model in earlier segments of the same sequence. RAG extends what the model knows ; memory-augmented architectures extend what it remembers .
Memorizing Transformers: Exact kNN Retrieval
The simplest idea: just store all past key-value pairs in a big external cache, and at each attention step, search that cache for the most relevant entries. This is the approach of Memorizing Transformers (Wu et al., 2022) .
During attention at layer $l$, the model runs two parallel attention computations:
- Local attention: standard scaled dot-product attention over the current context window, producing $A_{\text{local}}$. This is ordinary self-attention — nothing new.
- kNN attention: use the current query vectors to search the external KV cache via approximate kNN. Retrieve the top-$k$ most similar past key-value pairs for each query, compute attention scores over just those $k$ retrieved pairs, and produce $A_{\text{kNN}}$.
The two attention outputs are combined with a learned gate:
The gate is a learned scalar (passed through a sigmoid) that lets the model decide, per head, how much to rely on retrieved memories versus local context. When gate $\to 0$, the layer ignores the external cache entirely and behaves like a standard transformer. When gate $\to 1$, the layer relies entirely on retrieved past context. In practice the model learns to use the external cache primarily in the upper layers, because lower layers tend to handle local syntactic patterns that don't benefit from long-range retrieval.
The external cache itself is a FIFO buffer : as the model processes new segments, it appends the current segment's KV pairs to the cache and evicts the oldest entries when the cache is full. Wu et al. show the cache can scale to 262K tokens with efficient approximate kNN search using product quantization (the same indexing technique used in vector databases for RAG ). Because the kNN search retrieves only the top-$k$ entries (typically $k = 32$ or $k = 64$) rather than attending over the full cache, the attention cost for the external memory is $O(k \cdot d)$ per query — constant with respect to the cache size.
Infini-Attention: Compressive Memory
Storing exact KV pairs works, but the cache grows linearly with the amount of past context. Can we do better? Infini-attention (Munkhdalai et al., 2024) takes a different approach: instead of storing exact KV pairs, compress all past context into a fixed-size memory matrix that never grows, no matter how long the sequence gets.
Each attention layer maintains two persistent objects: a memory matrix $M \in \mathbb{R}^{d_k \times d_v}$ and a normaliser vector $z \in \mathbb{R}^{d_k}$. These are initialised to zeros and updated as the model processes successive segments of the input.
After processing segment $s$ (with keys $K_s$ and values $V_s$ from standard attention), the memory is updated:
where $\sigma$ is a nonlinearity (specifically ELU + 1, which ensures all values are positive). Let's unpack what this update does. The term $\sigma(K_s)^T V_s$ is a sum of outer products: each key vector $\sigma(k_{s,i})$ (of dimension $d_k$) is multiplied by its corresponding value vector $v_{s,i}$ (of dimension $d_v$), producing a $d_k \times d_v$ matrix. Summing over all tokens in the segment gives a matrix that encodes key-value associations. This gets added to the existing memory $M_{s-1}$, accumulating information across segments. The normaliser $z_s$ tracks the total "mass" of the keys that have been written, which is needed for proper normalisation during retrieval.
To retrieve from memory, the model uses the current segment's queries:
This is a linear attention lookup. Each query $\sigma(q_{s,i})$ is multiplied by the memory matrix $M_s$ to produce a $d_v$-dimensional output (the "retrieved" value), then divided by $\sigma(q_{s,i})^T z_s$ for normalisation (analogous to how softmax normalises standard attention scores). This retrieval costs $O(d_k \cdot d_v)$ per query — completely independent of how many past tokens have been compressed into $M$. That's the key advantage: the memory has fixed size $d_k \times d_v$ regardless of sequence length.
The final output combines local attention and memory attention via a learned gate $\beta$:
Here $\sigma(\beta)$ is a sigmoid applied to a learned parameter $\beta$ (per head). Boundary analysis: when $\beta \to -\infty$, $\sigma(\beta) \to 0$, and the output is pure local attention — the memory is completely ignored, and the layer behaves like standard attention. When $\beta \to +\infty$, $\sigma(\beta) \to 1$, and the output is pure memory — local context is ignored. In between, the model blends both sources. Each head can learn its own $\beta$, so some heads can specialise in local patterns while others focus on long-range memory.
Titans: Learning to Memorize at Test Time
Both Memorizing Transformers and Infini-attention store past context as data (KV pairs or compressed matrices). Titans (Behrouz et al., 2025) takes a more radical approach: the memory is a neural network — a small MLP — that is updated via gradient descent at test time . New information is "written" to memory by training the MLP on it, and "read" by running a forward pass through it.
Titans organises information into three memory systems:
- Short-term memory: standard attention over a local window. This is the normal context window — precise but bounded.
- Long-term memory: a small MLP that is trained via SGD during inference. Past tokens are encoded into the MLP's weights.
- Persistent memory: fixed learnable parameters (similar to prompt tuning) that encode model-level knowledge shared across all inputs. These are trained during pretraining and frozen at test time.
The long-term memory is where the novelty lies. At each time step $t$, the memory MLP $M$ is updated:
This is a standard gradient descent step, but running during inference , not training. Here's how it works:
- Recall: run a forward pass through the memory MLP with the current input $x_t$ to produce a prediction.
- Surprise: compute a reconstruction loss $\mathcal{L}(M_{t-1}, x_t)$ measuring how well the memory predicted the current input. If the input is surprising (high loss), the memory was missing relevant information.
- Store: backpropagate through the MLP and take a gradient step to update the memory weights. Surprising inputs cause larger gradient updates, writing them more strongly into memory.
There are two key hyperparameters. The memory learning rate $\eta$ controls how quickly new information is written. The forgetting rate $\lambda$ applies weight decay after each update:
This weight decay causes old memories to gradually fade unless they are reinforced by recurring patterns. Boundary analysis: when $\eta = 0$, the memory receives no gradient updates and remains frozen — the model has no long-term memory and relies only on its local attention window. When $\lambda = 0$, there is no forgetting: the MLP accumulates all past information without decay, which means its weights grow in magnitude unboundedly and can eventually become unstable. When $\lambda = 1$, the memory is completely erased at every step — the weight decay zeros out all parameters, giving the model no memory at all. The sweet spot is small $\eta$ (gentle writes) and small $\lambda$ (slow forgetting), allowing the MLP to build up a persistent representation of past context.
Why is this novel? It reframes memory as a neural network continuously trained during inference. The memory has constant size (the MLP's parameters) regardless of sequence length, just like Infini-attention's memory matrix. But unlike a linear memory matrix, an MLP can represent nonlinear associations — in principle, it can learn complex patterns that a simple outer-product memory cannot. The cost is additional compute: each memory update requires a forward pass, loss computation, and backward pass through the MLP. This connects to the broader emerging idea of test-time training / test-time adaptation, where models continue learning during inference rather than using frozen weights.
Landmark Attention: Retrieval Within Attention
The approaches above add external memory modules to the transformer. Landmark attention (Mohtashami & Jaggi, 2023) takes a different angle: instead of a separate memory, restructure the attention mechanism itself to enable random-access retrieval over arbitrarily long sequences.
The idea is to insert special landmark tokens at regular intervals, every $k$ tokens. Each landmark is trained to summarise the block of $k$ tokens that precedes it — it acts as a learned "table of contents" entry for its block. During attention, the model uses a two-stage process:
- Stage 1 — Attend to landmarks: each query first attends to all landmark tokens. With a sequence of $n$ tokens and one landmark every $k$ tokens, there are only $n / k$ landmarks. This is cheap.
- Stage 2 — Select and attend to full blocks: based on the landmark attention scores, select the top-$b$ most relevant blocks (the ones whose landmarks received the highest attention). Load the full KV pairs for those $b$ blocks and compute fine-grained attention over them.
The total attention cost per query is:
The first term $n/k$ is the cost of attending to all landmarks (coarse retrieval). The second term $b \cdot k$ is the cost of attending to the $b$ selected blocks of $k$ tokens each (fine-grained attention). Boundary analysis: when $b = n/k$ (select all blocks), the cost reduces to $n/k + n = O(n)$, which is essentially full attention. When $b = 1$ (select only the single most relevant block), the cost is $n/k + k$. For balanced performance, choosing $k = \sqrt{n}$ and $b = O(1)$ gives a cost of $O(\sqrt{n})$ per query — sublinear in sequence length. This is far less than the $O(n)$ of full attention when only a few blocks are relevant.
The key advantage over sliding-window attention is random access . A sliding window can only see the most recent $w$ tokens — it has no way to jump back to an important passage 50,000 tokens ago. Landmark attention can reach any block in the sequence, as long as its landmark is deemed relevant. The trade-off is that the landmarks must be good summaries of their blocks. If a landmark fails to capture the relevant information in its block, the block won't be selected and its tokens will be invisible to the query.
Choosing the Right Memory Approach
These four approaches form a spectrum from exact retrieval to learned abstraction. The following table compares them across the dimensions that matter in practice:
import json, js
rows = [
["Memorizing Transformers\n(Wu et al., 2022)",
"Exact KV cache + kNN",
"O(n) memory",
"Highest (exact recall)",
"Upper layers only",
"Cache eviction (FIFO)"],
["Infini-Attention\n(Munkhdalai et al., 2024)",
"Compressed matrix",
"O(d_k * d_v) fixed",
"Lossy (linear approx.)",
"All layers",
"Accumulation in matrix"],
["Titans\n(Behrouz et al., 2025)",
"Neural net (MLP) + SGD",
"O(MLP params) fixed",
"Lossy (nonlinear)",
"Separate module",
"Weight decay forgetting"],
["Landmark Attention\n(Mohtashami & Jaggi, 2023)",
"Landmark-guided retrieval",
"O(n/k) landmarks",
"Exact (for selected blocks)",
"Within attention",
"Block selection via scores"],
]
js.window.py_table_data = json.dumps({
"headers": [
"Method",
"Memory Type",
"Memory Cost",
"Fidelity",
"Integration",
"Forgetting Mechanism"
],
"rows": rows
})
print("Comparison of memory-augmented transformer approaches")
print()
print("Exact retrieval = highest fidelity, highest memory cost")
print("Compressed = fixed memory regardless of length, but lossy")
print("Learned (Titans) = most flexible, adds compute (test-time SGD)")
print("Landmark = random-access retrieval, depends on landmark quality")
At one end, Memorizing Transformers store exact KV pairs and retrieve via kNN — maximum fidelity, but memory cost grows linearly with past context. At the other end, Titans compresses everything into a small MLP's weights — fixed memory footprint and the ability to capture nonlinear associations, but at the cost of test-time compute (forward + backward through the MLP at every step) and being the newest and least proven approach at scale. Infini-attention sits in the middle: fixed-size memory with simple linear updates, no test-time training, but limited by the expressiveness of the linear outer-product compression. Landmark attention is orthogonal: it doesn't add a separate memory, but restructures attention to enable efficient random access.
In practice, most production models today (GPT-4, Claude, Gemini) rely on long context windows + FlashAttention rather than explicit external memory. The context windows have grown large enough (128K–1M+ tokens) that many use cases fit within a single window, and FlashAttention (Dao et al., 2022) makes the IO cost manageable even at those lengths. But the field is moving fast. As models are asked to process entire codebases, book-length documents, or persistent multi-turn conversations spanning thousands of turns, the need for explicit memory beyond the context window will likely grow — and the approaches in this article represent the leading research directions for meeting that need.
Quiz
Test your understanding of memory-augmented transformer architectures.
In Memorizing Transformers, how are local attention and kNN attention combined?
What is the key advantage of Infini-attention's memory matrix over storing exact KV pairs?
In Titans, what determines how strongly a new input is written into the long-term memory?
How does landmark attention achieve sublinear attention cost?