¿Qué Son los Queries, Keys y Values?

El artículo anterior terminó con una pregunta: si eliminamos el RNN y dependemos solo de la atención, ¿cómo decide cada posición qué otras posiciones son importantes? La atención de Bahdanau usaba el estado oculto del decoder para consultar al encoder, pero en un transformer no hay estados ocultos construidos a través de recurrencia. En su lugar, cada posición en la secuencia obtiene tres representaciones diferentes — un query , un key y un value — cada uno producido por una proyección lineal aprendida separada del mismo embedding de entrada.

Dada una matriz de entrada $X \in \mathbb{R}^{T \times d_{\text{model}}}$ (donde $T$ es la longitud de la secuencia y $d_{\text{model}}$ es la dimensión del embedding), calculamos:

$$Q = XW^Q, \quad K = XW^K, \quad V = XW^V$$

donde $W^Q, W^K \in \mathbb{R}^{d_{\text{model}} \times d_k}$ y $W^V \in \mathbb{R}^{d_{\text{model}} \times d_v}$. Cada fila de $Q$ es un vector query para una posición, cada fila de $K$ es un vector key, y cada fila de $V$ es un vector value.

La intuición detrás de esta separación merece reflexión, porque es fácil pasarla por alto. Un query codifica lo que una posición está buscando (representa la pregunta "¿qué otras posiciones son relevantes para mí?"). Un key codifica lo que una posición contiene (representa la respuesta "esto es lo que tengo para ofrecer"). Un value codifica qué información lleva realmente una posición (es la carga útil que se transmite una vez que se establece la relevancia). El query y el key interactúan para determinar cuánta atención prestar, y el value determina qué información se transfiere.

¿Por qué tres proyecciones separadas en lugar de usar simplemente los embeddings crudos? Porque el mismo token podría necesitar ser encontrado por queries muy diferentes dependiendo del contexto. La palabra "banco" en "banco del río" y "cuenta bancaria" probablemente debería tener keys similares cuando un query es sobre ubicaciones pero keys diferentes cuando un query es sobre finanzas. Tener proyecciones aprendidas separadas permite al modelo crear representaciones especializadas para cada rol, y el espacio de queries, el espacio de keys y el espacio de values pueden capturar cada uno diferentes aspectos del significado que sirven para diferentes propósitos en el cálculo de atención.

💡 Piénsalo como una biblioteca. El query es el término de búsqueda que escribimos. El key es la etiqueta de metadatos de cada libro (título, tema, palabras clave). El value es el contenido real del libro. Comparamos nuestro término de búsqueda con los metadatos para encontrar libros relevantes, y luego leemos el contenido de los que coincidieron. Separar el "contra qué comparar" del "qué recuperar" le da al sistema una flexibilidad que una sola representación no puede proporcionar.

¿Por Qué Productos Punto, y Por Qué Escalarlos?

Una vez que tenemos queries y keys, necesitamos una forma de medir qué tan bien coinciden. La opción más simple es un producto punto: para el query $q_i$ (una fila de $Q$) y el key $k_j$ (una fila de $K$), el puntaje de atención crudo es $q_i \cdot k_j = \sum_{m=1}^{d_k} q_{im} k_{jm}$. Dos vectores apuntando en direcciones similares producen un producto punto positivo grande; vectores ortogonales producen cero; vectores opuestos producen un valor negativo grande. Esto es rápido, no requiere parámetros y captura la señal de similitud que necesitamos.

Calcular todos los productos punto por pares de una vez nos da la matriz completa de puntajes de atención. Para $T$ tokens con queries y keys de $d_k$ dimensiones:

$$S = QK^\top \in \mathbb{R}^{T \times T}$$

La entrada $S_{ij}$ nos dice cuánto coincide el query de la posición $i$ con el key de la posición $j$. Pero estos puntajes crudos tienen un problema que se vuelve severo a medida que $d_k$ crece. Si las entradas de $Q$ y $K$ se extraen aproximadamente de una distribución normal estándar (media 0, varianza 1), cada producto punto es una suma de $d_k$ productos independientes, por lo que la varianza de $S_{ij}$ escala linealmente con $d_k$:

$$\text{Var}(q_i \cdot k_j) = \text{Var}\!\left(\sum_{m=1}^{d_k} q_{im} k_{jm}\right) = d_k$$

Con $d_k = 64$ (típico para una sola cabeza de atención), los productos punto tienen una desviación estándar de $\sqrt{64} = 8$. Con $d_k = 512$, la desviación estándar sube a aproximadamente 22.6. Productos punto más grandes en valor absoluto significan que el softmax en el siguiente paso empujará casi toda su masa de probabilidad hacia una o dos entradas, produciendo pesos de atención que son casi one-hot. ¿Por qué es eso malo? Porque los gradientes del softmax en el régimen saturado son extremadamente pequeños (cercanos a cero para casi todas las entradas). El entrenamiento se ralentiza drásticamente o se detiene por completo, porque el modelo no puede ajustar a qué posiciones atender.

La solución de Vaswani et al. (2017) es dividir por $\sqrt{d_k}$ antes de aplicar softmax:

