Losses

sentence_transformers.cross_encoder.losses defines different loss functions that can be used to fine-tune cross-encoder models on training data. The choice of loss function plays a critical role when fine-tuning the model. It determines how well our model will work for the specific downstream task.

Sadly, there is no “one size fits all” loss function. Which loss function is suitable depends on the available training data and on the target task. Consider checking out the Loss Overview to help narrow down your choice of loss function(s).

BinaryCrossEntropyLoss

class sentence_transformers.cross_encoder.losses.BinaryCrossEntropyLoss(model: CrossEncoder, activation_fn: Module = Identity(), pos_weight: Tensor | None = None, **kwargs)[source]

Computes the Binary Cross Entropy Loss for a CrossEncoder model. This loss is used to train a model to predict a high logit for positive pairs and a low logit for negative pairs. The model should be initialized with num_labels = 1 (a.k.a. the default) to predict one class.

It has been used to train many of the strong CrossEncoder MS MARCO Reranker models.

Parameters:
  • model (CrossEncoder) – A CrossEncoder model to be trained.

  • activation_fn (Module) – Activation function applied to the logits before computing the loss. Defaults to Identity.

  • pos_weight (Tensor, optional) – A weight of positive examples. Must be a torch.Tensor like torch.tensor(4) for a weight of 4. Defaults to None.

  • **kwargs – Additional keyword arguments passed to the underlying torch.nn.BCEWithLogitsLoss.

References

Requirements:
  1. Your model must be initialized with num_labels = 1 (a.k.a. the default) to predict one class.

Inputs:

Texts

Labels

Number of Model Output Labels

(anchor, positive/negative) pairs

1 if positive, 0 if negative

1

(sentence_A, sentence_B) pairs

float similarity score between 0 and 1

1

Recommendations:
  • Use mine_hard_negatives with output_format="labeled-pair" to convert question-answer pairs to the (anchor, positive/negative) pairs format with labels as 1 or 0, using hard negatives.

Example

from sentence_transformers.cross_encoder import CrossEncoder, CrossEncoderTrainer, losses
from datasets import Dataset

model = CrossEncoder("microsoft/mpnet-base")
train_dataset = Dataset.from_dict({
    "query": ["What are pandas?", "What are pandas?"],
    "response": ["Pandas are a kind of bear.", "Pandas are a kind of fish."],
    "label": [1, 0],
})
loss = losses.BinaryCrossEntropyLoss(model)

trainer = CrossEncoderTrainer(
    model=model,
    train_dataset=train_dataset,
    loss=loss,
)
trainer.train()

CrossEntropyLoss

class sentence_transformers.cross_encoder.losses.CrossEntropyLoss(model: CrossEncoder, activation_fn: Module = Identity(), **kwargs)[source]

Computes the Cross Entropy Loss for a CrossEncoder model. This loss is used to train a model to predict the correct class label for a given pair of sentences. The number of classes should be equal to the number of model output labels.

Parameters:
  • model (CrossEncoder) – A CrossEncoder model to be trained.

  • activation_fn (Module) – Activation function applied to the logits before computing the loss. Defaults to Identity.

  • **kwargs – Additional keyword arguments passed to the underlying torch.nn.CrossEntropyLoss.

References

Requirements:
  1. Your model can be initialized with num_labels > 1 to predict multiple classes.

  2. The number of dataset classes should be equal to the number of model output labels (model.num_labels).

Inputs:

Texts

Labels

Number of Model Output Labels

(sentence_A, sentence_B) pairs

class

num_classes

Example

from sentence_transformers.cross_encoder import CrossEncoder, CrossEncoderTrainer, losses
from datasets import Dataset

model = CrossEncoder("microsoft/mpnet-base", num_labels=2)
train_dataset = Dataset.from_dict({
    "sentence1": ["How can I be a good geologist?", "What is the capital of France?"],
    "sentence2": ["What should I do to be a great geologist?", "What is the capital of Germany?"],
    "label": [1, 0],  # 1: duplicate, 0: not duplicate
})
loss = losses.CrossEntropyLoss(model)

