¿Cómo Aumentamos la Probabilidad de Buenas Acciones?
Establecimos en el artículo anterior que el entrenamiento basado en RL permite al modelo generar libremente y recibir una recompensa escalar por la respuesta completa. La pregunta ahora es mecánica: dada esa señal de recompensa, ¿cómo actualizamos realmente los pesos del modelo? Necesitamos un gradiente — una dirección en el espacio de parámetros que, al seguirla, haga más probables las respuestas de alta recompensa y menos probables las de baja recompensa.
La respuesta más simple es el algoritmo REINFORCE (Williams, 1992) . La idea es clara: muestrear una trayectoria (generar una respuesta completa), observar la recompensa, y luego ajustar los parámetros del modelo para que las acciones tomadas durante esa trayectoria se vuelvan más o menos probables en proporción a cuán buena fue la recompensa. Formalmente, queremos maximizar el retorno esperado $J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta}[R(\tau)]$, donde $\tau$ es una trayectoria muestreada de la policy $\pi_\theta$ y $R(\tau)$ es la recompensa. El teorema del policy gradient nos da el gradiente de este objetivo:
Cada parte de esta fórmula juega un papel específico. $\nabla_\theta \log \pi_\theta(a_t \mid s_t)$ es la dirección en el espacio de parámetros que aumentaría la probabilidad de tomar la acción $a_t$ en el estado $s_t$ (en términos de modelos de lenguaje, esta es la dirección que hace más probable que el modelo produzca el token $a_t$ dado el prompt y todos los tokens generados hasta el momento). $R_t$ es el retorno desde el paso temporal $t$ en adelante, y actúa como un factor de escala: cuando $R_t$ es grande y positivo, damos un paso grande en la dirección que aumenta la probabilidad de este token; cuando $R_t$ es negativo, damos un paso en la dirección opuesta, haciendo este token menos probable.
La esperanza $\mathbb{E}_{\tau \sim \pi_\theta}$ significa que promediamos sobre muchas trayectorias muestreadas. En la práctica, aproximamos esto con un batch de respuestas generadas, calculando el gradiente para cada una y promediando. Cuantas más trayectorias muestreemos, mejor será nuestra estimación del gradiente verdadero, pero cada trayectoria requiere un pase forward completo a través del modelo, por lo que hay un tradeoff directo entre calidad del gradiente y cómputo.
Para ver por qué esto funciona, consideremos qué sucede con una sola trayectoria. Si generamos una respuesta y recibe alta recompensa, cada token en esa respuesta ve aumentada su probabilidad. Si recibe baja recompensa, cada token ve disminuida su probabilidad. A lo largo de muchas trayectorias, los tokens que aparecen consistentemente en respuestas de alta recompensa verán aumentadas sus probabilidades, mientras que los tokens que aparecen en respuestas de baja recompensa serán suprimidos. El algoritmo descubre qué decisiones a nivel de token conducen a buenos resultados puramente a partir de retroalimentación a nivel de respuesta, sin necesitar nunca supervisión por token.
¿Por Qué REINFORCE Tiene Problemas con la Varianza?
REINFORCE es correcto en esperanza (dado un número infinito de trayectorias, converge al gradiente verdadero). Pero en la práctica trabajamos con batches finitos, y la varianza de la estimación del gradiente puede ser enorme. Supongamos que muestreamos dos respuestas al mismo prompt: una obtiene $R = 8$ y otra obtiene $R = 2$. Ambas son positivas, así que REINFORCE aumenta la probabilidad de ambas respuestas (solo más para la primera). Pero el gradiente de ninguna de las respuestas "sabe" que $8$ es bueno y $2$ es mediocre; solo ven valores absolutos de recompensa, no relativos.
La solución es restar una línea base $b$ de la recompensa, reemplazando $R_t$ con $R_t - b$. Restar una línea base constante no cambia el gradiente esperado (porque $\mathbb{E}[\nabla \log \pi \cdot b] = 0$ para cualquier constante $b$) pero puede reducir dramáticamente la varianza al centrar la señal de recompensa alrededor de cero. Si $b$ está cerca del retorno promedio, entonces las trayectorias por encima del promedio obtienen peso positivo (son empujadas hacia arriba) y las que están por debajo obtienen peso negativo (son empujadas hacia abajo).
La línea base más común es la función de valor $V^\pi(s_t)$, que estima el retorno esperado desde el estado $s_t$ bajo la policy actual. Restarla nos da la ventaja (advantage) :
La ventaja $A_t$ responde una pregunta precisa: ¿fue el resultado real mejor o peor de lo que esperábamos? Si $A_t > 0$, la acción tomada fue mejor que el promedio para ese estado, y debemos aumentar su probabilidad. Si $A_t < 0$, fue peor, y debemos disminuirla. Esto es estrictamente más informativo que la recompensa bruta porque tiene en cuenta el contexto — una recompensa de $5$ es excelente si esperábamos $2$, pero decepcionante si esperábamos $8$.
Con la ventaja, el policy gradient se convierte en:
En la práctica, entrenamos una red crítica separada (que a menudo comparte un backbone con la policy) para predecir $V^\pi(s_t)$. Este crítico se actualiza junto con la policy usando regresión estándar sobre los retornos observados. La combinación de una policy (el "actor") y una función de valor (el "crítico") se llama arquitectura actor-critic , y es la base de prácticamente todos los métodos modernos de policy gradient.
¿Cómo Previene PPO las Actualizaciones Destructivas?
Incluso con una buena estimación de la ventaja, los métodos de policy gradient básicos tienden a ser inestables. Un solo batch con una trayectoria de recompensa inusualmente alta puede producir un gradiente grande que se pasa de largo, cambiando drásticamente la policy de una manera que degrada el rendimiento. Una vez que la policy se ha desplazado demasiado, las estimaciones de la función de valor se vuelven obsoletas, los cálculos de ventaja se vuelven poco fiables y el entrenamiento puede entrar en espiral. Las policies de redes neuronales son particularmente frágiles aquí porque un pequeño cambio en los pesos puede producir un gran cambio en la distribución de salida.
Proximal Policy Optimization (PPO) (Schulman et al., 2017) resuelve esto con un objetivo sustituto recortado (clipped surrogate objective) que impide que la policy cambie demasiado en una sola actualización. En lugar de usar directamente $\nabla \log \pi_\theta \cdot A_t$, PPO trabaja con la razón de probabilidad entre las policies nueva y antigua:
Esta razón $r_t$ mide cuánto ha cambiado la probabilidad de la policy actualizada para la acción $a_t$ en relación con la policy que originalmente generó la trayectoria. Si $r_t = 1$, la nueva policy asigna la misma probabilidad que la antigua. Si $r_t = 1.5$, la nueva policy es 50% más probable de tomar esta acción. Si $r_t = 0.5$, es la mitad de probable. El objetivo de PPO usa esta razón junto con la ventaja:
Veamos qué hace el recorte (clipping) en cada caso.
Cuando $A_t > 0$ (una buena acción que queremos reforzar), el término sin recorte $r_t \cdot A_t$ crece a medida que aumentamos $r_t$, animando al optimizador a hacer esta acción cada vez más probable. Sin el recorte, una ventaja muy alta podría empujar $r_t$ a valores extremos, concentrando toda la masa de probabilidad en esta única acción. El término $\text{clip}(r_t, 1-\varepsilon, 1+\varepsilon)$ limita $r_t$ a $1+\varepsilon$, por lo que el término recortado se satura una vez que la razón de probabilidad excede $1+\varepsilon$. El $\min$ luego toma el valor más bajo, por lo que una vez que $r_t > 1+\varepsilon$, no hay más gradiente empujando la razón hacia arriba. El modelo todavía puede ser incentivado a tomar esta acción más, pero no excesivamente en una sola actualización.
Cuando $A_t < 0$ (una mala acción que queremos suprimir), el término sin recorte $r_t \cdot A_t$ se vuelve menos negativo a medida que $r_t$ disminuye (ya que estamos multiplicando un número positivo decreciente por una ventaja negativa), que es exactamente lo que la optimización quiere: reducir $r_t$ para minimizar la contribución negativa. Pero el recorte impide que $r_t$ caiga por debajo de $1-\varepsilon$, por lo que el modelo no puede huir en pánico de esta acción en un solo paso. Nuevamente, el $\min$ toma el valor más pesimista (más bajo), asegurando que el gradiente se desvanezca una vez que $r_t$ cae por debajo de $1-\varepsilon$.
El hiperparámetro $\varepsilon$ controla cuánto puede moverse la policy en una actualización y típicamente se establece entre $0.1$ y $0.2$. Un $\varepsilon$ más pequeño significa actualizaciones más conservadoras (más estables pero aprendizaje más lento); un $\varepsilon$ más grande permite pasos más grandes (más rápido pero más arriesgado). Con $\varepsilon = 0.2$, la probabilidad de cualquier acción individual puede cambiar como máximo un 20% por paso de actualización en cualquier dirección.
Para ver por qué el $\min$ es necesario, consideremos qué sucede sin él. Si solo tuviéramos el término recortado, el objetivo se estancaría fuera del rango de recorte pero no impediría activamente que la razón se moviera más. El $\min$ asegura que el objetivo sea siempre el más conservador de los dos términos (una cota inferior pesimista). Cuando la razón está dentro de $[1-\varepsilon, 1+\varepsilon]$, ambos términos son idénticos y el gradiente fluye normalmente. Cuando la razón se desvía fuera de ese rango, el gradiente se elimina, creando una región de confianza alrededor de la policy antigua.
¿Cómo Se Ve el Bucle de Entrenamiento?
Con el objetivo recortado definido, podemos esbozar el bucle de entrenamiento de PPO de principio a fin. El algoritmo alterna entre dos fases: (1) recolectar trayectorias dejando que la policy actual genere respuestas, y (2) ejecutar varias épocas de actualizaciones de gradiente sobre esas trayectorias usando el objetivo recortado. El siguiente pseudocódigo muestra la estructura.
# PPO training loop (simplified for language model fine-tuning)
for iteration in range(num_iterations):
# ── Phase 1: Collect trajectories ──────────────────────────
prompts = sample_batch(prompt_dataset, batch_size)
with torch.no_grad():
responses = policy.generate(prompts) # sample full responses
old_log_probs = policy.log_probs(prompts, responses) # π_old(a|s)
rewards = reward_model(prompts, responses) # scalar per response
values = critic(prompts, responses) # V(s) per token position
advantages = compute_gae(rewards, values, gamma, lam) # A_t via GAE
# ── Phase 2: PPO update (multiple epochs on same batch) ───
for epoch in range(ppo_epochs): # typically 2-4 epochs
new_log_probs = policy.log_probs(prompts, responses)
# Probability ratio r_t = π_new / π_old
ratio = torch.exp(new_log_probs - old_log_probs)
# Clipped surrogate objective
unclipped = ratio * advantages
clipped = torch.clamp(ratio, 1 - epsilon, 1 + epsilon) * advantages
policy_loss = -torch.min(unclipped, clipped).mean()
# Value function loss (train the critic)
value_loss = F.mse_loss(critic(prompts, responses), returns)
# Combined loss
loss = policy_loss + 0.5 * value_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
Hay algunas cosas que notar en este bucle. Los
old_log_probs
se calculan una vez con
torch.no_grad()
y se congelan; sirven como punto de referencia para la razón de probabilidad. El bucle interno ejecuta múltiples pasos de gradiente sobre el mismo batch de trayectorias, que es la ganancia clave de eficiencia de PPO sobre el policy gradient básico (donde cada batch se usaría para una sola actualización y luego se descartaría). El recorte asegura que estos múltiples pases no muevan la policy demasiado lejos de donde se recolectaron las trayectorias, manteniendo válidas las estimaciones de ventaja.
La función
compute_gae
calcula Generalized Advantage Estimation
(Schulman et al., 2016)
, que combina estimaciones de ventaja de un solo paso y de múltiples pasos usando un parámetro de decaimiento $\lambda$. En la práctica, $\gamma$ (el factor de descuento) suele estar cerca de $1.0$ para tareas de modelos de lenguaje ya que nos importa la recompensa total de la respuesta, y $\lambda$ típicamente es $0.95$.
Cuando este bucle se aplica a modelos de lenguaje (como en RLHF), hay un ingrediente adicional que aún no hemos mostrado: una penalización de KL que mantiene a la policy de alejarse demasiado del modelo SFT original. Esa penalización es crítica para el alineamiento y es el enfoque del siguiente artículo.
Quiz
Pon a prueba tu comprensión de policy gradients y PPO.
En el policy gradient de REINFORCE, ¿qué papel juega la recompensa R_t?
¿Por qué restamos una línea base de la recompensa para calcular la ventaja?
En el objetivo recortado de PPO, ¿qué sucede cuando A_t > 0 y la razón r_t excede 1 + ε?
¿Por qué PPO ejecuta múltiples épocas de gradiente sobre el mismo batch de trayectorias?