Why Would a Single Head Fall Short?

By now we have a complete single-head attention mechanism: queries and keys produce a score matrix, we scale by $\sqrt{d_k}$, apply softmax, and use the result to weight the values. That mechanism works, but it forces every token to summarize everything it needs from the sequence into a single attention pattern. Consider the word "bank" in the sentence "The bank by the river issued a statement." To understand this sentence, "bank" needs to attend to "river" (to resolve that this is a riverbank, not a financial institution) and simultaneously attend to "issued" (to pick up that the subject of the verb is "bank"). A single attention distribution is a probability distribution over positions, so the attention weights must sum to 1, and focusing heavily on "river" necessarily takes weight away from "issued."

This is not a contrived edge case. Language constantly requires tracking multiple relationships at once: syntactic dependencies (what is the subject of this verb?), semantic similarity (which words share meaning?), coreference (which noun does this pronoun refer to?), and positional proximity (what came right before this word?). A single attention head compresses all of these into one weighted average, which forces the model to compromise between competing demands.

The fix proposed in (Vaswani et al., 2017) is to run several attention heads in parallel, each with its own learned projections, so that different heads can specialize in different types of relationships. The outputs are then concatenated and projected back to the model dimension. This is multi-head attention .

How Does the Model Split Into Multiple Heads?

Multi-head attention splitting d_model into 4 heads of 64 dimensions each, with separate Q/K/V projections, attention, concat, and output projection
Multi-head attention splits d_model into H heads (here 4 heads of 64 dimensions each). Each head runs its own Q/K/V projections and attention, then the results are concatenated and projected back to d_model.

Suppose our model dimension is $d_{\text{model}} = 512$ and we want $H = 8$ heads. Rather than giving each head its own full-sized projections (which would multiply the parameter count by 8), we split the representation: each head operates on a slice of size $d_k = d_{\text{model}} / H = 64$. The total computation stays roughly the same as single-head attention over the full $d_{\text{model}}$, because we run $H$ smaller attentions instead of one large one.

Each head $i$ has its own projection matrices $W_i^Q$, $W_i^K$, and $W_i^V$, each of shape $(d_{\text{model}}, d_k)$. Given an input $X$ of shape $(T, d_{\text{model}})$ where $T$ is the sequence length, the projections for head $i$ are computed as follows.

$$Q_i = X W_i^Q, \quad K_i = X W_i^K, \quad V_i = X W_i^V \qquad \text{each of shape } (T, d_k)$$

Each head then runs standard scaled dot-product attention independently, producing an output of shape $(T, d_k)$. Since there are $H$ heads, we get $H$ outputs of shape $(T, d_k)$, which we concatenate along the last dimension to recover shape $(T, d_{\text{model}})$. A final output projection $W^O$ of shape $(d_{\text{model}}, d_{\text{model}})$ mixes information across heads.

$$\text{head}_i = \text{Attention}(Q_i, K_i, V_i) = \text{softmax}\!\left(\frac{Q_i K_i^\top}{\sqrt{d_k}}\right) V_i$$
$$\text{MultiHead}(X) = \text{Concat}(\text{head}_1, \ldots, \text{head}_H) \, W^O$$

Let's trace the shapes through a concrete example with $d_{\text{model}} = 512$, $H = 8$, $d_k = 64$, and a sequence of $T = 10$ tokens:

  • Input: $X$ is $(10, 512)$.
  • Per-head projection: $Q_i = X W_i^Q$ maps $(10, 512) \times (512, 64) \to (10, 64)$. Same for $K_i$ and $V_i$.
  • Per-head attention scores: $Q_i K_i^\top$ maps $(10, 64) \times (64, 10) \to (10, 10)$. Each head produces its own $T \times T$ attention matrix.
  • Per-head output: $\text{softmax}(\ldots) \, V_i$ maps $(10, 10) \times (10, 64) \to (10, 64)$.
  • Concatenation: stack all 8 heads along the last axis: $(10, 64) \times 8 \to (10, 512)$.
  • Output projection: $W^O$ maps $(10, 512) \times (512, 512) \to (10, 512)$.

The output has exactly the same shape as the input, $(T, d_{\text{model}})$, which means multi-head attention is a drop-in replacement for single-head attention. This shape-preserving property also matters for stacking layers, since the output of one attention block becomes the input to the next.

๐Ÿ’ก The output projection $W^O$ is important because concatenation alone just pastes the head outputs side by side. Without $W^O$, each dimension in the output would come from exactly one head. The projection lets information from different heads mix, so any output dimension can be a learned combination of all heads' contributions.

What Does This Look Like in Code?

