Lo Mejor de Ambos Mundos
El artículo anterior dejó claro el compromiso: la atención sobresale en recuperación precisa y aprendizaje en contexto , mientras que los SSMs sobresalen en eficiencia de largo alcance y streaming . Los modelos Mamba puros tienen dificultades en tareas que requieren copia exacta o recuperación del contexto — si le pides a un SSM puro que repita un número de teléfono que vio hace 50,000 tokens, a menudo falla porque esa información específica fue comprimida en el estado de tamaño fijo. Los modelos de atención pura manejan esto fácilmente (simplemente atienden al token), pero su costo $O(n^2)$ hace que los contextos de millones de tokens sean prohibitivamente caros.
¿Qué pasaría si pudiéramos usar atención para las capas donde la recuperación precisa importa y SSMs para las capas donde la compresión de contexto de largo alcance es suficiente? Esa es la idea central detrás de las arquitecturas híbridas : apilar capas de atención y SSM en un solo modelo, dejando que cada tipo de capa haga lo que mejor sabe hacer. Las capas de atención proporcionan la memoria "nítida" — acceso directo token a token para contexto reciente o información crítica. Las capas SSM proporcionan la memoria "suave" — compresión eficiente de patrones y tendencias de largo alcance.
La pregunta de diseño clave es la proporción : ¿cuántas capas SSM versus capas de atención? En un extremo (toda atención), obtenemos un Transformer estándar. En el otro extremo (todo SSM), obtenemos Mamba puro. Entre estos extremos hay un espectro de híbridos, y la proporción óptima depende del caso de uso. Más capas de atención mejoran la recuperación a costa de la eficiencia en secuencias largas. Más capas SSM mejoran la eficiencia a costa de la recuperación precisa. Como veremos, los modelos de producción han convergido en proporciones alrededor de 6:1 a 8:1 (SSM-a-atención), usando apenas suficiente atención para manejar tareas que requieren mucha recuperación mientras mantienen el costo general cercano a lineal.
Jamba: Mamba + Atención + MoE
El primer modelo híbrido a gran escala que demostró que mezclar capas de atención y SSM funciona en la práctica es Jamba (Lieber et al., 2024) , un modelo de 52 mil millones de parámetros de AI21 Labs que soporta una ventana de contexto de 256K tokens . Jamba combina tres ideas arquitectónicas: capas SSM de Mamba, capas de atención estándar y capas feed-forward de Mixture of Experts (MoE).
La arquitectura de Jamba está construida a partir de bloques repetitivos. Cada bloque contiene una capa Mamba o una capa de atención, seguida de una red feed-forward (que puede o no usar MoE). Los bloques están dispuestos en un patrón específico: por cada 1 capa de atención, hay 7 capas Mamba — una proporción de 7:1. Concretamente, el modelo tiene 4 capas de atención de un total de 32 capas de mezcla de secuencia. Las capas de atención están espaciadas uniformemente (aproximadamente cada 8 capas), de modo que la información pasa a través de varias capas Mamba (comprimiendo contexto de largo alcance) antes de llegar a una capa de atención (que puede realizar recuperación precisa y operaciones en contexto).
¿Por qué funciona la proporción 7:1? Considera lo que contribuye cada tipo de capa. Las capas Mamba procesan la mayor parte de la secuencia con costo $O(n)$, manteniendo y actualizando representaciones de estado comprimidas. Las capas de atención, que aparecen cada 8 capas, actúan como "puntos de control" donde el modelo puede comparar tokens directamente y realizar el tipo de búsqueda precisa con la que los SSMs tienen dificultades. Cuatro capas de atención son suficientes para la mayoría de las tareas de recuperación porque el modelo no necesita realizar coincidencia exacta de tokens en cada capa — solo lo necesita en unos pocos puntos estratégicos.
El componente MoE aborda una preocupación diferente: la capacidad del modelo . Algunas de las capas feed-forward usan 16 expertos con enrutamiento top-2, lo que significa que solo 2 de 16 expertos se activan por token. Esto le da a Jamba la capacidad de un modelo de 52B parámetros mientras usa solo ~12B parámetros activos por token. Las capas MoE aparecen en un subconjunto de bloques (cada dos bloques), manteniendo la sobrecarga computacional manejable.
Los resultados son impactantes. Con contexto de 256K, Jamba cabe en una sola GPU de 80 GB — algo que sería imposible para un modelo de atención pura de 52B (solo el KV cache excedería la memoria). Esto se debe a que las 28 capas Mamba no contribuyen ningún KV cache en absoluto — solo las 4 capas de atención necesitan almacenamiento KV. El KV cache es así aproximadamente $4/32 = 12.5\%$ de lo que requeriría un modelo de atención completa. En el límite: con 0 capas de atención (Mamba puro), el KV cache es cero pero la recuperación se resiente. Con 32 capas de atención (Transformer puro), la recuperación es perfecta pero el KV cache de 256K no cabe en memoria.
import json, js
# KV cache comparison: Jamba (4 attn layers) vs full Transformer (32 attn layers)
n_layers_total = 32
n_attn_jamba = 4
d_model = 4096
n_kv_heads = 8 # GQA
d_head = 128
bytes_per_param = 2 # FP16
seq_lengths = [4096, 32768, 131072, 262144]
labels = ["4K", "32K", "128K", "256K"]
rows = []
for n, label in zip(seq_lengths, labels):
# Full Transformer: all 32 layers have KV cache
full_kv = 2 * n_layers_total * n_kv_heads * d_head * n * bytes_per_param
full_gb = full_kv / 1e9
# Jamba: only 4 layers have KV cache
jamba_kv = 2 * n_attn_jamba * n_kv_heads * d_head * n * bytes_per_param
jamba_gb = jamba_kv / 1e9
saving = (1 - jamba_kv / full_kv) * 100
rows.append([label, f"{full_gb:.1f} GB", f"{jamba_gb:.1f} GB", f"{saving:.0f}%"])
js.window.py_table_data = json.dumps({
"headers": ["Seq Length", "Full Transformer KV", "Jamba KV (4 attn layers)", "KV Memory Saved"],
"rows": rows
})
print(f"Config: {n_layers_total} layers, {n_kv_heads} KV heads, d_head={d_head}, FP16")
print(f"Jamba uses only {n_attn_jamba}/{n_layers_total} attention layers")
print(f"At 256K tokens: full Transformer needs {rows[-1][1]}, Jamba needs {rows[-1][2]}")
Zamba y StripedHyena
Jamba no es el único enfoque híbrido. Otros equipos han explorado diferentes formas de particionar capas de atención y SSM, cada uno haciendo diferentes compromisos.
Zamba (Glorioso et al., 2024) de Zyphra adopta un enfoque más agresivo: usa una sola capa de atención compartida intercalada entre muchas capas Mamba. En lugar de distribuir múltiples capas de atención a lo largo de la red (como hace Jamba), Zamba coloca una capa de atención y comparte sus pesos en múltiples posiciones de la red. El modelo pasa por este bloque de atención compartido a intervalos regulares, de modo que aparece en múltiples puntos durante el pase forward pero usa solo un conjunto de parámetros de atención. Esto reduce aún más el conteo de parámetros y el KV cache (ya que solo hay una capa de atención de keys y values para almacenar en caché) mientras sigue proporcionando la capacidad de recuperación que le falta a Mamba puro.
StripedHyena (Together AI, 2023) adopta un enfoque diferente a la cuestión híbrida. En lugar de usar Mamba, alterna entre capas Hyena (una alternativa basada en convolución a la atención con escalado sub-cuadrático) y capas de atención estándar. El patrón "a rayas" — alternando Hyena y atención — le da al modelo tanto la eficiencia de largo alcance de la mezcla de secuencias sub-cuadrática como el aprendizaje en contexto preciso de la atención. StripedHyena-7B igualó el rendimiento de LLaMA 2 7B en benchmarks estándar mientras ofrecía mejor escalado en secuencias más largas.
Estos modelos comparten una lección común: el mecanismo específico no-atención importa menos que el principio de mezcla . Ya sea que uses Mamba, Hyena, RWKV u otra capa sub-cuadrática, el enfoque híbrido de combinarlo con unas pocas capas de atención supera consistentemente a cualquier enfoque puro. Las capas de atención actúan como una "columna vertebral de precisión" — sin ellas, todos estos modelos muestran degradación en tareas que requieren mucha recuperación. Con incluso una pequeña fracción de capas de atención (1 de cada 6 a 1 de cada 8), el rendimiento de recuperación se recupera a niveles cercanos al Transformer.
Cuándo Usar Qué
Con tres paradigmas disponibles — atención pura, SSM puro e híbrido — ¿cómo eliges? La respuesta depende de tu longitud de secuencia , requisitos de la tarea y restricciones de despliegue . Aquí hay una comparación práctica:
import json, js
rows = [
[
"Pure Attention (GPT, LLaMA)",
"O(n^2)",
"Short-medium (up to ~32K)",
"Excellent",
"General-purpose LLMs, chat, code generation"
],
[
"Pure SSM (Mamba)",
"O(n)",
"Very long (256K+)",
"Limited",
"Genomics, long audio, streaming, on-device"
],
[
"Hybrid (Jamba, Zamba)",
"~O(n)",
"Long (128K-256K+)",
"Good",
"Production LLMs needing long context + recall"
],
]
js.window.py_table_data = json.dumps({
"headers": ["Architecture", "Compute", "Best Context Range", "In-Context Recall", "Best Use Cases"],
"rows": rows
})
print("Architecture selection guide")
print("Pure attention: best quality, worst scaling")
print("Pure SSM: best scaling, weakest recall")
print("Hybrid: practical sweet spot for most production needs")
Repasemos la lógica de decisión para escenarios comunes:
- Chatbot con conversaciones cortas (< 8K tokens): Atención pura (Transformer). El costo cuadrático es insignificante a esta longitud, y obtienes el mejor aprendizaje en contexto. No hay razón para añadir la complejidad de SSM.
- Analizar una base de código completa o un documento legal extenso (64K-256K tokens): Arquitectura híbrida. Necesitas tanto comprensión de largo alcance (las capas SSM manejan esto económicamente) como la capacidad de referenciar secciones específicas con precisión (las capas de atención manejan esto). Los modelos estilo Jamba están diseñados exactamente para este régimen.
- Procesar secuencias genómicas o datos de sensores continuos (1M+ tokens): SSM puro o híbrido agresivo (1-2 capas de atención). La secuencia es demasiado larga para cualquier cantidad significativa de atención, y la tarea típicamente requiere reconocimiento de patrones en rangos largos en lugar de recuperación precisa de tokens individuales.
- Inferencia en dispositivo o streaming: SSM puro. El estado de tamaño fijo (sin KV cache creciente) es ideal para entornos con memoria limitada, y el modo recurrente procesa naturalmente un token a la vez sin necesidad de almacenar tokens pasados.
La tendencia es clara: a medida que las longitudes de contexto crecen y los modelos se despliegan en entornos más diversos, las arquitecturas híbridas se están convirtiendo en el estándar de producción . Los Transformers de atención pura dominaron porque los contextos eran cortos (2K-8K tokens) y el hardware estaba optimizado para multiplicaciones de matrices. A medida que avanzamos hacia 256K y más allá, el costo cuadrático se vuelve insostenible, y el enfoque híbrido — usando SSMs para la mayor parte del procesamiento y atención para las capas de precisión — ofrece el mejor compromiso entre calidad y costo.
El framework SSD de Mamba-2 (artículo anterior) sugiere que esta convergencia puede profundizarse aún más: si los SSMs y la atención son matemáticamente duales, las arquitecturas futuras podrían ni siquiera tener "capas de atención" y "capas SSM" discretas — podrían tener capas que interpolan suavemente entre los dos modos dependiendo de la entrada y la posición en la secuencia.
Quiz
Pon a prueba tu comprensión de las arquitecturas híbridas que combinan atención y SSMs.
¿Por qué los modelos SSM puros (como Mamba) tienen dificultades con tareas que requieren recuperación precisa del contexto?
Jamba usa una proporción de 7:1 de capas Mamba-a-atención (28 capas Mamba y 4 capas de atención). ¿Cuál es el beneficio principal de esta proporción para la inferencia de contexto largo?
¿Cómo difiere el enfoque de Zamba para la arquitectura híbrida del de Jamba?
¿Para qué escenario seguiría siendo un Transformer de atención pura la mejor opción sobre un modelo híbrido o SSM?