$$\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right) V$$

Examinemos qué hace cada parte de esta fórmula considerando qué se rompería si la elimináramos.

Sin el escalado $\frac{1}{\sqrt{d_k}}$: como acabamos de discutir, los productos punto crecen con la dimensión, el softmax se satura y los gradientes se desvanecen. El modelo efectivamente toma decisiones de atención duras desde el inicio del entrenamiento, antes de haber aprendido algo útil, y luego no puede actualizar esas decisiones porque los gradientes son demasiado pequeños.

Sin el softmax: los pesos de atención no sumarían 1, por lo que la salida sería una combinación lineal con coeficientes arbitrarios (posiblemente negativos). El softmax asegura que los pesos formen una distribución de probabilidad válida, lo que significa que la salida es una combinación convexa de vectores value, que se encuentra dentro del casco convexo de los values, manteniendo la escala de salida acotada e interpretable como "cuánta atención prestar a cada posición".

Sin la multiplicación por $V$: tendríamos pesos de atención pero nada a lo que atender realmente . Los pesos nos dicen qué posiciones importan, pero los values llevan la información real. Sin $V$, la salida de la capa de atención sería la propia matriz de pesos (una matriz $T \times T$ de escalares), no una secuencia de vectores de $d_v$ dimensiones que puede alimentar la siguiente capa.

💡 El factor $\sqrt{d_k}$ normaliza la varianza de los productos punto de vuelta a aproximadamente 1, independientemente de $d_k$. Después de dividir, $\text{Var}(S_{ij} / \sqrt{d_k}) = d_k / d_k = 1$. Esto mantiene las entradas del softmax en un rango moderado donde la distribución tiene entropía (puede distribuir probabilidad entre múltiples posiciones) y los gradientes fluyen.

Para ver esto concretamente, consideremos el caso extremo donde $d_k = 1$. Entonces cada "producto punto" es simplemente el producto de dos escalares. La varianza es 1, las entradas del softmax son moderadas y todo funciona bien, así que no necesitamos escalado. Ahora aumentemos a $d_k = 512$: sin escalado, un producto punto típico podría ser $\pm 22$, y $\text{softmax}([22, -3, 1, -5])$ da casi $[1, 0, 0, 0]$ (toda la masa en una entrada, sin gradiente útil para las demás). Dividir por $\sqrt{512} \approx 22.6$ nos devuelve a valores moderados alrededor de $\pm 1$, donde el softmax produce una distribución suave y el aprendizaje puede proceder.

Cómo la Suma Ponderada Produce la Salida

Después del softmax, tenemos una matriz $A = \text{softmax}(QK^\top / \sqrt{d_k}) \in \mathbb{R}^{T \times T}$, donde cada fila $A_i$ es una distribución de probabilidad sobre las $T$ posiciones. El paso final es multiplicar esto por la matriz de values $V \in \mathbb{R}^{T \times d_v}$:

$$\text{Output}_i = \sum_{j=1}^{T} A_{ij} \, V_j$$

Cada posición de salida $i$ es un promedio ponderado de todos los vectores value, donde los pesos provienen de qué tan bien el query de la posición $i$ coincidió con el key de cada posición. Si la posición $i$ atiende fuertemente a la posición 3 ($A_{i3}$ es grande) y débilmente a todo lo demás, entonces $\text{Output}_i \approx V_3$ (la salida en la posición $i$ es aproximadamente el vector value de la posición 3). Si la atención se distribuye uniformemente, la salida es la media de todos los vectores value, lo que tiende a producir una representación borrosa y menos útil.

Este es el mecanismo completo, y podemos implementarlo en unas pocas líneas de Python. El código a continuación calcula la atención de producto punto escalado de una sola cabeza desde cero usando solo NumPy, para que podamos ver exactamente qué sucede en cada paso.

import numpy as np

np.random.seed(42)

# Dimensions
T = 4       # sequence length (4 tokens)
d_model = 8 # embedding dimension
d_k = 4     # query/key dimension
d_v = 4     # value dimension

# Random input embeddings (T tokens, each d_model-dimensional)
X = np.random.randn(T, d_model)

# Learned projection matrices (in practice, these are trained)
W_Q = np.random.randn(d_model, d_k) * 0.1
W_K = np.random.randn(d_model, d_k) * 0.1
W_V = np.random.randn(d_model, d_v) * 0.1

# Project inputs to queries, keys, values
Q = X @ W_Q  # (T, d_k)
K = X @ W_K  # (T, d_k)
V = X @ W_V  # (T, d_v)

# Raw attention scores
scores = Q @ K.T  # (T, T)
print("Raw scores (before scaling):")
print(np.round(scores, 3))
print(f"\nScore std dev: {scores.std():.3f} (expected ~sqrt(d_k)={np.sqrt(d_k):.3f})")

# Scaled scores
scaled_scores = scores / np.sqrt(d_k)
print(f"\nScaled scores std dev: {scaled_scores.std():.3f} (expected ~1.0)")