trainer = CrossEncoderTrainer(
    model=model,
    train_dataset=train_dataset,
    loss=loss,
)
trainer.train()

LambdaLoss

class sentence_transformers.cross_encoder.losses.LambdaLoss(model: ~sentence_transformers.cross_encoder.CrossEncoder.CrossEncoder, weighting_scheme: ~sentence_transformers.cross_encoder.losses.LambdaLoss.BaseWeightingScheme | None = NDCGLoss2PPScheme(   (ndcg_loss2): NDCGLoss2Scheme()   (lambda_rank): LambdaRankScheme() ), k: int | None = None, sigma: float = 1.0, eps: float = 1e-10, reduction_log: ~typing.Literal['natural', 'binary'] = 'binary', activation_fn: ~torch.nn.modules.module.Module | None = Identity(), mini_batch_size: int | None = None)[source]

The LambdaLoss Framework for Ranking Metric Optimization. This loss function implements the LambdaLoss framework for ranking metric optimization, which provides various weighting schemes including LambdaRank and NDCG variations. The implementation is optimized to handle padded documents efficiently by only processing valid documents during model inference.

Note

The number of documents per query can vary between samples with the LambdaLoss.

Parameters:
  • model (CrossEncoder) – CrossEncoder model to be trained

  • weighting_scheme (BaseWeightingScheme, optional) –

    Weighting scheme to use for the loss.

    Defaults to NDCGLoss2PPScheme. In the original LambdaLoss paper, the NDCGLoss2PPScheme was shown to reach the strongest performance, with the NDCGLoss2Scheme following closely.

  • k (int, optional) – Number of documents to consider for NDCG@K. Defaults to None (use all documents).

  • sigma (float) – Score difference weight used in sigmoid

  • eps (float) – Small constant for numerical stability

  • reduction_log (str) – Type of logarithm to use - “natural”: Natural logarithm (log) - “binary”: Binary logarithm (log2)

  • activation_fn (Module) – Activation function applied to the logits before computing the loss. Defaults to Identity.

  • mini_batch_size (int, optional) –

    Number of samples to process in each forward pass. This has a significant impact on the memory consumption and speed of the training process. Three cases are possible:

    • If mini_batch_size is None, the mini_batch_size is set to the batch size.

    • If mini_batch_size is greater than 0, the batch is split into mini-batches of size mini_batch_size.

    • If mini_batch_size is <= 0, the entire batch is processed at once.

    Defaults to None.

References

Requirements:
  1. Query with multiple documents (listwise approach)

  2. Documents must have relevance scores/labels. Both binary and continuous labels are supported.

Inputs:

Texts

Labels

Number of Model Output Labels

(query, [doc1, doc2, …, docN])

[score1, score2, …, scoreN]

1

Recommendations:
  • Use mine_hard_negatives with output_format="labeled-list" to convert question-answer pairs to the required input format with hard negatives.

Relations:
  • LambdaLoss anecdotally performs better than the other losses with the same input format.

Example

from sentence_transformers.cross_encoder import CrossEncoder, CrossEncoderTrainer, losses
from datasets import Dataset

model = CrossEncoder("microsoft/mpnet-base")
train_dataset = Dataset.from_dict({
    "query": ["What are pandas?", "What is the capital of France?"],
    "docs": [
        ["Pandas are a kind of bear.", "Pandas are kind of like fish."],
        ["The capital of France is Paris.", "Paris is the capital of France.", "Paris is quite large."],
    ],
    "labels": [[1, 0], [1, 1, 0]],
})
loss = losses.LambdaLoss(model)

trainer = CrossEncoderTrainer(
    model=model,
    train_dataset=train_dataset,
    loss=loss,
)
trainer.train()
class sentence_transformers.cross_encoder.losses.LambdaLoss.BaseWeightingScheme(*args, **kwargs)[source]

Base class for implementing weighting schemes in LambdaLoss.

class sentence_transformers.cross_encoder.losses.NoWeightingScheme(*args, **kwargs)[source]

