Loss Overview
Warning
To train a SparseEncoder
, you need either SpladeLoss
or CSRLoss
, depending on the architecture. These are wrapper losses that add sparsity regularization on top of a main loss function, which must be provided as a parameter. The only loss that can be used independently is SparseMSELoss
, as it performs embedding-level distillation, ensuring sparsity by directly copying the teacher’s sparse embedding.
Sparse specific Loss Functions
SPLADE Loss
The SpladeLoss
implements a specialized loss function for SPLADE (Sparse Lexical and Expansion) models. It combines a main loss function with regularization terms to control efficiency:
Supports all the losses mention below as main loss but three principal loss types:
SparseMultipleNegativesRankingLoss
,SparseMarginMSELoss
andSparseDistillKLDivLoss
.Uses
FlopsLoss
for regularization to control sparsity by default, but supports custom regularizers.Balances effectiveness (via the main loss) with efficiency by regularizing both query and document representations.
Allows using different regularizers for queries and documents via the
query_regularizer
anddocument_regularizer
parameters, enabling fine-grained control over sparsity patterns for different types of inputs.Supports separate threshold values for queries and documents via the
query_regularizer_threshold
anddocument_regularizer_threshold
parameters, allowing different sparsity strictness levels for each input type.
CSR Loss
If you are using the SparseAutoEncoder
module, then you have to use the CSRLoss
(Contrastive Sparse Representation Loss). It combines two components:
A reconstruction loss
CSRReconstructionLoss
that ensures sparse representation can faithfully reconstruct original embeddings.A main loss, which in the paper is a contrastive learning component using
SparseMultipleNegativesRankingLoss
that ensures semanticallysimilar sentences have similar representations. But it’s theorically possible to use all the losses mention below as main loss like forSpladeLoss
.
Loss Table
Loss functions play a critical role in the performance of your fine-tuned model. Sadly, there is no “one size fits all” loss function. Ideally, this table should help narrow down your choice of loss function(s) by matching them to your data formats.
Note
You can often convert one training data format into another, allowing more loss functions to be viable for your scenario. For example,
(sentence_A, sentence_B) pairs
withclass
labels can be converted into(anchor, positive, negative) triplets
by sampling sentences with the same or different classes.
Note
The loss functions in SentenceTransformer > Loss Overview that appear here with the Sparse
prefix are identical to their dense versions. The prefix is used only to indicate which losses can be used as main losses to train a SparseEncoder
Inputs | Labels | Appropriate Loss Functions |
---|---|---|
(anchor, positive) pairs |
none |
SparseMultipleNegativesRankingLoss |
(sentence_A, sentence_B) pairs |
float similarity score between 0 and 1 |
SparseCoSENTLoss SparseAnglELoss SparseCosineSimilarityLoss |
(anchor, positive, negative) triplets |
none |
SparseMultipleNegativesRankingLoss SparseTripletLoss |
(anchor, positive, negative_1, ..., negative_n) |
none |
SparseMultipleNegativesRankingLoss |
Distillation
These loss functions are specifically designed to be used when distilling the knowledge from one model into another. This is rather commonly used when training Sparse embedding models.
Texts | Labels | Appropriate Loss Functions |
---|---|---|
sentence |
model sentence embeddings |
SparseMSELoss |
sentence_1, sentence_2, ..., sentence_N |
model sentence embeddings |
SparseMSELoss |
(query, passage_one, passage_two) triplets |
gold_sim(query, passage_one) - gold_sim(query, passage_two) |
SparseMarginMSELoss |
(query, positive, negative) triplets |
[gold_sim(query, positive), gold_sim(query, negative)] |
SparseDistillKLDivLoss SparseMarginMSELoss |
(query, positive, negative_1, ..., negative_n) |
[gold_sim(query, positive) - gold_sim(query, negative_i) for i in 1..n] |
SparseMarginMSELoss |
(query, positive, negative_1, ..., negative_n) |
[gold_sim(query, positive), gold_sim(query, negative_i)...] |
SparseDistillKLDivLoss SparseMarginMSELoss |
Commonly used Loss Functions
In practice, not all loss functions get used equally often. The most common scenarios are:
(anchor, positive) pairs
without any labels:SparseMultipleNegativesRankingLoss
(a.k.a. InfoNCE or in-batch negatives loss) is commonly used to train the top performing embedding models. This data is often relatively cheap to obtain, and the models are generally very performant. Here for our sparse retrieval tasks, this format works well withSpladeLoss
orCSRLoss
, both typically using InfoNCE as their underlying loss function.(query, positive, negative_1, ..., negative_n)
format: This structure with multiple negatives is particularly effective withSpladeLoss
configured withSparseMarginMSELoss
, especially in knowledge distillation scenarios where a teacher model provides similarity scores. The strongest models are trained with distillation losses likeSparseDistillKLDivLoss
orSparseMarginMSELoss
.
Custom Loss Functions
Advanced users can create and train with their own loss functions. Custom loss functions only have a few requirements:
They must be a subclass of
torch.nn.Module
.They must have
model
as the first argument in the constructor.They must implement a
forward
method that acceptssentence_features
andlabels
. The former is a list of tokenized batches, one element for each column. These tokenized batches can be fed directly to themodel
being trained to produce embeddings. The latter is an optional tensor of labels. The method must return a single loss value or a dictionary of loss components (component names to loss values) that will be summed to produce the final loss value. When returning a dictionary, the individual components will be logged separately in addition to the summed loss, allowing you to monitor the individual components of the loss.
To get full support with the automatic model card generation, you may also wish to implement:
a
get_config_dict
method that returns a dictionary of loss parameters.a
citation
property so your work gets cited in all models that train with the loss.
Consider inspecting existing loss functions to get a feel for how loss functions are commonly implemented.