Why Replace Attention?
Every attention mechanism we've seen so far — full, sparse, sliding window — shares one fundamental cost: each token must interact with other tokens to decide what's relevant. Full attention is $O(n^2)$ in both compute and memory. Sparse attention (article 3) reduces the constant factor or drops some interactions, but it's still fundamentally quadratic or at best $O(n \sqrt{n})$. Sliding window attention is $O(n \cdot w)$, linear in $n$ but only because it gives up direct long-range access. What if we could build a sequence model that is $O(n)$ from the very start — processing each token in constant time with respect to sequence length — while still capturing long-range dependencies?
That's the promise of State Space Models (SSMs) . Instead of computing pairwise interactions between all tokens (as attention does), an SSM maintains a fixed-size hidden state that gets updated as each new token arrives. Think of it like a recurrent neural network (RNN) — but with a very specific mathematical structure that makes it trainable on long sequences. The key idea: compress the entire history into a state vector $h \in \mathbb{R}^N$, then update that state in $O(N)$ time per step, where $N$ is the state dimension (typically 16–64), not the sequence length.
At the boundary: if the state dimension $N = 1$, we have a single number summarising all of history — hopelessly lossy. If $N = n$ (state as large as the sequence), we recover something equivalent to storing the entire sequence, losing the efficiency advantage. The practical regime is $N \ll n$: a compact state (say, 16 dimensions) that captures the essential patterns across millions of tokens.
State Space Models: The Continuous-Time Foundation
The Structured State Space for Sequence Modeling (S4) model (Gu et al., 2022) starts from a classical idea in control theory: a continuous-time state space model . The system has an input signal $x(t)$ (a scalar at each time $t$), a hidden state $h(t) \in \mathbb{R}^N$, and an output signal $y(t)$. The dynamics are governed by four matrices:
Let's define every symbol. $h(t) \in \mathbb{R}^N$ is the hidden state at time $t$ — a vector that compresses all information from the past. $h'(t)$ is its time derivative (how fast the state is changing). $x(t) \in \mathbb{R}$ is the input at time $t$ (one channel — S4 processes each channel independently). $y(t) \in \mathbb{R}$ is the output. $A \in \mathbb{R}^{N \times N}$ is the state transition matrix — it controls how the hidden state evolves over time. This is the most important matrix: it determines what the model remembers and what it forgets. $B \in \mathbb{R}^{N \times 1}$ is the input projection matrix — how the new input feeds into the state. $C \in \mathbb{R}^{1 \times N}$ is the output projection matrix — how we read the state to produce the output. $D \in \mathbb{R}$ is a skip connection from input to output (often set to zero or treated as a residual).
But we're working with discrete sequences (tokens), not continuous signals. S4 discretises the continuous system using a step size $\Delta > 0$. The standard approach (zero-order hold) converts $(A, B)$ into discrete matrices $(\bar{A}, \bar{B})$:
In practice, a first-order approximation is often used: $\bar{A} \approx I + \Delta A$ and $\bar{B} \approx \Delta B$. After discretisation, the model becomes a simple recurrence:
This looks exactly like an RNN! At each step $k$, we take the previous state $h_{k-1}$, multiply by $\bar{A}$ (the discrete state transition), add the input $x_k$ scaled by $\bar{B}$, and produce the next state $h_k$. Reading out with $C$ gives the output. Processing a sequence of length $n$ takes $O(n \cdot N^2)$ time in recurrent mode — linear in sequence length.
But there's a crucial trick that makes S4 trainable (unlike vanilla RNNs). Because $\bar{A}$, $\bar{B}$, $C$, and $D$ are fixed for all time steps (they don't depend on the input), we can unroll the recurrence into a convolution. The output at step $k$ is:
The kernel $K_j = C \bar{A}^j \bar{B}$ can be precomputed, and then the entire output sequence is a convolution $y = K * x$. Convolutions can be computed in $O(n \log n)$ via FFT — massively parallelisable on GPUs, unlike the sequential recurrence. So S4 has two modes : a convolutional mode for parallel training ($O(n \log n)$) and a recurrent mode for efficient autoregressive inference ($O(n)$).
The final piece of S4 is HiPPO initialisation (Gu et al., 2020) . Random initialisation of $A$ leads to exploding or vanishing states over long sequences (the same problem that plagues RNNs). HiPPO (High-order Polynomial Projection Operators) initialises $A$ to a specific matrix that continuously projects the input history onto a basis of Legendre polynomials. Intuitively, the state $h(t)$ stores a compressed polynomial approximation of everything the model has seen so far. This gives S4 the ability to remember information over tens of thousands of steps — something vanilla RNNs and even LSTMs struggle with.
import json, js
import math
# Demonstrate the S4 recurrence for a tiny example
# State dim N=2, sequence length=8
N = 2 # state dimension
# Simplified discrete parameters (not real HiPPO, just illustrative)
delta = 0.1
# A is a decay matrix (diagonal for simplicity)
A_bar = [[1 - delta * 0.5, 0], [0, 1 - delta * 0.3]] # slow + fast decay
B_bar = [[delta * 1.0], [delta * 0.5]]
C = [1.0, 1.0]
D = 0.0
# Input sequence
x = [0, 0, 1, 0, 0, 0, 0, 0] # impulse at step 2
# Run recurrence
h = [0.0, 0.0]
rows = []
for k in range(len(x)):
# h_k = A_bar @ h_{k-1} + B_bar * x_k
h_new = [
A_bar[0][0] * h[0] + A_bar[0][1] * h[1] + B_bar[0][0] * x[k],
A_bar[1][0] * h[0] + A_bar[1][1] * h[1] + B_bar[1][0] * x[k],
]
y_k = C[0] * h_new[0] + C[1] * h_new[1] + D * x[k]
rows.append([
str(k),
str(x[k]),
f"[{h_new[0]:.4f}, {h_new[1]:.4f}]",
f"{y_k:.4f}"
])
h = h_new
js.window.py_table_data = json.dumps({
"headers": ["Step k", "Input x_k", "State h_k", "Output y_k"],
"rows": rows
})
print("S4 recurrence with N=2 state, impulse at step 2")
print("Notice: state decays gradually, so the output 'remembers' the impulse")
print("The slow dimension (0.95 decay) retains more than the fast one (0.97 decay)")
Mamba: Making SSMs Data-Dependent
S4 has a fundamental limitation: the matrices $A$, $B$, $C$, and $\Delta$ are the same for every input token . Whether the model sees the word "important" or the word "the", it applies the same state transition. This means S4 cannot selectively focus on or ignore specific inputs — the compression is content-agnostic. To see why this matters, consider a simple task: given a sequence like "key: 42 ... noise ... noise ... query: key", the model must retrieve 42. An S4 model processes "42" and "noise" with the same $B$ matrix, so both get written into the state equally. It has no mechanism to say "this token is important, write it more strongly" or "this is noise, skip it."
Mamba (Gu & Dao, 2023) solves this by making the SSM parameters input-dependent . Specifically, Mamba makes $B$, $C$, and $\Delta$ functions of the current input $x_k$:
where each $\text{Linear}$ is a learned projection from the input dimension to the appropriate shape. The softplus on $\Delta_k$ ensures it's always positive (since $\Delta$ is a time step, it must be $> 0$). The matrix $A$ remains fixed (initialised with HiPPO) — making it input-dependent would break the structured properties that enable efficient computation.
Why is $\Delta_k$ so important? Recall the discretised state transition: $\bar{A}_k = e^{\Delta_k A}$. When $\Delta_k$ is large, $\bar{A}_k$ decays more (the state "forgets" more of the past) and $\bar{B}_k$ is larger (the new input is written more strongly). When $\Delta_k$ is small, $\bar{A}_k \approx I$ (the state is preserved as-is) and $\bar{B}_k$ is small (the input is mostly ignored). So $\Delta_k$ acts as a selective gate : the model learns to set $\Delta_k$ large for important tokens ("pay attention, update the state") and small for irrelevant tokens ("skip this, keep the state").
At the boundaries: if $\Delta_k \to 0$ for every token, $\bar{A}_k \to I$ and $\bar{B}_k \to 0$ — the state never changes and the model ignores all input. If $\Delta_k \to \infty$ for every token, $\bar{A}_k \to 0$ and the state is wiped clean at every step — the model has no memory at all, only seeing the current token. The learned $\Delta_k$ values sit between these extremes, and crucially, they vary per token.
But making the parameters input-dependent breaks the convolution trick! With fixed parameters, the kernel $K_j = C\bar{A}^j\bar{B}$ can be precomputed. With input-dependent $B_k$, $C_k$, $\Delta_k$, the kernel changes at every position. Mamba can't use FFT-based convolutions — so how does it train efficiently?
The answer is the selective scan algorithm , implemented with a hardware-aware approach. The key insight is that the recurrence $h_k = \bar{A}_k h_{k-1} + \bar{B}_k x_k$ is a parallel prefix sum (also called a scan). Just as you can compute cumulative sums in $O(n)$ work with $O(\log n)$ parallel steps, you can compute this linear recurrence in parallel using an associative scan. Mamba implements this as a custom CUDA kernel that:
- Loads the input and parameters from GPU HBM (high bandwidth memory) into fast SRAM
- Computes the discretisation ($\bar{A}_k$, $\bar{B}_k$) in SRAM
- Runs the parallel scan entirely in SRAM
- Writes only the final outputs back to HBM
This avoids the memory bottleneck that would otherwise make the scan slow on GPUs. The result: Mamba trains at speeds comparable to optimised Transformers, while scaling linearly with sequence length. On a 1M-token sequence, attention would need $\sim$1 trillion pairwise interactions. Mamba processes it in $O(n \cdot N \cdot D)$ time, where $N$ is the state dimension (typically 16) and $D$ is the model dimension.
import json, js
import math
# Compare compute scaling: attention vs Mamba
seq_lengths = [1024, 4096, 16384, 65536, 262144, 1048576]
labels = ["1K", "4K", "16K", "64K", "256K", "1M"]
d_model = 768
n_heads = 12
d_head = d_model // n_heads
N_state = 16 # Mamba state dimension
rows = []
for n, label in zip(seq_lengths, labels):
# Attention: 2 * n^2 * d (QK^T + softmax @ V, per head, simplified)
attn_flops = 2 * n * n * d_model
# Mamba: n * N * D (scan + input projections, simplified)
mamba_flops = n * N_state * d_model
ratio = attn_flops / mamba_flops
rows.append([
label,
f"{attn_flops:.2e}",
f"{mamba_flops:.2e}",
f"{ratio:.0f}x"
])
js.window.py_table_data = json.dumps({
"headers": ["Seq Length", "Attention FLOPs", "Mamba FLOPs", "Attention / Mamba"],
"rows": rows
})
print("Simplified compute comparison (single layer, d_model=768)")
print("Attention scales as O(n^2 * d), Mamba as O(n * N * d)")
print(f"At 1M tokens, attention needs {rows[-1][3]} more compute than Mamba")
Mamba-2 and Structured State Space Duality
If SSMs and attention seem like completely different paradigms, Mamba-2 (Dao & Gu, 2024) reveals a deep mathematical connection between them. The paper introduces the Structured State Space Duality (SSD) framework, which shows that a linear SSM is equivalent to a specific form of attention — and vice versa.
Here's the intuition. Write out the SSM recurrence for every position and collect the outputs into a matrix equation. For a sequence of length $n$, the output $y_k$ depends on all previous inputs $x_0, \ldots, x_k$ through the accumulated state. If you write the matrix $M$ where $M_{k,j}$ gives the weight of input $x_j$ on output $y_k$, you get:
where $\bar{A}_{k:j+1} = \bar{A}_k \bar{A}_{k-1} \cdots \bar{A}_{j+1}$ is the product of all state transitions from step $j+1$ to $k$. This matrix $M$ is lower triangular (causal — outputs only depend on past inputs) and structured (every entry is determined by the SSM parameters, not freely learned).
Now compare this to causal attention. The attention matrix is also lower triangular (the causal mask), and each entry is a function of the query and key at those positions. The SSD insight is that when the attention matrix has the specific structure $M_{k,j} = Q_k^\top S_{k:j+1} K_j$ (where $S$ is a structured mask that decays with distance), this is mathematically identical to a linear SSM. The queries play the role of $C$, the keys play the role of $B$, and the structured mask $S$ encodes the state transitions $\bar{A}$.
This duality has a practical consequence: Mamba-2 can be computed using either the SSM recurrence (efficient for long sequences) or a matrix multiplication formulation (efficient on modern hardware with tensor cores). Mamba-2 uses a block decomposition : it splits the sequence into chunks, uses the quadratic (attention-like) formulation within each chunk (to exploit tensor cores), and uses the linear (SSM) recurrence to propagate state between chunks. This hybrid algorithm is 2–8$\times$ faster than Mamba-1's selective scan on modern GPUs.
Mamba-2 also simplifies the architecture. Where Mamba-1 used a state dimension $N$ that was independent of the head dimension, Mamba-2 introduces multi-head SSM (analogous to multi-head attention), where each head has its own $A$, $B$, $C$ matrices. The head structure matches the tensor core layout, further improving hardware utilisation.
SSMs vs Attention: Strengths and Weaknesses
With SSMs like Mamba and attention-based Transformers as the two main paradigms, when should you use which? Neither dominates everywhere. Let's compare them on the dimensions that matter for production systems.
Compute scaling. Attention is $O(n^2)$ in sequence length. Even with FlashAttention (which reduces memory to $O(n)$), the compute remains quadratic. SSMs are $O(n)$ — processing a 1M-token sequence costs 1000$\times$ less than attention. For very long sequences (books, codebases, genomics), this is the deciding factor.
Memory during inference. Attention-based models need a KV cache that grows linearly with sequence length (and is expensive — the LLaMA 3 70B example from article 6 showed 42 GB for 128K tokens even with GQA). SSMs need only the fixed-size state $h \in \mathbb{R}^N$ — typically a few kilobytes per layer. At 1M tokens, the KV cache for a Transformer might be hundreds of gigabytes; the SSM state stays at a few megabytes.
In-context learning. This is where attention shines. Attention can perform exact token-level lookup : given "The capital of France is [blank]", attention can directly attend to the token "Paris" wherever it appeared in the context. SSMs compress all of history into a fixed-size state, which means exact retrieval of a specific token from thousands of steps ago is difficult. Empirically, Mamba models underperform Transformers on tasks that require precise copying or retrieval from context (e.g., associative recall, phone book lookup).
Training throughput. During training, attention benefits from massive parallelism (all pairs can be computed simultaneously). S4's convolution mode and Mamba-2's block decomposition achieve competitive training throughput, but attention + FlashAttention on short-to-medium sequences (up to ~8K) is hard to beat due to years of hardware and software optimisation. On sequences beyond 16K, SSMs start to pull ahead.
import json, js
# Summary comparison table
rows = [
["Compute scaling", "O(n^2)", "O(n)", "SSM"],
["Memory (inference)", "O(n) KV cache", "O(1) state", "SSM"],
["In-context learning", "Excellent", "Limited", "Attention"],
["Precise recall", "Exact (attends to any token)", "Approximate (compressed state)", "Attention"],
["Training (short seq)", "Highly optimised", "Competitive", "Attention"],
["Training (long seq)", "Quadratic cost", "Linear cost", "SSM"],
["Hardware maturity", "Years of optimisation", "Newer, catching up", "Attention"],
["Streaming / real-time", "Must recompute or cache", "Natural (recurrent)", "SSM"],
]
js.window.py_table_data = json.dumps({
"headers": ["Dimension", "Attention", "SSM (Mamba)", "Advantage"],
"rows": rows
})
print("Neither architecture dominates across all dimensions")
print("Attention wins at precise recall and in-context learning")
print("SSMs win at long-sequence efficiency and streaming inference")
The takeaway: SSMs are not a drop-in replacement for attention. They excel in different regimes. For short-to-medium contexts (up to ~8K tokens) where in-context learning and precise recall matter, attention-based Transformers remain superior. For very long sequences (64K+) where the quadratic cost becomes prohibitive, or for streaming applications where maintaining a KV cache is impractical, SSMs offer a compelling alternative. And as we'll see in the next article, the most promising direction may be combining both .
Quiz
Test your understanding of state space models, S4, and Mamba.
In the S4 model, what is the role of HiPPO initialisation for the matrix $A$?
What is Mamba's key innovation over S4?
In Mamba, what happens when the learned step size $\Delta_k$ is very small for a given token?
What does the Structured State Space Duality (SSD) framework in Mamba-2 reveal?