Implementation of no weighting scheme (weights = 1.0).

class sentence_transformers.cross_encoder.losses.NDCGLoss1Scheme(*args, **kwargs)[source]

Implementation of NDCG Loss1 weighting scheme.

It is used to optimize for the NDCG metric, but this weighting scheme is not recommended as the NDCGLoss2Scheme and NDCGLoss2PPScheme were shown to reach superior performance in the original LambdaLoss paper.

class sentence_transformers.cross_encoder.losses.NDCGLoss2Scheme(*args, **kwargs)[source]

Implementation of NDCG Loss2 weighting scheme.

This scheme uses a tighter bound than NDCGLoss1Scheme and was shown to reach superior performance in the original LambdaLoss paper. It is used to optimize for the NDCG metric.

class sentence_transformers.cross_encoder.losses.LambdaRankScheme(*args, **kwargs)[source]

Implementation of LambdaRank weighting scheme.

This weighting optimizes a coarse upper bound of NDCG.

class sentence_transformers.cross_encoder.losses.NDCGLoss2PPScheme(mu: float = 10.0)[source]

Implementation of NDCG Loss2++ weighting scheme.

It is a hybrid weighting scheme that combines the NDCGLoss2 and LambdaRank schemes. It was shown to reach the strongest performance in the original LambdaLoss paper.

ListMLELoss

class sentence_transformers.cross_encoder.losses.ListMLELoss(model: CrossEncoder, activation_fn: Module | None = Identity(), mini_batch_size: int | None = None, respect_input_order: bool = True)[source]

This loss function implements the ListMLE learning to rank algorithm, which uses a list-wise approach based on maximum likelihood estimation of permutations. It maximizes the likelihood of the permutation induced by the ground truth labels.

Note

The number of documents per query can vary between samples with the ListMLELoss.

Parameters:
  • model (CrossEncoder) – CrossEncoder model to be trained

  • activation_fn (Module) – Activation function applied to the logits before computing the loss. Defaults to Identity.

  • mini_batch_size (int, optional) –

    Number of samples to process in each forward pass. This has a significant impact on the memory consumption and speed of the training process. Three cases are possible:

    • If mini_batch_size is None, the mini_batch_size is set to the batch size.

    • If mini_batch_size is greater than 0, the batch is split into mini-batches of size mini_batch_size.

    • If mini_batch_size is <= 0, the entire batch is processed at once.

    Defaults to None.

  • respect_input_order (bool) – Whether to respect the original input order of documents. If True, assumes the input documents are already ordered by relevance (most relevant first). If False, sorts documents by label values. Defaults to True.

References

Requirements:
  1. Query with multiple documents (listwise approach)

  2. Documents must have relevance scores/labels. Both binary and continuous labels are supported.

  3. Documents must be sorted in a defined rank order.

Inputs:

Texts

Labels

Number of Model Output Labels

(query, [doc1, doc2, …, docN])

[score1, score2, …, scoreN]

1

Recommendations:
  • Use mine_hard_negatives with output_format="labeled-list" to convert question-answer pairs to the required input format with hard negatives.

Relations:

Example

from sentence_transformers.cross_encoder import CrossEncoder, CrossEncoderTrainer, losses
from datasets import Dataset

model = CrossEncoder("microsoft/mpnet-base")
train_dataset = Dataset.from_dict({
    "query": ["What are pandas?", "What is the capital of France?"],
    "docs": [
        ["Pandas are a kind of bear.", "Pandas are kind of like fish."],
        ["The capital of France is Paris.", "Paris is the capital of France.", "Paris is quite large."],
    ],
    "labels": [[1, 0], [1, 1, 0]],
})

# Standard ListMLE loss respecting input order
loss = losses.ListMLELoss(model)

trainer = CrossEncoderTrainer(
    model=model,
    train_dataset=train_dataset,
    loss=loss,
)
trainer.train()

PListMLELoss

class sentence_transformers.cross_encoder.losses.PListMLELoss(model: CrossEncoder, lambda_weight: PListMLELambdaWeight | None = PListMLELambdaWeight(), activation_fn: Module | None = Identity(), mini_batch_size: int | None = None, respect_input_order: bool = True)[source]

