¿Por qué hacer fine-tuning?
Un modelo de embeddings de propósito general entrenado con pares de Common Crawl funciona bien en una amplia gama de tareas, pero no está optimizado para nuestro dominio. Los documentos legales, registros médicos y código fuente tienen distribuciones que difieren significativamente del texto web, por lo que el fine-tuning con pares de recuperación específicos del dominio tiende a mejorar el recall y NDCG en benchmarks de dominio (frecuentemente entre 5-15 puntos porcentuales en consultas difíciles, aunque la ganancia varía según la distancia del dominio y la calidad de los datos).
El ingrediente esencial son pares (consulta, documento_positivo), idealmente acompañados de negativos difíciles. Estos pueden provenir de registros de usuarios existentes (consulta, resultado clickeado), de anotación, o de generación sintética con un modelo de lenguaje.
InfoNCE y negativos en lote
El objetivo de entrenamiento estándar para recuperación con bi-encoder es InfoNCE (Information Noise-Contrastive Estimation), también llamado NT-Xent o pérdida contrastiva. Para un lote de $B$ pares (consulta, documento positivo), la pérdida para la consulta $i$ es:
La temperatura $ au$ es un escalar que controla qué tan pronunciada es la distribución softmax. Un $ au$ pequeño (digamos 0.01) produce una distribución aguda donde el positivo obtiene casi toda la masa de probabilidad y el gradiente es mayormente cero para todos los negativos excepto el más difícil. Un $ au$ grande (digamos 1.0) distribuye la probabilidad entre todos los negativos, proporcionando más señal de gradiente pero un entrenamiento menos discriminativo.
En código, el escalado de temperatura ocurre antes del softmax. Dividir por $ au$ es equivalente a multiplicar los logits por $1/ au$ (cambia la nitidez de la distribución antes del softmax, no las puntuaciones en sí). El siguiente fragmento simula esto para una consulta contra ocho documentos y grafica las distribuciones de probabilidad resultantes a cuatro temperaturas diferentes:
import math, json
import js
def softmax(logits):
m = max(logits)
exps = [math.exp(x - m) for x in logits]
s = sum(exps)
return [e / s for e in exps]
# Simulated similarity scores for one query against 8 docs (doc 0 is positive)
raw_scores = [0.85, 0.62, 0.58, 0.71, 0.45, 0.39, 0.67, 0.52]
temperatures = [0.05, 0.1, 0.5, 1.0]
doc_labels = [f"Doc {i+1}" for i in range(len(raw_scores))]
plot_lines = []
colors = ["#3b82f6", "#10b981", "#f59e0b", "#ef4444"]
for temp, color in zip(temperatures, colors):
scaled = [s / temp for s in raw_scores]
probs = softmax(scaled)
plot_lines.append({
"label": f"τ={temp}",
"data": [round(p, 4) for p in probs],
"color": color
})
plot_data = [
{
"title": "Temperature Scaling: Softmax Distribution over Negatives",
"x_label": "Document (Doc 1 = positive)",
"y_label": "Probability",
"x_data": doc_labels,
"lines": plot_lines
}
]
js.window.py_plot_data = json.dumps(plot_data)
El gráfico ilustra el efecto práctico. Con $ au=0.05$, Doc 1 (el positivo) captura casi toda la masa de probabilidad, por lo que el gradiente fluye casi enteramente desde el único negativo más difícil. Con $ au=1.0$, la probabilidad se distribuye ampliamente y los negativos más fáciles aún reciben señal de gradiente, lo cual puede ser útil al inicio del entrenamiento.
Los negativos en lote son el otro detalle clave de implementación. Para un lote de $B$ pares, la matriz de similitud $B imes B$ se calcula en una sola multiplicación de matrices donde la diagonal contiene pares positivos y las entradas fuera de la diagonal en la misma fila sirven como negativos. Esto nos da $B(B-1)$ pares negativos por lote de forma gratuita. Lotes más grandes generalmente mejoran la calidad del entrenamiento (por eso el entrenamiento distribuido con tamaños de lote grandes por GPU importa para modelos de embeddings), aunque los retornos tienden a disminuir más allá de cierto punto.
Minería de negativos difíciles
Los negativos en lote son documentos aleatorios de otros ejemplos de entrenamiento, lo que significa que típicamente son fáciles de distinguir del verdadero positivo (documentos "no relacionados" que la mayoría de los modelos adecuadamente entrenados ya puntuarán más bajo). Entrenar solo con negativos fáciles produce un modelo que nunca aprende a discriminar entre documentos casi-relevantes y verdaderamente relevantes.
Los negativos difíciles son documentos que son temáticamente similares al positivo pero no realmente relevantes (el tipo que confunde a un recuperador en producción). Su minería típicamente involucra usar BM25 o un modelo de embeddings más débil para recuperar los top-$K$ candidatos y luego excluir los positivos anotados; los candidatos principales restantes se convierten en los negativos difíciles.
El principal riesgo con este enfoque son los falsos negativos. Si un "negativo difícil" es realmente relevante para la consulta pero no fue anotado, el entrenamiento empujará al modelo a puntuarlo más bajo, lo cual es incorrecto. Para mitigar esto, los pipelines de minería de negativos difíciles usualmente incluyen un paso de filtrado donde los negativos candidatos se puntúan con un modelo más potente (un cross-encoder o un modelo de lenguaje frontier) y cualquiera que puntúe por encima de un umbral de relevancia se descarta.
ANCE (Approximate nearest-neighbour Negative Contrastive Estimation) (Xiong et al., 2020) aborda un problema más sutil: los negativos difíciles se vuelven obsoletos. En lugar de minarlos una vez antes del entrenamiento, ANCE los actualiza cada $T$ pasos re-indexando todos los embeddings de documentos con el modelo actual, minando nuevos negativos difíciles contra el índice actualizado, y continuando el entrenamiento. Esto evita que el modelo memorice sus negativos y asegura que siempre entrene contra sus casos de fallo actuales.
Destilación de conocimiento desde cross-encoders
Los cross-encoders son más precisos que los bi-encoders pero demasiado lentos para la recuperación de primera etapa, por lo que la destilación de conocimiento nos permite transferir los ricos juicios de relevancia del cross-encoder a los productos punto del bi-encoder (entrenando un modelo rápido para que se comporte como uno lento).
El cross-encoder maestro puntúa cada par consulta-documento con una puntuación de relevancia de valor real $s_j^{ ext{CE}}$. Después de aplicar softmax sobre la lista de candidatos para producir las etiquetas suaves del maestro $\mathbf{p}^{ ext{teacher}}$, entrenamos al bi-encoder estudiante para que coincida con esta distribución minimizando la divergencia KL:
Es útil expandir la divergencia KL como $H(p^{ ext{teacher}}, p^{ ext{student}}) - H(p^{ ext{teacher}})$. El segundo término (la entropía de la distribución del maestro) no depende de los parámetros del estudiante, por lo que minimizar KL es equivalente a minimizar la entropía cruzada $H(p^{ ext{teacher}}, p^{ ext{student}})$ con respecto al estudiante. En código esto significa que solo necesitamos calcular la entropía cruzada entre las etiquetas suaves del maestro y los logits del estudiante (el término de entropía del maestro se cancela y nunca se calcula). La siguiente función implementa esto, incluyendo la corrección de magnitud de gradiente $T^2$ de Hinton et al. (2015) :
import torch
import torch.nn.functional as F
def distillation_loss(student_logits, teacher_logits, temperature=1.0):
"""
student_logits: (batch_size, num_candidates) — bi-encoder dot products
teacher_logits: (batch_size, num_candidates) — cross-encoder scores
KL(teacher || student) = CE(teacher_probs, student_log_probs) - H(teacher)
H(teacher) is constant w.r.t. student parameters, so we minimize CE only.
"""
# Scale by temperature — softer distributions transfer more information
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
# Cross-entropy: -sum(teacher_probs * student_log_probs)
# F.kl_div expects log-probabilities for input, probabilities for target
loss = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean')
return loss * (temperature ** 2) # rescale to match original scale
# The temperature ** 2 factor comes from the gradient analysis:
# when logits are scaled by 1/T, the gradients shrink by 1/T^2,
# so multiplying by T^2 restores gradient magnitude to the T=1 baseline.
# This was shown in the original knowledge distillation paper (Hinton et al., 2015).
El reescalado $T^2$ aborda un efecto secundario sutil del escalado de temperatura. Cuando dividimos los logits por la temperatura $T$, los gradientes con respecto a los logits se reducen por $1/T^2$. Multiplicar la pérdida por $T^2$ restaura la magnitud del gradiente al valor base de $T=1$, lo que significa que la misma tasa de aprendizaje funciona independientemente de la temperatura que elijamos.
BGE-M3 y entrenamiento multi-pérdida
BGE-M3 de FlagEmbedding (Chen et al., 2024) combina todas las ideas anteriores en un solo bucle de entrenamiento. Para cada lote calcula tres pérdidas InfoNCE (una por cabeza de recuperación: densa, dispersa y multi-vector) más una pérdida de auto-destilación de conocimiento donde la puntuación del ensemble sirve como maestro suave:
La puntuación del maestro de auto-KD es una combinación ponderada de las tres cabezas de recuperación ($s_{ ext{teacher}} = \alpha\, s_{ ext{dense}} + \beta\, s_{ ext{sparse}} + \gamma\, s_{ ext{multi-vec}}$), calculada con los parámetros del modelo actual por lo que no se necesita un modelo maestro separado. Dado que la puntuación del ensemble tiende a ser más precisa que cualquier cabeza individual, destilarla de vuelta a cada cabeza las empuja a ser consistentes entre sí.
El siguiente código implementa las tres funciones de puntuación. Para la puntuación multi-vector estilo ColBERT, el cómputo de MaxSim usa el mismo patrón einsum que vimos en el artículo de interacción tardía. La función de puntuación dispersa calcula un producto punto sobre pesos de términos aprendidos (un peso por token del vocabulario), lo cual es análogo a BM25 pero con pesos aprendidos en lugar de estadísticos:
import torch
def colbert_score(q_reps, p_reps):
"""
Multi-vector (ColBERT-style) scoring via MaxSim.
q_reps: (num_queries, num_query_tokens, dim)
p_reps: (num_passages, num_passage_tokens, dim)
Returns: (num_queries, num_passages) relevance scores
"""
# All pairwise token similarities: (Q, P, q_tokens, p_tokens)
token_scores = torch.einsum('qin,pjn->qipj', q_reps, p_reps)
# MaxSim: for each query token, max over passage tokens,
# then sum over query tokens
scores = token_scores.max(dim=-1).values.sum(dim=-1)
return scores
def sparse_score(q_weights, p_weights):
"""
Learned sparse scoring (analogous to BM25 with learned term weights).
q_weights: (num_queries, vocab_size) — per-token importance weights
p_weights: (num_passages, vocab_size)
Returns: (num_queries, num_passages) relevance scores
"""
# Dot product over sparse term-weight vectors.
# In practice these are very sparse (ReLU output), so only
# overlapping terms contribute — similar to lexical matching.
scores = torch.matmul(q_weights, p_weights.T)
return scores
def dense_score(q_reps, p_reps, temperature=0.02):
"""
Dense scoring via pooled sentence embeddings.
q_reps: (num_queries, dim)
p_reps: (num_passages, dim)
"""
# Simple dot product, then scale by 1/temperature before softmax.
# Lower temperature -> sharper distribution -> harder negatives needed
scores = torch.matmul(q_reps, p_reps.T)
return scores / temperature # division happens BEFORE softmax in the loss
Observa `scores / temperature` en la función de puntuación densa. La temperatura se aplica a los logits antes de que entren al softmax en la pérdida InfoNCE, no a las probabilidades después. Esta es la misma operación que el escalado de temperatura en la generación de modelos de lenguaje, donde reducir la temperatura hace la distribución más pronunciada. Para el entrenamiento de recuperación, una temperatura más baja (por ejemplo, 0.02) crea un problema de entrenamiento más difícil porque el positivo debe puntuar significativamente más alto que todos los negativos para lograr una pérdida baja.
Quiz
Verifica tu comprensión del fine-tuning de embeddings para recuperación.
En la pérdida InfoNCE, ¿cuál es el efecto de usar una temperatura muy pequeña (por ejemplo, τ = 0.02)?
¿Por qué minimizar KL(teacher || student) es equivalente a minimizar la entropía cruzada H(teacher, student) con respecto al estudiante?
ANCE mejora la minería de negativos difíciles mediante:
En el código de destilación de conocimiento, ¿por qué se multiplica la pérdida por T² cuando se usa temperatura T > 1?
En la auto-destilación de conocimiento de BGE-M3, ¿qué sirve como maestro?