# Softmax (row-wise)
def softmax(x):
    e = np.exp(x - x.max(axis=-1, keepdims=True))  # subtract max for stability
    return e / e.sum(axis=-1, keepdims=True)

attn_weights = softmax(scaled_scores)
print("\nAttention weights (each row sums to 1):")
print(np.round(attn_weights, 3))
print("Row sums:", np.round(attn_weights.sum(axis=1), 6))

# Weighted sum of values
output = attn_weights @ V  # (T, d_v)
print("\nOutput (T x d_v):")
print(np.round(output, 3))

Observa algunas cosas en la salida. Los puntajes crudos tienen una desviación estándar cercana a $\sqrt{d_k} = 2.0$, que crecería más con un $d_k$ mayor. Después de dividir por $\sqrt{d_k}$, la desviación estándar baja nuevamente hacia 1.0. Cada fila de los pesos de atención suma exactamente 1 (una distribución de probabilidad válida), y la salida final tiene la misma forma que $V$ (un vector de $d_v$ dimensiones por posición).

💡 En el código anterior, las matrices de pesos $W^Q$, $W^K$, $W^V$ se inicializan con valores aleatorios pequeños (escalados por 0.1). En la práctica, una inicialización cuidadosa (como Xavier/Glorot) asegura que las proyecciones produzcan salidas con varianza unitaria desde el inicio, lo cual es importante para un entrenamiento estable.

De Una Cabeza a Muchas

Una sola cabeza de atención calcula un conjunto de pesos de atención (una forma para que cada posición mire a todas las demás posiciones). Pero un token podría necesitar atender a diferentes posiciones por diferentes razones: un patrón de atención podría capturar dependencias sintácticas ("¿a qué sustantivo modifica este adjetivo?") mientras que otro captura relaciones semánticas ("¿a qué token anterior se refiere este pronombre?"). Una sola cabeza debe comprimir todas estas necesidades en un solo conjunto de pesos.

La atención multi-cabeza (Vaswani et al., 2017) resuelve esto ejecutando $h$ cabezas de atención en paralelo, cada una con sus propias matrices de proyección $W_i^Q, W_i^K, W_i^V$, y luego concatenando y proyectando los resultados:

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) \, W^O$$
$$\text{where } \text{head}_i = \text{Attention}(XW_i^Q, XW_i^K, XW_i^V)$$

Cada cabeza usa $d_k = d_{\text{model}} / h$, por lo que el cálculo total es aproximadamente el mismo que una sola cabeza con la dimensión completa $d_{\text{model}}$. Con 8 cabezas y $d_{\text{model}} = 512$, cada cabeza opera en un subespacio de 64 dimensiones. La proyección de salida $W^O \in \mathbb{R}^{hd_v \times d_{\text{model}}}$ mezcla las salidas de las cabezas de vuelta al espacio de representación del modelo.

¿Qué sucedería en los extremos? Si $h = 1$, recuperamos la atención de una sola cabeza: un conjunto de proyecciones Q, K, V y un patrón de atención por capa. El modelo aún puede aprender, pero cada capa solo puede calcular un promedio ponderado de values, forzándolo a comprimir todo en un solo patrón de atención. Si $h = d_{\text{model}}$, cada cabeza opera en un subespacio de 1 dimensión, por lo que cada producto punto $q^\top k$ es simplemente un escalar por un escalar. Las cabezas se vuelven demasiado estrechas para capturar relaciones significativas. La elección estándar de $h = 8$ o $h = 16$ se sitúa entre estos extremos, dando a cada cabeza suficientes dimensiones para aprender patrones útiles mientras proporciona suficientes cabezas para la especialización.

En la práctica, diferentes cabezas a menudo aprenden a atender a cosas diferentes. Clark et al. (2019) encontraron que en BERT, algunas cabezas se especializan en atender al token anterior o siguiente, otras rastrean dependencias sintácticas como la concordancia sujeto-verbo, y algunas atienden ampliamente a toda la oración. Esta especialización emergente es parte de lo que hace que la atención multi-cabeza sea tan efectiva: aprende un conjunto diverso de patrones de atención sin que se le diga qué buscar.

Ahora tenemos el panorama completo de cómo funcionan los puntajes de atención: proyectar a Q, K, V, calcular productos punto escalados, aplicar softmax, tomar una suma ponderada de values, y hacer esto múltiples veces en paralelo con cabezas separadas. El siguiente artículo aborda una restricción importante: ¿qué sucede cuando el modelo está generando texto y no debe mirar tokens futuros? Ahí es donde entra el enmascaramiento causal, y la matriz de atención adopta una estructura muy específica.

Cuestionario

Pon a prueba tu comprensión de la atención de producto punto escalado.

¿Por qué dividimos los puntajes del producto punto por √d_k antes de aplicar softmax?

¿Qué papel juega la matriz Value V en el cálculo de atención?

Si d_k = 256 y las entradas de query/key son aproximadamente normales estándar, ¿cuál es la desviación estándar aproximada de los productos punto crudos (sin escalar)?

¿Por qué la atención multi-cabeza usa h cabezas separadas con d_k = d_model / h en lugar de una cabeza con d_k = d_model?