PListMLE loss for learning to rank with position-aware weighting. This loss function implements the ListMLE ranking algorithm which uses a list-wise approach based on maximum likelihood estimation of permutations. It maximizes the likelihood of the permutation induced by the ground truth labels with position-aware weighting.

This loss is also known as Position-Aware ListMLE or p-ListMLE.

Note

The number of documents per query can vary between samples with the PListMLELoss.

Parameters:
  • model (CrossEncoder) – CrossEncoder model to be trained

  • lambda_weight (PListMLELambdaWeight, optional) – Weighting scheme to use. When specified, implements Position-Aware ListMLE which applies different weights to different rank positions. Default is None (standard PListMLE).

  • activation_fn (Module) – Activation function applied to the logits before computing the loss. Defaults to Identity.

  • mini_batch_size (int, optional) –

    Number of samples to process in each forward pass. This has a significant impact on the memory consumption and speed of the training process. Three cases are possible:

    • If mini_batch_size is None, the mini_batch_size is set to the batch size.

    • If mini_batch_size is greater than 0, the batch is split into mini-batches of size mini_batch_size.

    • If mini_batch_size is <= 0, the entire batch is processed at once.

    Defaults to None.

  • respect_input_order (bool) – Whether to respect the original input order of documents. If True, assumes the input documents are already ordered by relevance (most relevant first). If False, sorts documents by label values. Defaults to True.

References

Requirements:
  1. Query with multiple documents (listwise approach)

  2. Documents must have relevance scores/labels. Both binary and continuous labels are supported.

  3. Documents must be sorted in a defined rank order.

Inputs:

Texts

Labels

Number of Model Output Labels

(query, [doc1, doc2, …, docN])

[score1, score2, …, scoreN]

1

Recommendations:
  • Use mine_hard_negatives with output_format="labeled-list" to convert question-answer pairs to the required input format with hard negatives.

Relations:

Example

from sentence_transformers.cross_encoder import CrossEncoder, CrossEncoderTrainer, losses
from datasets import Dataset

model = CrossEncoder("microsoft/mpnet-base")
train_dataset = Dataset.from_dict({
    "query": ["What are pandas?", "What is the capital of France?"],
    "docs": [
        ["Pandas are a kind of bear.", "Pandas are kind of like fish."],
        ["The capital of France is Paris.", "Paris is the capital of France.", "Paris is quite large."],
    ],
    "labels": [[1, 0], [1, 1, 0]],
})

# Either: Position-Aware ListMLE with default weighting
lambda_weight = losses.PListMLELambdaWeight()
loss = losses.PListMLELoss(model, lambda_weight=lambda_weight)

# or: Position-Aware ListMLE with custom weighting function
def custom_discount(ranks): # e.g. ranks: [1, 2, 3, 4, 5]
    return 1.0 / torch.log1p(ranks)
lambda_weight = losses.PListMLELambdaWeight(rank_discount_fn=custom_discount)
loss = losses.PListMLELoss(model, lambda_weight=lambda_weight)

trainer = CrossEncoderTrainer(
    model=model,
    train_dataset=train_dataset,
    loss=loss,
)
trainer.train()
class sentence_transformers.cross_encoder.losses.PListMLELambdaWeight(rank_discount_fn=None)[source]

Base class for implementing weighting schemes in Position-Aware ListMLE Loss.

Initialize a lambda weight for PListMLE loss.

Parameters:

rank_discount_fn – Function that computes a discount for each rank position. If None, uses default discount of 2^(num_docs - rank) - 1.

ListNetLoss

class sentence_transformers.cross_encoder.losses.ListNetLoss(model: CrossEncoder, activation_fn: Module | None = Identity(), mini_batch_size: int | None = None)[source]

ListNet loss for learning to rank. This loss function implements the ListNet ranking algorithm which uses a list-wise approach to learn ranking models. It minimizes the cross entropy between the predicted ranking distribution and the ground truth ranking distribution. The implementation is optimized to handle padded documents efficiently by only processing valid documents during model inference.