In practice, we don't loop over heads one by one. Instead, we project the full $d_{\text{model}}$ for all heads at once using a single weight matrix of shape $(d_{\text{model}}, d_{\text{model}})$, then reshape the result to separate the heads. This is mathematically identical to having $H$ separate $(d_{\text{model}}, d_k)$ projections, but it runs in one matrix multiply instead of $H$, which is much faster on GPUs.

The following implementation keeps things explicit so we can see each step clearly. We project all heads at once, reshape to split them, run attention per head, concatenate, and apply the output projection.

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # One big projection for all heads, then we reshape
        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)

    def forward(self, X, mask=None):
        B, T, _ = X.shape  # batch, sequence length, d_model

        # Project all heads at once: (B, T, d_model) -> (B, T, d_model)
        Q = self.W_Q(X)
        K = self.W_K(X)
        V = self.W_V(X)

        # Reshape to (B, num_heads, T, d_k) so each head has its own slice
        Q = Q.view(B, T, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(B, T, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(B, T, self.num_heads, self.d_k).transpose(1, 2)

        # Scaled dot-product attention per head
        # scores: (B, num_heads, T, T)
        scores = (Q @ K.transpose(-2, -1)) / (self.d_k ** 0.5)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)

        # Weighted sum of values: (B, num_heads, T, d_k)
        head_outputs = attn_weights @ V

        # Concatenate heads: (B, T, d_model)
        concat = head_outputs.transpose(1, 2).contiguous().view(B, T, self.d_model)

        # Final projection mixes information across heads
        return self.W_O(concat)

# --- Quick shape check ---
d_model, num_heads, T, batch = 512, 8, 10, 2
mha = MultiHeadAttention(d_model, num_heads)
X = torch.randn(batch, T, d_model)
output = mha(X)

print(f"Input shape:  {X.shape}")      # (2, 10, 512)
print(f"Output shape: {output.shape}")  # (2, 10, 512)
print(f"Heads: {num_heads}, head dim: {d_model // num_heads}")
print(f"Parameters: {sum(p.numel() for p in mha.parameters()):,}")

A few things to notice in the code. The view and transpose calls are the key operations: view(B, T, num_heads, d_k) splits the last dimension into separate heads, and transpose(1, 2) moves the head dimension before the sequence dimension so that the batch matrix multiply Q @ K.transpose(-2, -1) runs attention independently for each head in parallel. After attention, we reverse the operation with transpose(1, 2).contiguous().view(B, T, d_model) to concatenate the heads back together.

๐Ÿ“Œ The total parameter count is $4 \times d_{\text{model}}^2$ (one $(d_{\text{model}}, d_{\text{model}})$ matrix each for $W^Q$, $W^K$, $W^V$, and $W^O$). This is independent of the number of heads, because the head count only changes how we partition the dimensions, not the total number of weights.

What Do Different Heads Actually Learn?

Nothing in the architecture forces heads to specialize, yet empirically they tend to. Several studies have analyzed what patterns emerge in trained transformers, and the findings are remarkably consistent.

Clark et al. (2019) analyzed BERT's attention heads and found that individual heads often learn interpretable roles. Some heads track syntactic dependencies (the head consistently attends from a verb to its subject, regardless of distance), while others focus on positional patterns (attending to the previous token or the next token). Voita et al. (2019) ("Analyzing Multi-Head Self-Attention") went further and identified three dominant head types in English-to-Russian translation models: positional heads (attending to an adjacent position), syntactic heads (attending along dependency parse edges), and rare-word heads (attending to the least frequent tokens in the sentence, which tend to carry the most disambiguating information).

This specialization is why multi-head attention outperforms simply increasing $d_k$ in a single head. A larger single head has more capacity, but it still produces one attention pattern per token. Multiple heads allow the model to route different types of information through different channels simultaneously, and the output projection $W^O$ learns to combine them. In graph-theoretic terms (recalling our earlier framing of the attention matrix as an adjacency matrix), each head defines a different graph over the same set of tokens, and the model reads all of these graphs at once.

๐Ÿ’ก Voita et al. also found that many heads in trained models can be pruned without significant performance loss, suggesting that not all heads end up learning useful patterns. In a 6-layer, 8-head transformer, they pruned more than half the heads with minimal degradation. The surviving heads were predominantly the positional and syntactic ones.

Quiz

Test your understanding of multi-head attention and how it extends the single-head mechanism.

If d_model = 768 and we use 12 attention heads, what is the dimension of each head's query, key, and value vectors?

Why does multi-head attention use an output projection W_O after concatenating the heads?

What is the main advantage of multi-head attention over single-head attention with the same d_model?

How does the total parameter count of multi-head attention compare to single-head attention (assuming the same d_model)?