¿Por Qué Reemplazar la Atención?
Todos los mecanismos de atención que hemos visto hasta ahora — completa, dispersa, ventana deslizante — comparten un costo fundamental: cada token debe interactuar con otros tokens para decidir qué es relevante. La atención completa es $O(n^2)$ tanto en cómputo como en memoria. La atención dispersa (artículo 3) reduce el factor constante o elimina algunas interacciones, pero sigue siendo fundamentalmente cuadrática o, en el mejor de los casos, $O(n \sqrt{n})$. La atención de ventana deslizante es $O(n \cdot w)$, lineal en $n$ pero solo porque renuncia al acceso directo de largo alcance. ¿Qué pasaría si pudiéramos construir un modelo de secuencias que sea $O(n)$ desde el principio — procesando cada token en tiempo constante con respecto a la longitud de la secuencia — mientras aún captura dependencias de largo alcance?
Esa es la promesa de los State Space Models (SSMs) . En lugar de calcular interacciones por pares entre todos los tokens (como hace la atención), un SSM mantiene un estado oculto de tamaño fijo que se actualiza a medida que llega cada nuevo token. Piénsalo como una red neuronal recurrente (RNN) — pero con una estructura matemática muy específica que lo hace entrenable en secuencias largas. La idea clave: comprimir todo el historial en un vector de estado $h \in \mathbb{R}^N$, luego actualizar ese estado en tiempo $O(N)$ por paso, donde $N$ es la dimensión del estado (típicamente 16–64), no la longitud de la secuencia.
En los extremos: si la dimensión del estado $N = 1$, tenemos un solo número resumiendo todo el historial — irremediablemente con pérdida. Si $N = n$ (estado tan grande como la secuencia), recuperamos algo equivalente a almacenar la secuencia completa, perdiendo la ventaja de eficiencia. El régimen práctico es $N \ll n$: un estado compacto (digamos, 16 dimensiones) que captura los patrones esenciales a lo largo de millones de tokens.
State Space Models: La Base en Tiempo Continuo
El modelo Structured State Space for Sequence Modeling (S4) (Gu et al., 2022) parte de una idea clásica de la teoría de control: un state space model en tiempo continuo . El sistema tiene una señal de entrada $x(t)$ (un escalar en cada tiempo $t$), un estado oculto $h(t) \in \mathbb{R}^N$, y una señal de salida $y(t)$. La dinámica está gobernada por cuatro matrices:
Definamos cada símbolo. $h(t) \in \mathbb{R}^N$ es el estado oculto en el tiempo $t$ — un vector que comprime toda la información del pasado. $h'(t)$ es su derivada temporal (qué tan rápido está cambiando el estado). $x(t) \in \mathbb{R}$ es la entrada en el tiempo $t$ (un canal — S4 procesa cada canal independientemente). $y(t) \in \mathbb{R}$ es la salida. $A \in \mathbb{R}^{N \times N}$ es la matriz de transición de estado — controla cómo evoluciona el estado oculto con el tiempo. Esta es la matriz más importante: determina qué recuerda el modelo y qué olvida. $B \in \mathbb{R}^{N \times 1}$ es la matriz de proyección de entrada — cómo la nueva entrada alimenta al estado. $C \in \mathbb{R}^{1 \times N}$ es la matriz de proyección de salida — cómo leemos el estado para producir la salida. $D \in \mathbb{R}$ es una conexión residual de la entrada a la salida (frecuentemente se establece en cero o se trata como un residual).
Pero trabajamos con secuencias discretas (tokens), no con señales continuas. S4 discretiza el sistema continuo usando un paso de tiempo $\Delta > 0$. El enfoque estándar (retención de orden cero) convierte $(A, B)$ en matrices discretas $(\bar{A}, \bar{B})$:
En la práctica, frecuentemente se usa una aproximación de primer orden: $\bar{A} \approx I + \Delta A$ y $\bar{B} \approx \Delta B$. Después de la discretización, el modelo se convierte en una recurrencia simple:
¡Esto luce exactamente como una RNN! En cada paso $k$, tomamos el estado anterior $h_{k-1}$, multiplicamos por $\bar{A}$ (la transición de estado discreta), sumamos la entrada $x_k$ escalada por $\bar{B}$, y producimos el siguiente estado $h_k$. La lectura con $C$ da la salida. Procesar una secuencia de longitud $n$ toma tiempo $O(n \cdot N^2)$ en modo recurrente — lineal en la longitud de la secuencia.
Pero hay un truco crucial que hace a S4 entrenable (a diferencia de las RNNs convencionales). Como $\bar{A}$, $\bar{B}$, $C$ y $D$ son fijos para todos los pasos de tiempo (no dependen de la entrada), podemos desenrollar la recurrencia en una convolución. La salida en el paso $k$ es:
El kernel $K_j = C \bar{A}^j \bar{B}$ puede precalcularse, y luego toda la secuencia de salida es una convolución $y = K * x$. Las convoluciones pueden calcularse en $O(n \log n)$ vía FFT — masivamente paralelizable en GPUs, a diferencia de la recurrencia secuencial. Así que S4 tiene dos modos : un modo convolucional para entrenamiento paralelo ($O(n \log n)$) y un modo recurrente para inferencia autoregresiva eficiente ($O(n)$).
La pieza final de S4 es la inicialización HiPPO (Gu et al., 2020) . La inicialización aleatoria de $A$ conduce a estados que explotan o se desvanecen en secuencias largas (el mismo problema que afecta a las RNNs). HiPPO (High-order Polynomial Projection Operators) inicializa $A$ con una matriz específica que proyecta continuamente el historial de entrada sobre una base de polinomios de Legendre. Intuitivamente, el estado $h(t)$ almacena una aproximación polinómica comprimida de todo lo que el modelo ha visto hasta ahora. Esto le da a S4 la capacidad de recordar información a lo largo de decenas de miles de pasos — algo con lo que las RNNs convencionales e incluso las LSTMs tienen dificultades.
import json, js
import math
# Demonstrate the S4 recurrence for a tiny example
# State dim N=2, sequence length=8
N = 2 # state dimension
# Simplified discrete parameters (not real HiPPO, just illustrative)
delta = 0.1
# A is a decay matrix (diagonal for simplicity)
A_bar = [[1 - delta * 0.5, 0], [0, 1 - delta * 0.3]] # slow + fast decay
B_bar = [[delta * 1.0], [delta * 0.5]]
C = [1.0, 1.0]
D = 0.0
# Input sequence
x = [0, 0, 1, 0, 0, 0, 0, 0] # impulse at step 2
# Run recurrence
h = [0.0, 0.0]
rows = []
for k in range(len(x)):
# h_k = A_bar @ h_{k-1} + B_bar * x_k
h_new = [
A_bar[0][0] * h[0] + A_bar[0][1] * h[1] + B_bar[0][0] * x[k],
A_bar[1][0] * h[0] + A_bar[1][1] * h[1] + B_bar[1][0] * x[k],
]
y_k = C[0] * h_new[0] + C[1] * h_new[1] + D * x[k]
rows.append([
str(k),
str(x[k]),
f"[{h_new[0]:.4f}, {h_new[1]:.4f}]",
f"{y_k:.4f}"
])
h = h_new
js.window.py_table_data = json.dumps({
"headers": ["Paso k", "Entrada x_k", "Estado h_k", "Salida y_k"],
"rows": rows
})
print("Recurrencia S4 con estado N=2, impulso en el paso 2")
print("Observa: el estado decae gradualmente, así que la salida 'recuerda' el impulso")
print("La dimensión lenta (decaimiento 0.95) retiene más que la rápida (decaimiento 0.97)")
Mamba: Haciendo los SSMs Dependientes de los Datos
S4 tiene una limitación fundamental: las matrices $A$, $B$, $C$ y $\Delta$ son las mismas para cada token de entrada . Ya sea que el modelo vea la palabra "importante" o la palabra "el", aplica la misma transición de estado. Esto significa que S4 no puede enfocarse selectivamente en o ignorar entradas específicas — la compresión es agnóstica al contenido. Para ver por qué esto importa, considera una tarea simple: dada una secuencia como "clave: 42 ... ruido ... ruido ... consulta: clave", el modelo debe recuperar 42. Un modelo S4 procesa "42" y "ruido" con la misma matriz $B$, así que ambos se escriben en el estado por igual. No tiene mecanismo para decir "este token es importante, escríbelo con más fuerza" o "esto es ruido, sáltalo."
Mamba (Gu & Dao, 2023) resuelve esto haciendo los parámetros del SSM dependientes de la entrada . Específicamente, Mamba hace que $B$, $C$ y $\Delta$ sean funciones de la entrada actual $x_k$:
donde cada $\text{Linear}$ es una proyección aprendida de la dimensión de entrada a la forma apropiada. El softplus en $\Delta_k$ asegura que siempre sea positivo (ya que $\Delta$ es un paso de tiempo, debe ser $> 0$). La matriz $A$ permanece fija (inicializada con HiPPO) — hacerla dependiente de la entrada rompería las propiedades estructuradas que permiten el cálculo eficiente.
¿Por qué es $\Delta_k$ tan importante? Recuerda la transición de estado discretizada: $\bar{A}_k = e^{\Delta_k A}$. Cuando $\Delta_k$ es grande, $\bar{A}_k$ decae más (el estado "olvida" más del pasado) y $\bar{B}_k$ es mayor (la nueva entrada se escribe con más fuerza). Cuando $\Delta_k$ es pequeño, $\bar{A}_k \approx I$ (el estado se preserva tal cual) y $\bar{B}_k$ es pequeño (la entrada se ignora en su mayoría). Así que $\Delta_k$ actúa como una compuerta selectiva : el modelo aprende a establecer $\Delta_k$ grande para tokens importantes ("presta atención, actualiza el estado") y pequeño para tokens irrelevantes ("sáltalo, mantén el estado").
En los extremos: si $\Delta_k \to 0$ para cada token, $\bar{A}_k \to I$ y $\bar{B}_k \to 0$ — el estado nunca cambia y el modelo ignora toda entrada. Si $\Delta_k \to \infty$ para cada token, $\bar{A}_k \to 0$ y el estado se borra por completo en cada paso — el modelo no tiene memoria alguna, solo ve el token actual. Los valores aprendidos de $\Delta_k$ se sitúan entre estos extremos, y crucialmente, varían por token.
¡Pero hacer los parámetros dependientes de la entrada rompe el truco de convolución! Con parámetros fijos, el kernel $K_j = C\bar{A}^j\bar{B}$ puede precalcularse. Con $B_k$, $C_k$, $\Delta_k$ dependientes de la entrada, el kernel cambia en cada posición. Mamba no puede usar convoluciones basadas en FFT — entonces ¿cómo entrena eficientemente?
La respuesta es el algoritmo de escaneo selectivo , implementado con un enfoque consciente del hardware. La idea clave es que la recurrencia $h_k = \bar{A}_k h_{k-1} + \bar{B}_k x_k$ es una suma de prefijos paralela (también llamada escaneo). Así como puedes calcular sumas acumuladas en $O(n)$ de trabajo con $O(\log n)$ pasos paralelos, puedes calcular esta recurrencia lineal en paralelo usando un escaneo asociativo. Mamba implementa esto como un kernel CUDA personalizado que:
- Carga la entrada y los parámetros desde la HBM (memoria de alto ancho de banda) de la GPU a la SRAM rápida
- Calcula la discretización ($\bar{A}_k$, $\bar{B}_k$) en SRAM
- Ejecuta el escaneo paralelo completamente en SRAM
- Escribe solo las salidas finales de vuelta a la HBM
Esto evita el cuello de botella de memoria que de otro modo haría lento el escaneo en GPUs. El resultado: Mamba entrena a velocidades comparables con Transformers optimizados, mientras escala linealmente con la longitud de la secuencia. En una secuencia de 1M de tokens, la atención necesitaría $\sim$1 billón de interacciones por pares. Mamba la procesa en tiempo $O(n \cdot N \cdot D)$, donde $N$ es la dimensión del estado (típicamente 16) y $D$ es la dimensión del modelo.
import json, js
import math
# Compare compute scaling: attention vs Mamba
seq_lengths = [1024, 4096, 16384, 65536, 262144, 1048576]
labels = ["1K", "4K", "16K", "64K", "256K", "1M"]
d_model = 768
n_heads = 12
d_head = d_model // n_heads
N_state = 16 # Mamba state dimension
rows = []
for n, label in zip(seq_lengths, labels):
# Attention: 2 * n^2 * d (QK^T + softmax @ V, per head, simplified)
attn_flops = 2 * n * n * d_model
# Mamba: n * N * D (scan + input projections, simplified)
mamba_flops = n * N_state * d_model
ratio = attn_flops / mamba_flops
rows.append([
label,
f"{attn_flops:.2e}",
f"{mamba_flops:.2e}",
f"{ratio:.0f}x"
])
js.window.py_table_data = json.dumps({
"headers": ["Long. Secuencia", "FLOPs Atención", "FLOPs Mamba", "Atención / Mamba"],
"rows": rows
})
print("Comparación simplificada de cómputo (una capa, d_model=768)")
print("La atención escala como O(n^2 * d), Mamba como O(n * N * d)")
print(f"Con 1M de tokens, la atención necesita {rows[-1][3]} más cómputo que Mamba")
Mamba-2 y la Dualidad Estructurada de Espacios de Estado
Si los SSMs y la atención parecen paradigmas completamente diferentes, Mamba-2 (Dao & Gu, 2024) revela una conexión matemática profunda entre ellos. El artículo introduce el marco de Structured State Space Duality (SSD) , que muestra que un SSM lineal es equivalente a una forma específica de atención — y viceversa.
Aquí está la intuición. Escribe la recurrencia del SSM para cada posición y recopila las salidas en una ecuación matricial. Para una secuencia de longitud $n$, la salida $y_k$ depende de todas las entradas anteriores $x_0, \ldots, x_k$ a través del estado acumulado. Si escribes la matriz $M$ donde $M_{k,j}$ da el peso de la entrada $x_j$ sobre la salida $y_k$, obtienes:
donde $\bar{A}_{k:j+1} = \bar{A}_k \bar{A}_{k-1} \cdots \bar{A}_{j+1}$ es el producto de todas las transiciones de estado del paso $j+1$ al $k$. Esta matriz $M$ es triangular inferior (causal — las salidas solo dependen de entradas pasadas) y estructurada (cada entrada está determinada por los parámetros del SSM, no se aprende libremente).
Ahora compara esto con la atención causal. La matriz de atención también es triangular inferior (la máscara causal), y cada entrada es una función del query y key en esas posiciones. La idea de SSD es que cuando la matriz de atención tiene la estructura específica $M_{k,j} = Q_k^\top S_{k:j+1} K_j$ (donde $S$ es una máscara estructurada que decae con la distancia), esto es matemáticamente idéntico a un SSM lineal. Los queries juegan el rol de $C$, los keys juegan el rol de $B$, y la máscara estructurada $S$ codifica las transiciones de estado $\bar{A}$.
Esta dualidad tiene una consecuencia práctica: Mamba-2 puede calcularse usando cualquiera de los dos : la recurrencia del SSM (eficiente para secuencias largas) o una formulación de multiplicación matricial (eficiente en hardware moderno con tensor cores). Mamba-2 usa una descomposición por bloques : divide la secuencia en fragmentos, usa la formulación cuadrática (tipo atención) dentro de cada fragmento (para aprovechar los tensor cores), y usa la recurrencia lineal (SSM) para propagar el estado entre fragmentos. Este algoritmo híbrido es 2–8$\times$ más rápido que el escaneo selectivo de Mamba-1 en GPUs modernas.
Mamba-2 también simplifica la arquitectura. Donde Mamba-1 usaba una dimensión de estado $N$ independiente de la dimensión de cabeza, Mamba-2 introduce SSM multi-cabeza (análogo a la atención multi-cabeza), donde cada cabeza tiene sus propias matrices $A$, $B$, $C$. La estructura de cabezas coincide con la disposición de los tensor cores, mejorando aún más la utilización del hardware.
SSMs vs Atención: Fortalezas y Debilidades
Con SSMs como Mamba y Transformers basados en atención como los dos paradigmas principales, ¿cuándo deberías usar cuál? Ninguno domina en todas partes. Comparémoslos en las dimensiones que importan para sistemas en producción.
Escalado de cómputo. La atención es $O(n^2)$ en longitud de secuencia. Incluso con FlashAttention (que reduce la memoria a $O(n)$), el cómputo sigue siendo cuadrático. Los SSMs son $O(n)$ — procesar una secuencia de 1M de tokens cuesta 1000$\times$ menos que la atención. Para secuencias muy largas (libros, bases de código, genómica), este es el factor decisivo.
Memoria durante la inferencia. Los modelos basados en atención necesitan una caché KV que crece linealmente con la longitud de la secuencia (y es costosa — el ejemplo de LLaMA 3 70B del artículo 6 mostró 42 GB para 128K tokens incluso con GQA). Los SSMs solo necesitan el estado de tamaño fijo $h \in \mathbb{R}^N$ — típicamente unos pocos kilobytes por capa. Con 1M de tokens, la caché KV de un Transformer podría ser de cientos de gigabytes; el estado del SSM se mantiene en unos pocos megabytes.
Aprendizaje en contexto. Aquí es donde la atención brilla. La atención puede realizar búsqueda exacta a nivel de token : dado "La capital de Francia es [espacio]", la atención puede atender directamente al token "París" donde sea que haya aparecido en el contexto. Los SSMs comprimen todo el historial en un estado de tamaño fijo, lo que significa que la recuperación exacta de un token específico de miles de pasos atrás es difícil. Empíricamente, los modelos Mamba rinden por debajo de los Transformers en tareas que requieren copia precisa o recuperación del contexto (por ejemplo, recuerdo asociativo, búsqueda en directorio telefónico).
Rendimiento de entrenamiento. Durante el entrenamiento, la atención se beneficia de un paralelismo masivo (todos los pares pueden calcularse simultáneamente). El modo de convolución de S4 y la descomposición por bloques de Mamba-2 logran un rendimiento de entrenamiento competitivo, pero la atención + FlashAttention en secuencias cortas a medianas (hasta ~8K) es difícil de superar debido a años de optimización de hardware y software. En secuencias más allá de 16K, los SSMs comienzan a tomar ventaja.
import json, js
# Summary comparison table
rows = [
["Escalado de cómputo", "O(n^2)", "O(n)", "SSM"],
["Memoria (inferencia)", "O(n) caché KV", "O(1) estado", "SSM"],
["Aprendizaje en contexto", "Excelente", "Limitado", "Atención"],
["Recuperación precisa", "Exacta (atiende a cualquier token)", "Aproximada (estado comprimido)", "Atención"],
["Entrenamiento (seq. corta)", "Altamente optimizado", "Competitivo", "Atención"],
["Entrenamiento (seq. larga)", "Costo cuadrático", "Costo lineal", "SSM"],
["Madurez de hardware", "Años de optimización", "Más nuevo, poniéndose al día", "Atención"],
["Streaming / tiempo real", "Debe recalcular o cachear", "Natural (recurrente)", "SSM"],
]
js.window.py_table_data = json.dumps({
"headers": ["Dimensión", "Atención", "SSM (Mamba)", "Ventaja"],
"rows": rows
})
print("Ninguna arquitectura domina en todas las dimensiones")
print("La atención gana en recuperación precisa y aprendizaje en contexto")
print("Los SSMs ganan en eficiencia con secuencias largas e inferencia en streaming")
La conclusión: los SSMs no son un reemplazo directo de la atención. Sobresalen en regímenes diferentes. Para contextos cortos a medianos (hasta ~8K tokens) donde el aprendizaje en contexto y la recuperación precisa importan, los Transformers basados en atención siguen siendo superiores. Para secuencias muy largas (64K+) donde el costo cuadrático se vuelve prohibitivo, o para aplicaciones de streaming donde mantener una caché KV es impráctico, los SSMs ofrecen una alternativa convincente. Y como veremos en el próximo artículo, la dirección más prometedora puede ser combinar ambos .
Quiz
Pon a prueba tu comprensión de los state space models, S4 y Mamba.
En el modelo S4, ¿cuál es el rol de la inicialización HiPPO para la matriz $A$?
¿Cuál es la innovación clave de Mamba sobre S4?
En Mamba, ¿qué sucede cuando el paso de tiempo aprendido $\Delta_k$ es muy pequeño para un token dado?
¿Qué revela el marco de Structured State Space Duality (SSD) en Mamba-2?