Note

The number of documents per query can vary between samples with the ListNetLoss.

Parameters:
  • model (CrossEncoder) – CrossEncoder model to be trained

  • activation_fn (Module) – Activation function applied to the logits before computing the loss. Defaults to Identity.

  • mini_batch_size (int, optional) –

    Number of samples to process in each forward pass. This has a significant impact on the memory consumption and speed of the training process. Three cases are possible:

    • If mini_batch_size is None, the mini_batch_size is set to the batch size.

    • If mini_batch_size is greater than 0, the batch is split into mini-batches of size mini_batch_size.

    • If mini_batch_size is <= 0, the entire batch is processed at once.

    Defaults to None.

References

Requirements:
  1. Query with multiple documents (listwise approach)

  2. Documents must have relevance scores/labels. Both binary and continuous labels are supported.

Inputs:

Texts

Labels

Number of Model Output Labels

(query, [doc1, doc2, …, docN])

[score1, score2, …, scoreN]

1

Recommendations:
  • Use mine_hard_negatives with output_format="labeled-list" to convert question-answer pairs to the required input format with hard negatives.

Relations:
  • LambdaLoss takes the same inputs, and generally outperforms this loss.

Example

from sentence_transformers.cross_encoder import CrossEncoder, CrossEncoderTrainer, losses
from datasets import Dataset

model = CrossEncoder("microsoft/mpnet-base")
train_dataset = Dataset.from_dict({
    "query": ["What are pandas?", "What is the capital of France?"],
    "docs": [
        ["Pandas are a kind of bear.", "Pandas are kind of like fish."],
        ["The capital of France is Paris.", "Paris is the capital of France.", "Paris is quite large."],
    ],
    "labels": [[1, 0], [1, 1, 0]],
})
loss = losses.ListNetLoss(model)

trainer = CrossEncoderTrainer(
    model=model,
    train_dataset=train_dataset,
    loss=loss,
)
trainer.train()

MultipleNegativesRankingLoss

class sentence_transformers.cross_encoder.losses.MultipleNegativesRankingLoss(model: CrossEncoder, num_negatives: int | None = 4, scale: int = 10.0, activation_fn: Module | None = Sigmoid())[source]

Given a list of (anchor, positive) pairs or (anchor, positive, negative) triplets, this loss optimizes the following:

  • Given an anchor (e.g. a question), assign the highest similarity to the corresponding positive (i.e. answer) out of every single positive and negative (e.g. all answers) in the batch.

If you provide the optional negatives, they will all be used as extra options from which the model must pick the correct positive. Within reason, the harder this “picking” is, the stronger the model will become. Because of this, a higher batch size results in more in-batch negatives, which then increases performance (to a point).

This loss function works great to train embeddings for retrieval setups where you have positive pairs (e.g. (query, answer)) as it will sample in each batch n-1 negative docs randomly.

This loss is also known as InfoNCE loss, SimCSE loss, Cross-Entropy Loss with in-batch negatives, or simply in-batch negatives loss.

Parameters:
  • model (CrossEncoder) – A CrossEncoder model to be trained.

  • num_negatives (int, optional) – Number of in-batch negatives to sample for each anchor. Defaults to 4.

  • scale (int, optional) – Output of similarity function is multiplied by scale value. Defaults to 10.0.

  • activation_fn (Module) – Activation function applied to the logits before computing the loss. Defaults to Sigmoid.

Note

The current default values are subject to change in the future. Experimentation is encouraged.

References

Requirements:
  1. Your model must be initialized with num_labels = 1 (a.k.a. the default) to predict one class.

Inputs:

Texts

Labels

Number of Model Output Labels

(anchor, positive) pairs

none

1

(anchor, positive, negative) triplets

none

1

(anchor, positive, negative_1, …, negative_n)

none

1

