Why Fine-Tune?
A general-purpose embedding model trained on Common Crawl pairs works well across a broad range of tasks, but it is not optimised for our domain. Legal documents, medical records, and source code all have distributions that differ significantly from web text, so fine-tuning on domain-specific retrieval pairs tends to improve recall and NDCG on domain benchmarks (often by 5-15 percentage points on hard queries, though the gain varies with domain distance and data quality).
The essential ingredient is (query, positive_document) pairs, ideally accompanied by hard negatives. These can come from existing user logs (query, clicked result), from annotation, or from synthetic generation with a language model.
InfoNCE and In-Batch Negatives
The standard training objective for bi-encoder retrieval is InfoNCE (Information Noise-Contrastive Estimation), also called NT-Xent or the contrastive loss. For a batch of $B$ (query, positive document) pairs, the loss for query $i$ is:
The temperature $\tau$ is a scalar that controls how peaked the softmax distribution is. A small $\tau$ (say 0.01) produces a sharp distribution where the positive gets nearly all the probability mass and the gradient is mostly zero for all negatives except the hardest one. A large $\tau$ (say 1.0) spreads probability across all negatives, providing more gradient signal but less discriminative training.
In code, temperature scaling happens before the softmax. Dividing by $\tau$ is equivalent to multiplying the logits by $1/\tau$ (it changes the sharpness of the distribution pre-softmax, not the scores themselves). The following snippet simulates this for one query against eight documents and plots the resulting probability distributions at four different temperatures:
import math, json
import js
def softmax(logits):
m = max(logits)
exps = [math.exp(x - m) for x in logits]
s = sum(exps)
return [e / s for e in exps]
# Simulated similarity scores for one query against 8 docs (doc 0 is positive)
raw_scores = [0.85, 0.62, 0.58, 0.71, 0.45, 0.39, 0.67, 0.52]
temperatures = [0.05, 0.1, 0.5, 1.0]
doc_labels = [f"Doc {i+1}" for i in range(len(raw_scores))]
plot_lines = []
colors = ["#3b82f6", "#10b981", "#f59e0b", "#ef4444"]
for temp, color in zip(temperatures, colors):
scaled = [s / temp for s in raw_scores]
probs = softmax(scaled)
plot_lines.append({
"label": f"τ={temp}",
"data": [round(p, 4) for p in probs],
"color": color
})
plot_data = [
{
"title": "Temperature Scaling: Softmax Distribution over Negatives",
"x_label": "Document (Doc 1 = positive)",
"y_label": "Probability",
"x_data": doc_labels,
"lines": plot_lines
}
]
js.window.py_plot_data = json.dumps(plot_data)
The plot illustrates the practical effect. With $\tau=0.05$, Doc 1 (the positive) captures nearly all the probability mass, so the gradient flows almost entirely from the single hardest negative. With $\tau=1.0$, the probability spreads widely and easier negatives still receive gradient signal, which can be helpful early in training.
In-batch negatives are the other key implementation detail. For a batch of $B$ pairs, the $B \times B$ similarity matrix is computed in one matrix multiplication where the diagonal holds positive pairs and off-diagonal entries within the same row serve as negatives. This gives us $B(B-1)$ negative pairs per batch for free. Larger batches generally improve training quality (which is why distributed training with large per-GPU batch sizes matters for embedding models), though returns tend to diminish beyond a certain point.
Hard Negative Mining
In-batch negatives are random documents from other training examples, which means they are typically easy to distinguish from the true positive ("unrelated" documents that most adequately trained models will already score lower). Training on only easy negatives produces a model that never learns to discriminate between nearly-relevant and truly-relevant documents.
Hard negatives are documents that are topically similar to the positive but not actually relevant (the kind that confuse a retriever in production). Mining them typically involves using BM25 or a weaker embedding model to retrieve the top-$K$ candidates and then excluding the annotated positives; the remaining top candidates become the hard negatives.
The main risk with this approach is false negatives. If a "hard negative" is actually relevant to the query but was not annotated, training will push the model to score it lower, which is incorrect. To mitigate this, hard negative mining pipelines usually include a filtering step where candidate negatives are scored with a more powerful model (a cross-encoder or a frontier language model) and any that score above a relevance threshold are discarded.
ANCE (Approximate nearest-neighbour Negative Contrastive Estimation) (Xiong et al., 2020) addresses a subtler problem: hard negatives go stale. Rather than mining them once before training, ANCE refreshes them every $T$ steps by re-indexing all document embeddings with the current model, mining new hard negatives against the fresh index, and continuing training. This prevents the model from memorising its negatives and ensures it always trains against its current failure cases.
Knowledge Distillation from Cross-Encoders
Cross-encoders are more accurate than bi-encoders but too slow for first-stage retrieval, so knowledge distillation lets us transfer the cross-encoder's rich relevance judgements into the bi-encoder's dot products (training a fast model to behave like a slow one).
The teacher cross-encoder scores each query-document pair with a real-valued relevance score $s_j^{\text{CE}}$. After applying softmax over the candidate list to produce the teacher's soft labels $\mathbf{p}^{\text{teacher}}$, we train the student bi-encoder to match this distribution by minimising KL divergence:
It helps to expand the KL divergence as $H(p^{\text{teacher}}, p^{\text{student}}) - H(p^{\text{teacher}})$. The second term (the entropy of the teacher distribution) does not depend on the student parameters, so minimising KL is equivalent to minimising the cross-entropy $H(p^{\text{teacher}}, p^{\text{student}})$ with respect to the student. In code this means we only need to compute the cross-entropy between the teacher soft labels and the student logits (the teacher's entropy term drops out and is never computed). The following function implements this, including the $T^2$ gradient-magnitude correction from Hinton et al. (2015) :
import torch
import torch.nn.functional as F
def distillation_loss(student_logits, teacher_logits, temperature=1.0):
"""
student_logits: (batch_size, num_candidates) — bi-encoder dot products
teacher_logits: (batch_size, num_candidates) — cross-encoder scores
KL(teacher || student) = CE(teacher_probs, student_log_probs) - H(teacher)
H(teacher) is constant w.r.t. student parameters, so we minimize CE only.
"""
# Scale by temperature — softer distributions transfer more information
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
# Cross-entropy: -sum(teacher_probs * student_log_probs)
# F.kl_div expects log-probabilities for input, probabilities for target
loss = F.kl_div(student_log_probs, teacher_probs, reduction='batchmean')
return loss * (temperature ** 2) # rescale to match original scale
# The temperature ** 2 factor comes from the gradient analysis:
# when logits are scaled by 1/T, the gradients shrink by 1/T^2,
# so multiplying by T^2 restores gradient magnitude to the T=1 baseline.
# This was shown in the original knowledge distillation paper (Hinton et al., 2015).
The $T^2$ rescaling addresses a subtle side-effect of temperature scaling. When we divide logits by temperature $T$, the gradients with respect to the logits shrink by $1/T^2$. Multiplying the loss by $T^2$ restores gradient magnitude to the $T=1$ baseline, which means the same learning rate works regardless of the temperature we choose.
BGE-M3 and Multi-Loss Training
FlagEmbedding's BGE-M3 (Chen et al., 2024) combines all of the ideas above in a single training loop. For each batch it computes three InfoNCE losses (one per retrieval head: dense, sparse, and multi-vector) plus a self-knowledge distillation loss where the ensemble score serves as a soft teacher:
The self-KD teacher score is a weighted combination of all three retrieval heads ($s_{\text{teacher}} = \alpha\, s_{\text{dense}} + \beta\, s_{\text{sparse}} + \gamma\, s_{\text{multi-vec}}$), computed with the current model parameters so no separate teacher model is needed. Because the ensemble score tends to be more accurate than any individual head, distilling it back into each head pushes them to be consistent with each other.
The following code implements the three scoring functions. For ColBERT-style multi-vector scoring, the MaxSim computation uses the same einsum pattern we saw in the late interaction article. The sparse scoring function computes a dot product over learned term weights (one weight per vocabulary token), which is analogous to BM25 but with learned rather than statistical weights:
import torch
def colbert_score(q_reps, p_reps):
"""
Multi-vector (ColBERT-style) scoring via MaxSim.
q_reps: (num_queries, num_query_tokens, dim)
p_reps: (num_passages, num_passage_tokens, dim)
Returns: (num_queries, num_passages) relevance scores
"""
# All pairwise token similarities: (Q, P, q_tokens, p_tokens)
token_scores = torch.einsum('qin,pjn->qipj', q_reps, p_reps)
# MaxSim: for each query token, max over passage tokens,
# then sum over query tokens
scores = token_scores.max(dim=-1).values.sum(dim=-1)
return scores
def sparse_score(q_weights, p_weights):
"""
Learned sparse scoring (analogous to BM25 with learned term weights).
q_weights: (num_queries, vocab_size) — per-token importance weights
p_weights: (num_passages, vocab_size)
Returns: (num_queries, num_passages) relevance scores
"""
# Dot product over sparse term-weight vectors.
# In practice these are very sparse (ReLU output), so only
# overlapping terms contribute — similar to lexical matching.
scores = torch.matmul(q_weights, p_weights.T)
return scores
def dense_score(q_reps, p_reps, temperature=0.02):
"""
Dense scoring via pooled sentence embeddings.
q_reps: (num_queries, dim)
p_reps: (num_passages, dim)
"""
# Simple dot product, then scale by 1/temperature before softmax.
# Lower temperature -> sharper distribution -> harder negatives needed
scores = torch.matmul(q_reps, p_reps.T)
return scores / temperature # division happens BEFORE softmax in the loss
Notice `scores / temperature` in the dense score function. The temperature is applied to the logits before they enter the softmax in the InfoNCE loss, not to the probabilities after. This is the same operation as temperature scaling in language model generation, where lowering the temperature makes the distribution more peaked. For retrieval training a lower temperature (e.g. 0.02) creates a harder training problem because the positive must score significantly higher than all negatives to achieve low loss.
Quiz
Check your understanding of embedding fine-tuning for retrieval.
In the InfoNCE loss, what is the effect of using a very small temperature (e.g. τ = 0.02)?
Why is minimising KL(teacher || student) equivalent to minimising cross-entropy H(teacher, student) with respect to the student?
ANCE improves hard negative mining by:
In knowledge distillation code, why is the loss multiplied by T² when using temperature T > 1?
In BGE-M3's self-knowledge distillation, what serves as the teacher?