¿Cómo Aprende un Modelo el Lenguaje Sin Etiquetas?
En el artículo anterior construimos un transformer decoder-only y lo entrenamos para invertir secuencias cortas. Esa tarea de juguete usó unos pocos miles de parámetros y un conjunto de datos sintético (nada que se parezca al lenguaje real). El salto de un modelo de juguete a algo como GPT-4 o BERT usa los mismos bloques arquitectónicos, pero el entrenamiento cambia completamente. La primera etapa, el pre-training , expone al modelo a enormes cantidades de texto crudo con un objetivo auto-supervisado que no requiere etiquetas humanas. El modelo aprende sintaxis, hechos, patrones de razonamiento y conocimiento del mundo puramente prediciendo texto.
Hay dos objetivos de pre-training dominantes, y se corresponden directamente con las dos arquitecturas que estudiamos en el artículo 8. El modelado de lenguaje causal (CLM) se usa para modelos decoder-only como GPT (Radford et al., 2018) . El modelo lee tokens de izquierda a derecha y, en cada posición, predice el siguiente token. La pérdida es el promedio de la log-verosimilitud negativa sobre todas las posiciones:
Esta es exactamente la pérdida que usamos en el artículo 9 para la tarea de inversión, solo que aplicada al lenguaje natural a escala. Cada oración en el corpus de entrenamiento es un ejemplo de entrenamiento gratuito (no se necesita anotación). La máscara causal asegura que el modelo no pueda ver tokens futuros, por lo que cada posición proporciona una tarea de predicción genuina.
El modelado de lenguaje enmascarado (MLM) se usa para modelos encoder como BERT (Devlin et al., 2019) . En lugar de predecir el siguiente token, el modelo recibe una secuencia donde aproximadamente el 15% de los tokens han sido reemplazados con un token [MASK], y predice la identidad original de cada posición enmascarada. Debido a que los encoders tienen atención bidireccional (sin máscara causal), el modelo puede usar contexto de ambos lados de una posición enmascarada, lo que tiende a producir representaciones más ricas para tareas de clasificación posteriores.
donde $\mathcal{M}$ es el conjunto de posiciones enmascaradas y $x_{\setminus \mathcal{M}}$ denota la secuencia con esas posiciones reemplazadas. La pérdida se calcula solo sobre las posiciones enmascaradas, no sobre toda la secuencia, lo que significa que cada ejemplo de entrenamiento proporciona menos señales de gradiente que CLM (aproximadamente el 15% de los tokens frente al 100%). Esta es una razón por la que los modelos encoder típicamente necesitan más datos de entrenamiento o epochs para alcanzar una saturación comparable.
El condicionamiento sobre $x_{\setminus \mathcal{M}}$ (los tokens no enmascarados) es lo que le da al MLM su carácter bidireccional: el modelo ve tokens a ambos lados del hueco al predecir qué pertenece en él. Si enmascaráramos el 100% de los tokens, el modelo no tendría contexto alguno y la pérdida se reduciría a predecir tokens desde nada (esencialmente un modelo de lenguaje unigrama). Si enmascaráramos el 0%, no habría señal de entrenamiento. La tasa de enmascaramiento del 15% es un compromiso: suficiente contexto para que el modelo haga predicciones informadas, suficientes posiciones enmascaradas para producir un gradiente útil. Devlin et al. probaron otras tasas y encontraron que el 15% funciona bien empíricamente, aunque trabajos posteriores como SpanBERT mostraron que enmascarar tramos contiguos en lugar de tokens individuales aleatorios puede mejorar el rendimiento en tareas posteriores.
T5 (Raffel et al., 2020) introdujo una tercera variante para modelos encoder-decoder: la corrupción de tramos , donde tramos contiguos de tokens se reemplazan con un único token centinela y el decoder genera los tramos faltantes. Esto une los dos enfoques: el encoder ve contexto bidireccional corrupto, y el decoder genera las piezas faltantes de forma autoregresiva.
¿Qué Datos y Cómputo Requiere el Pre-training?
La auto-supervisión significa que podemos usar cualquier texto como datos de entrenamiento, lo que desplaza el cuello de botella de la anotación a la recolección de datos y el cómputo. Los corpus de pre-training modernos son masivos. Common Crawl contiene petabytes de datos de rastreo web acumulados desde 2008. The Pile (Gao et al., 2020) curó 825 GB de texto diverso en inglés de 22 fuentes incluyendo libros, artículos académicos, código de GitHub y Stack Exchange. RedPajama (Together, 2023) replicó y extendió la receta de datos de entrenamiento de LLaMA con más de 1,2 billones de tokens de web, libros, Wikipedia, código y fuentes académicas.
Pero más datos por sí solos no son suficientes, porque también necesitamos suficientes parámetros para absorberlos y suficiente cómputo para ejecutar la optimización. Estas tres cantidades (datos, parámetros, cómputo) están estrechamente vinculadas. Kaplan et al. (2020) caracterizó por primera vez esta relación, mostrando que la pérdida sigue leyes de potencia suaves como función del tamaño del modelo, el tamaño del dataset y el presupuesto de cómputo. Sin embargo, su análisis sugería que los modelos deberían escalarse más rápido que los datos, lo que llevaba a modelos grandes entrenados con relativamente pocos tokens.
Hoffmann et al. (2022) revisitó esta cuestión con Chinchilla y llegó a una conclusión diferente: para un presupuesto de cómputo fijo, el enfoque óptimo es escalar los parámetros del modelo y los tokens de entrenamiento aproximadamente al mismo ritmo. Específicamente, encontraron que el número de tokens de entrenamiento debería ser aproximadamente 20 veces el número de parámetros. Un modelo de 10 mil millones de parámetros, según esta estimación, debería ver aproximadamente 200 mil millones de tokens para aprovechar al máximo el cómputo disponible.
El impacto práctico fue inmediato. Antes de Chinchilla, modelos como Gopher (280B parámetros, 300B tokens) estaban posiblemente sub-entrenados en relación con su tamaño. Chinchilla mismo (70B parámetros, 1,4 billones de tokens) superó a Gopher a pesar de ser cuatro veces más pequeño, precisamente porque fue entrenado con proporcionalmente más datos. La lección es que ni el tamaño del modelo ni el tamaño de los datos es una única palanca; hay un balance óptimo para cualquier presupuesto de cómputo dado, y ese balance generalmente involucra más tokens de lo que la práctica anterior sugería.
Después del pre-training, tenemos un modelo que puede predecir el siguiente token con una precisión notable. Ha absorbido gramática, hechos e incluso cierta capacidad de razonamiento de la estructura estadística de sus datos de entrenamiento. Pero si le damos un prompt con una pregunta como "Explica el entrelazamiento cuántico en términos simples", es tan probable que continúe con otra pregunta, un párrafo estilo Wikipedia o una publicación de foro como que produzca una respuesta útil. El pre-training enseña al modelo a imitar la distribución del texto en internet, no a seguir instrucciones. Esa brecha es lo que aborda el fine-tuning.
¿Cómo Aprende un Modelo Pre-entrenado a Seguir Instrucciones?
El ajuste fino supervisado (SFT) toma un modelo de lenguaje pre-entrenado y continúa entrenándolo en un conjunto de datos curado de pares (instrucción, respuesta). Los datos podrían verse así:
# Example SFT training pair
instruction = "Explain why the sky is blue in two sentences."
response = "Sunlight contains all wavelengths of visible light. When it hits Earth's atmosphere, shorter blue wavelengths scatter more than longer red ones (Rayleigh scattering), so the sky appears blue from the ground."
# The model sees the concatenation as one sequence:
# [instruction tokens] [response tokens]
# and we compute the loss only on the response tokens.
El objetivo de entrenamiento sigue siendo la predicción del siguiente token (la misma pérdida de entropía cruzada del pre-training) pero aplicada solo a los tokens de la respuesta. Los tokens de la instrucción se alimentan al modelo como contexto (pasan por el pase forward e influyen en la atención), pero anulamos su contribución a la pérdida. Esto enseña al modelo a generar completaciones útiles condicionadas a instrucciones, sin penalizarlo por no predecir la instrucción misma.
Los datasets para SFT son mucho más pequeños que los corpus de pre-training (típicamente decenas de miles a unos pocos cientos de miles de ejemplos, en lugar de miles de millones de tokens). InstructGPT (Ouyang et al., 2022) usó aproximadamente 13.000 ejemplos de demostración para su etapa de SFT. Alpaca (Taori et al., 2023) mostró que incluso 52.000 pares de instrucción-respuesta generados por GPT-4 podían convertir un modelo base LLaMA en un seguidor de instrucciones aceptable. La razón por la que datasets tan pequeños funcionan es que el SFT no está enseñando conocimiento nuevo al modelo (el conocimiento ya está en los pesos pre-entrenados). El SFT está enseñando al modelo un nuevo formato : leer la instrucción, luego producir una respuesta directa en lugar de continuar en cualquier estilo que los datos de pre-training contenían.
Formalmente, dado un ejemplo de SFT donde la instrucción tiene $T_I$ tokens y la respuesta tiene $T_R$ tokens, la pérdida es:
Esto es idéntico a la pérdida CLM pero sumada solo sobre la porción de respuesta $[T_I+1, \ldots, T_I + T_R]$, lo que a veces se llama enmascaramiento de instrucciones o enmascaramiento de prompt . Si observamos el caso límite donde $T_I = 0$ (sin instrucción, solo una respuesta), la pérdida SFT se reduce a la pérdida CLM estándar sobre toda la secuencia, que es exactamente pre-training. El SFT es realmente solo pre-training en una distribución más dirigida.
¿Dónde Falla el SFT?
El SFT produce modelos que siguen instrucciones y generan respuestas coherentes y bien formateadas. Es el cambio de comportamiento más grande en el pipeline de entrenamiento (la diferencia entre un modelo base que divaga y un modelo de chat que responde preguntas). Pero tiene limitaciones estructurales que se hacen evidentes una vez que miramos de cerca cómo la pérdida trata los tokens individuales.
La pérdida de entropía cruzada trata cada token en la respuesta objetivo por igual. Si la respuesta de referencia es "La capital de Francia es París", el modelo es penalizado en la misma cantidad por equivocarse en "La" que por equivocarse en "París". Pero claramente estos tokens no llevan la misma información: "París" es la respuesta real, mientras que "La capital de Francia es" es una formulación estándar que razonablemente podría tomar muchas formas. Una pérdida que ponderara más los tokens críticos para la respuesta proporcionaría una mejor señal de aprendizaje, pero el SFT estándar no tiene mecanismo para distinguir qué tokens importan.
Hay un problema más profundo. El SFT fuerza un único camino de generación específico. Si un ejemplo de entrenamiento responde "¿Cuánto es 2+2?" con "La respuesta es 4", el modelo aprende a producir exactamente esos tokens en ese orden. Pero "4" y "2+2 es igual a 4" y "Cuatro" son todas respuestas válidas. El SFT penaliza al modelo por producir cualquiera de las alternativas, incluso las correctas, porque la pérdida de entropía cruzada mide la divergencia respecto a la secuencia de referencia única, no respecto al conjunto de respuestas aceptables. El modelo está siendo entrenado para imitar, no para ser correcto.
Algunos trabajos recientes abordan el problema de la ponderación de tokens directamente. Los enfoques de SFT con ponderación de tokens asignan diferentes pesos de pérdida a diferentes tokens en la respuesta según su importancia. Un método es usar un modelo de recompensa separado para puntuar la contribución de cada token a la calidad de la respuesta, y luego aumentar el peso de los tokens que el modelo de recompensa considera importantes. La reflexión selectiva para fine-tuning (Li et al., 2023) toma un enfoque relacionado, usando un modelo profesor para identificar qué tokens en una respuesta son más críticos y ponderando la pérdida en consecuencia. Estos métodos tienden a mejorar el rendimiento en benchmarks en comparación con el SFT de peso uniforme, pero requieren saber cómo puntuar tokens individuales, lo que introduce su propia complejidad.
También podemos ver el problema de rigidez desde una perspectiva distribucional. El modelo pre-entrenado tiene una distribución amplia sobre posibles continuaciones para cualquier prefijo dado. El SFT estrecha esa distribución para coincidir con los ejemplos de entrenamiento, lo cual es deseable (queremos que el modelo produzca buenas respuestas) pero también frágil (el modelo aprende a producir estas específicas buenas respuestas en lugar de la clase general de buenas respuestas). El modelo se convierte en un imitador preciso de sus datos de entrenamiento, lo cual funciona bien cuando los datos de entrenamiento son exhaustivos pero falla cuando el modelo encuentra preguntas novedosas que requieren formulaciones novedosas.
¿Qué Viene Después del SFT?
Hagamos un balance de dónde estamos. El pre-training nos da un modelo que entiende el lenguaje. El SFT remodela su comportamiento para que siga instrucciones en lugar de imitar texto de internet. El SFT con ponderación de tokens aborda parcialmente el problema de la ponderación igualitaria diciéndole al modelo qué tokens importan más. Pero todos estos enfoques comparten una restricción fundamental: operan a nivel de token, diciéndole al modelo exactamente qué tokens producir (o ponderándolos), una posición a la vez.
¿Qué pasaría si tomáramos un enfoque completamente diferente? En lugar de especificar los tokens de salida correctos token por token, podríamos dejar que el modelo genere una respuesta completa como quiera, y luego puntuar la respuesta completa con una señal de recompensa. Una respuesta que es correcta, útil y bien razonada recibe una recompensa alta; una que es incorrecta o inútil recibe una recompensa baja. El modelo no recibe orientación sobre qué tokens cambiar, así que tiene que descubrirlo por sí mismo explorando diferentes estrategias de generación y observando cuáles conducen a recompensas más altas.
Esto es el aprendizaje por refuerzo a partir de retroalimentación humana (RLHF) , introducido para modelos de lenguaje por Ouyang et al. (2022) en el artículo de InstructGPT. El pipeline tiene tres etapas: (1) pre-entrenar un modelo base, (2) hacer fine-tuning con SFT para obtener un modelo que siga instrucciones, y (3) refinar aún más con RL usando un modelo de recompensa entrenado con datos de preferencias humanas. Cada etapa aborda las limitaciones de la anterior: el pre-training proporciona conocimiento, el SFT proporciona formato, y el RL proporciona alineamiento con lo que los humanos realmente quieren.
El RL resuelve en principio ambos problemas centrales del SFT. No requiere especificar qué tokens importan, porque la recompensa se calcula sobre la salida completa, así que el modelo aprende por sí mismo qué tokens son fundamentales. Y no fuerza un único camino de generación, porque cualquier respuesta que logre una recompensa alta es reforzada, independientemente de si coincide con una referencia específica. Si "4" y "Cuatro" y "La respuesta es 4" reciben la misma recompensa, el modelo aprende que las tres son válidas.
Por supuesto, el RL introduce sus propios desafíos: reward hacking (el modelo encuentra atajos que explotan el modelo de recompensa), inestabilidad de entrenamiento (los métodos de gradiente de política son notoriamente de alta varianza), y la necesidad de un buen modelo de recompensa en primer lugar. Estos son los temas de la serie de aprendizaje por refuerzo, que continúa exactamente donde termina este artículo.
La trayectoria desde esta serie a la siguiente es una línea recta. Hemos ido desde "¿qué es la atención?" a "¿cómo encajan todas las piezas?" a "¿cómo entrenamos un modelo a escala?" a "¿cómo hacemos que siga instrucciones?" hasta la pregunta abierta: "¿cómo hacemos que sea realmente bueno siguiendo instrucciones, no solo imitativo?" Esa pregunta (y su respuesta a través del aprendizaje por refuerzo) es donde comienza la siguiente serie.
Quiz
Pon a prueba tu comprensión del pre-training y el ajuste fino supervisado.
¿Cuál es la diferencia clave entre CLM y MLM como objetivos de pre-training?
Según las leyes de escalado de Chinchilla, ¿cuál es la proporción óptima aproximada de tokens de entrenamiento a parámetros del modelo?
¿Por qué el SFT funciona con datasets relativamente pequeños (decenas de miles de ejemplos) comparado con el pre-training (billones de tokens)?
¿Qué limitación fundamental del SFT aborda el aprendizaje por refuerzo?