Recommendations:
  • Use BatchSamplers.NO_DUPLICATES (docs) to ensure that no in-batch negatives are duplicates of the anchor or positive samples.

  • Use mine_hard_negatives with output_format="n-tuple" or output_format="triplet" to convert question-answer pairs to triplets with hard negatives.

Relations:
  • CachedMultipleNegativesRankingLoss is equivalent to this loss, but it uses caching that allows for much higher batch sizes (and thus better performance) without extra memory usage. However, it is slightly slower.

Example

from sentence_transformers.cross_encoder import CrossEncoder, CrossEncoderTrainer, losses
from datasets import Dataset

model = CrossEncoder("microsoft/mpnet-base")
train_dataset = Dataset.from_dict({
    "query": ["What are pandas?", "What is the capital of France?"],
    "answer": ["Pandas are a kind of bear.", "The capital of France is Paris."],
})
loss = losses.MultipleNegativesRankingLoss(model)

trainer = CrossEncoderTrainer(
    model=model,
    train_dataset=train_dataset,
    loss=loss,
)
trainer.train()

CachedMultipleNegativesRankingLoss

class sentence_transformers.cross_encoder.losses.CachedMultipleNegativesRankingLoss(model: CrossEncoder, num_negatives: int | None = 4, scale: float = 10.0, activation_fn: Module | None = Sigmoid(), mini_batch_size: int = 32, show_progress_bar: bool = False)[source]

Boosted version of MultipleNegativesRankingLoss that caches the gradients of the logits wrt. the loss. This allows for much higher batch sizes without extra memory usage. However, it is slightly slower.

In detail:

  1. It first does a quick prediction step without gradients/computation graphs to get all the logits;

  2. Calculate the loss, backward up to the logits and cache the gradients wrt. to the logits;

  3. A 2nd prediction step with gradients/computation graphs and connect the cached gradients into the backward chain.

Notes: All steps are done with mini-batches. In the original implementation of GradCache, (2) is not done in mini-batches and requires a lot memory when the batch size is large. The gradient caching will sacrifice around 20% computation time according to the paper.

Given a list of (anchor, positive) pairs or (anchor, positive, negative) triplets, this loss optimizes the following:

  • Given an anchor (e.g. a question), assign the highest similarity to the corresponding positive (i.e. answer) out of every single positive and negative (e.g. all answers) in the batch.

If you provide the optional negatives, they will all be used as extra options from which the model must pick the correct positive. Within reason, the harder this “picking” is, the stronger the model will become. Because of this, a higher batch size results in more in-batch negatives, which then increases performance (to a point).

This loss function works great to train embeddings for retrieval setups where you have positive pairs (e.g. (query, answer)) as it will sample in each batch n-1 negative docs randomly.

This loss is also known as InfoNCE loss with GradCache.

Parameters:
  • model (CrossEncoder) – A CrossEncoder model to be trained.

  • num_negatives (int, optional) – Number of in-batch negatives to sample for each anchor. Defaults to 4.

  • scale (int, optional) – Output of similarity function is multiplied by scale value. Defaults to 10.0.

  • activation_fn (Module) – Activation function applied to the logits before computing the loss. Defaults to Sigmoid.

  • mini_batch_size (int, optional) – Mini-batch size for the forward pass. This informs the memory usage. Defaults to 32.

  • show_progress_bar (bool, optional) – Whether to show a progress bar during the forward pass. Defaults to False.

Note

The current default values are subject to change in the future. Experimentation is encouraged.

References

Requirements:
  1. Your model must be initialized with num_labels = 1 (a.k.a. the default) to predict one class.

  2. Should be used with large per_device_train_batch_size and low mini_batch_size for superior performance, but slower training time than MultipleNegativesRankingLoss.

Inputs:

Texts

Labels

Number of Model Output Labels

(anchor, positive) pairs

none

1

(anchor, positive, negative) triplets

none

1

(anchor, positive, negative_1, …, negative_n)

none

1

Recommendations:
  • Use BatchSamplers.NO_DUPLICATES (docs) to ensure that no in-batch negatives are duplicates of the anchor or positive samples.

  • Use mine_hard_negatives with output_format="n-tuple" or output_format="triplet" to convert question-answer pairs to triplets with hard negatives.

Relations:

Example

from sentence_transformers.cross_encoder import CrossEncoder, CrossEncoderTrainer, losses
from datasets import Dataset

model = CrossEncoder("microsoft/mpnet-base")
train_dataset = Dataset.from_dict({
    "query": ["What are pandas?", "What is the capital of France?"],
    "answer": ["Pandas are a kind of bear.", "The capital of France is Paris."],
})
loss = losses.CachedMultipleNegativesRankingLoss(model, mini_batch_size=32)

trainer = CrossEncoderTrainer(
    model=model,
    train_dataset=train_dataset,
    loss=loss,
)
trainer.train()

MSELoss

class sentence_transformers.cross_encoder.losses.MSELoss(model: CrossEncoder, activation_fn: Module = Identity(), **kwargs)[source]

Computes the MSE loss between the computed query-passage score and a target query-passage score. This loss is used to distill a cross-encoder model from a teacher cross-encoder model or gold labels.

Parameters:
  • model (CrossEncoder) – A CrossEncoder model to be trained.

  • activation_fn (Module) – Activation function applied to the logits before computing the loss.

  • **kwargs – Additional keyword arguments passed to the underlying torch.nn.MSELoss.

Note

Be mindful of the magnitude of both the labels and what the model produces. If the teacher model produces logits with Sigmoid to bound them to [0, 1], then you may wish to use a Sigmoid activation function in the loss.

References

Requirements:
  1. Your model must be initialized with num_labels = 1 (a.k.a. the default) to predict one class.

  2. Usually uses a finetuned CrossEncoder teacher M in a knowledge distillation setup.

Inputs:

Texts

Labels

Number of Model Output Labels

(sentence_A, sentence_B) pairs

similarity score

1

Relations:
  • MarginMSELoss is similar to this loss, but with a margin through a negative pair.

Example

from sentence_transformers.cross_encoder import CrossEncoder, CrossEncoderTrainer, losses
from datasets import Dataset

student_model = CrossEncoder("microsoft/mpnet-base")
teacher_model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L12-v2")
train_dataset = Dataset.from_dict({
    "query": ["What are pandas?", "What is the capital of France?"],
    "answer": ["Pandas are a kind of bear.", "The capital of France is Paris."],
})

def compute_labels(batch):
    return {
        "label": teacher_model.predict(list(zip(batch["query"], batch["answer"])))
    }

train_dataset = train_dataset.map(compute_labels, batched=True)
loss = losses.MSELoss(student_model)

trainer = CrossEncoderTrainer(
    model=student_model,
    train_dataset=train_dataset,
    loss=loss,
)
trainer.train()

MarginMSELoss

class sentence_transformers.cross_encoder.losses.MarginMSELoss(model: CrossEncoder, activation_fn: Module = Identity(), **kwargs)[source]

Computes the MSE loss between |sim(Query, Pos) - sim(Query, Neg)| and |gold_sim(Query, Pos) - gold_sim(Query, Neg)|. This loss is often used to distill a cross-encoder model from a teacher cross-encoder model or gold labels.

In contrast to MultipleNegativesRankingLoss, the two passages do not have to be strictly positive and negative, both can be relevant or not relevant for a given query. This can be an advantage of MarginMSELoss over MultipleNegativesRankingLoss.

Note

Be mindful of the magnitude of both the labels and what the model produces. If the teacher model produces logits with Sigmoid to bound them to [0, 1], then you may wish to use a Sigmoid activation function in the loss.

Parameters:
  • model (CrossEncoder) – A CrossEncoder model to be trained.

  • activation_fn (Module) – Activation function applied to the logits before computing the loss.

  • **kwargs – Additional keyword arguments passed to the underlying torch.nn.MSELoss.

References

Requirements:
  1. Your model must be initialized with num_labels = 1 (a.k.a. the default) to predict one class.

  2. Usually uses a finetuned CrossEncoder teacher M in a knowledge distillation setup.

Inputs:

Texts

Labels

Number of Model Output Labels

(query, passage_one, passage_two) triplets

gold_sim(query, passage_one) - gold_sim(query, passage_two)

1

Relations:
  • MSELoss is similar to this loss, but without a margin through the negative pair.

Example

from sentence_transformers.cross_encoder import CrossEncoder, CrossEncoderTrainer, losses
from datasets import Dataset

student_model = CrossEncoder("microsoft/mpnet-base")
teacher_model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L12-v2")
train_dataset = Dataset.from_dict({
    "query": ["What are pandas?", "What is the capital of France?"],
    "positive": ["Pandas are a kind of bear.", "The capital of France is Paris."],
    "negative": ["Pandas are a kind of fish.", "The capital of France is Berlin."],
})

def compute_labels(batch):
    positive_scores = teacher_model.predict(list(zip(batch["query"], batch["positive"])))
    negative_scores = teacher_model.predict(list(zip(batch["query"], batch["negative"])))
    return {
        "label": positive_scores - negative_scores
    }

train_dataset = train_dataset.map(compute_labels, batched=True)
loss = losses.MarginMSELoss(student_model)

trainer = CrossEncoderTrainer(
    model=student_model,
    train_dataset=train_dataset,
    loss=loss,
)
trainer.train()

RankNetLoss

class sentence_transformers.cross_encoder.losses.RankNetLoss(model: CrossEncoder, k: int | None = None, sigma: float = 1.0, eps: float = 1e-10, reduction_log: Literal['natural', 'binary'] = 'binary', activation_fn: Module | None = Identity(), mini_batch_size: int | None = None)[source]

RankNet loss implementation for learning to rank. This loss function implements the RankNet algorithm, which learns a ranking function by optimizing pairwise document comparisons using a neural network. The implementation is optimized to handle padded documents efficiently by only processing valid documents during model inference.

Parameters:
  • model (CrossEncoder) – CrossEncoder model to be trained

  • sigma (float) – Score difference weight used in sigmoid (default: 1.0)

  • eps (float) – Small constant for numerical stability (default: 1e-10)

  • activation_fn (Module) – Activation function applied to the logits before computing the loss. Defaults to Identity.

  • mini_batch_size (int, optional) – Number of samples to process in each forward pass. This has a significant impact on the memory consumption and speed of the training process. Three cases are possible: - If mini_batch_size is None, the mini_batch_size is set to the batch size. - If mini_batch_size is greater than 0, the batch is split into mini-batches of size mini_batch_size. - If mini_batch_size is <= 0, the entire batch is processed at once. Defaults to None.

References

Requirements:
  1. Query with multiple documents (pairwise approach)

  2. Documents must have relevance scores/labels. Both binary and continuous labels are supported.

Inputs:

Texts

Labels

Number of Model Output Labels

(query, [doc1, doc2, …, docN])

[score1, score2, …, scoreN]

1

Recommendations:
  • Use mine_hard_negatives with output_format="labeled-list" to convert question-answer pairs to the required input format with hard negatives.

Relations:
  • LambdaLoss can be seen as an extension of this loss where each score pair is weighted. Alternatively, this loss can be seen as a special case of the LambdaLoss without a weighting scheme.

  • LambdaLoss with its default NDCGLoss2++ weighting scheme anecdotally performs better than the other losses with the same input format.

Example

from sentence_transformers.cross_encoder import CrossEncoder, CrossEncoderTrainer, losses
from datasets import Dataset

model = CrossEncoder("microsoft/mpnet-base")
train_dataset = Dataset.from_dict({
    "query": ["What are pandas?", "What is the capital of France?"],
    "docs": [
        ["Pandas are a kind of bear.", "Pandas are kind of like fish."],
        ["The capital of France is Paris.", "Paris is the capital of France.", "Paris is quite large."],
    ],
    "labels": [[1, 0], [1, 1, 0]],
})
loss = losses.RankNetLoss(model)

trainer = CrossEncoderTrainer(
    model=model,
    train_dataset=train_dataset,
    loss=loss,
)
trainer.train()