Training Overview
Why Finetune?
Cross Encoder models are very often used as 2nd stage rerankers in a Retrieve and Rerank search stack. In such a situation, the Cross Encoder reranks the top X candidates from the retriever (which can be a Sentence Transformer model). To avoid the reranker model reducing the performance on your use case, finetuning it can be crucial. Rerankers always have just 1 output label.
Beyond that, Cross Encoder models can also be used as pair classifiers. For example, a model trained on Natural Language Inference data can be used to classify pairs of texts as “contradiction”, “entailment”, and “neutral”. Pair Classifiers generally have more than 1 output label.
See Training Examples for numerous training scripts for common real-world applications that you can adopt.
Training Components
Training Cross Encoder models involves between 3 to 5 components, just like training Sentence Transformer models:
Dataset
The CrossEncoderTrainer
trains and evaluates using datasets.Dataset
(one dataset) or datasets.DatasetDict
instances (multiple datasets, see also Multi-dataset training).
If you want to load data from the Hugging Face Datasets, then you should use datasets.load_dataset()
:
from datasets import load_dataset
train_dataset = load_dataset("sentence-transformers/all-nli", "pair-class", split="train")
eval_dataset = load_dataset("sentence-transformers/all-nli", "pair-class", split="dev")
print(train_dataset)
"""
Dataset({
features: ['premise', 'hypothesis', 'label'],
num_rows: 942069
})
"""
Some datasets (including sentence-transformers/all-nli) require you to provide a “subset” alongside the dataset name. sentence-transformers/all-nli
has 4 subsets, each with different data formats: pair, pair-class, pair-score, triplet.
Note
Many Hugging Face datasets that work out of the box with Sentence Transformers have been tagged with sentence-transformers, allowing you to easily find them by browsing to https://huggingface.co/datasets?other=sentence-transformers. We strongly recommend that you browse these datasets to find training datasets that might be useful for your tasks.
If you have local data in common file-formats, then you can load this data easily using datasets.load_dataset()
:
from datasets import load_dataset
dataset = load_dataset("csv", data_files="my_file.csv")
or:
from datasets import load_dataset
dataset = load_dataset("json", data_files="my_file.json")
If you have local data that requires some extra pre-processing, my recommendation is to initialize your dataset using datasets.Dataset.from_dict()
and a dictionary of lists, like so:
from datasets import Dataset
anchors = []
positives = []
# Open a file, do preprocessing, filtering, cleaning, etc.
# and append to the lists
dataset = Dataset.from_dict({
"anchor": anchors,
"positive": positives,
})
Each key from the dictionary will become a column in the resulting dataset.
Dataset Format
It is important that your dataset format matches your loss function (or that you choose a loss function that matches your dataset format and model). Verifying whether a dataset format and model work with a loss function involves three steps:
All columns not named “label”, “labels”, “score”, or “scores” are considered Inputs according to the Loss Overview table. The number of remaining columns must match the number of valid inputs for your chosen loss. The names of these columns are irrelevant, only the order matters.
If your loss function requires a Label according to the Loss Overview table, then your dataset must have a column named “label”, “labels”, “score”, or “scores”. This column is automatically taken as the label.
The number of model output labels matches what is required for the loss according to Loss Overview table.
For example, given a dataset with columns ["text1", "text2", "label"]
where the “label” column has float similarity score ranging from 0 to 1 and a model outputting 1 label, we can use it with BinaryCrossEntropyLoss
because:
the dataset has a “label” column as is required for this loss function.
the dataset has 2 non-label columns, exactly the amount required by this loss functions.
the model has 1 output label, exactly as required by this loss function.
Be sure to re-order your dataset columns with Dataset.select_columns
if your columns are not ordered correctly. For example, if your dataset has ["good_answer", "bad_answer", "question"]
as columns, then this dataset can technically be used with a loss that requires (anchor, positive, negative) triplets, but the good_answer
column will be taken as the anchor, bad_answer
as the positive, and question
as the negative.
Additionally, if your dataset has extraneous columns (e.g. sample_id, metadata, source, type), you should remove these with Dataset.remove_columns
as they will be used as inputs otherwise. You can also use Dataset.select_columns
to keep only the desired columns.
Hard Negatives Mining
The success of training CrossEncoder models often depends on the quality of the negatives, i.e. the passages for which the query-negative score should be low. Negatives can be divided into two types:
Soft negatives: passages that are completely unrelated.
Hard negatives: passages that seem like they might be relevant for the query, but are not.
A concise example is:
Query: Where was Apple founded?
Soft Negative: The Cache River Bridge is a Parker pony truss that spans the Cache River between Walnut Ridge and Paragould, Arkansas.
Hard Negative: The Fuji apple is an apple cultivar developed in the late 1930s, and brought to market in 1962.
The strongest CrossEncoder models are generally trained to recognize hard negatives, and so it’s valuable to be able to “mine” hard negatives. Sentence Transformers supports a strong mine_hard_negatives()
function that can assist, given a dataset of query-answer pairs:
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import mine_hard_negatives
# Load the GooAQ dataset: https://huggingface.co/datasets/sentence-transformers/gooaq
train_dataset = load_dataset("sentence-transformers/gooaq", split=f"train").select(range(100_000))
print(train_dataset)
# Mine hard negatives using a very efficient embedding model
embedding_model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
hard_train_dataset = mine_hard_negatives(
train_dataset,
embedding_model,
num_negatives=5, # How many negatives per question-answer pair
range_min=10, # Skip the x most similar samples
range_max=100, # Consider only the x most similar samples
max_score=0.8, # Only consider samples with a similarity score of at most x
margin=0.1, # Similarity between query and negative samples should be x lower than query-positive similarity
sampling_strategy="top", # Sample the top negatives from the range
batch_size=4096, # Use a batch size of 4096 for the embedding model
output_format="labeled-pair", # The output format is (query, passage, label), as required by BinaryCrossEntropyLoss
use_faiss=True, # Using FAISS is recommended to keep memory usage low (pip install faiss-gpu or pip install faiss-cpu)
)
print(hard_train_dataset)
print(hard_train_dataset[1])
Click to see the outputs of this script.
Dataset({
features: ['question', 'answer'],
num_rows: 100000
})
Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [00:01<00:00, 13.74it/s]
Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 36.49it/s]
Querying FAISS index: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:19<00:00, 2.80s/it]
Metric Positive Negative Difference
Count 100,000 436,925
Mean 0.5882 0.4040 0.2157
Median 0.5989 0.4024 0.1836
Std 0.1425 0.0905 0.1013
Min -0.0514 0.1405 0.1014
25% 0.4993 0.3377 0.1352
50% 0.5989 0.4024 0.1836
75% 0.6888 0.4681 0.2699
Max 0.9748 0.7486 0.7545
Skipped 2420871 potential negatives (23.97%) due to the margin of 0.1.
Skipped 43 potential negatives (0.00%) due to the maximum score of 0.8.
Could not find enough negatives for 63075 samples (12.62%). Consider adjusting the range_max, range_min, margin and max_score parameters if you'd like to find more valid negatives.
Dataset({
features: ['question', 'answer', 'label'],
num_rows: 536925
})
{
'question': 'how to transfer bookmarks from one laptop to another?',
'answer': 'Using an External Drive Just about any external drive, including a USB thumb drive, or an SD card can be used to transfer your files from one laptop to another. Connect the drive to your old laptop; drag your files to the drive, then disconnect it and transfer the drive contents onto your new laptop.',
'label': 0
}
Loss Function
Loss functions quantify how well a model performs for a given batch of data, allowing an optimizer to update the model weights to produce more favourable (i.e., lower) loss values. This is the core of the training process.
Sadly, there is no single loss function that works best for all use-cases. Instead, which loss function to use greatly depends on your available data and on your target task. See Dataset Format to learn what datasets are valid for which loss functions. Additionally, the Loss Overview will be your best friend to learn about the options.
Most loss functions can be initialized with just the CrossEncoder
that you’re training, alongside some optional parameters, e.g.:
from datasets import load_dataset
from sentence_transformers import CrossEncoder
from sentence_transformers.cross_encoder.losses import MultipleNegativesRankingLoss
# Load a model to train/finetune
model = CrossEncoder("xlm-roberta-base", num_labels=1) # num_labels=1 is for rerankers
# Initialize the MultipleNegativesRankingLoss
# This loss requires pairs of related texts or triplets
loss = MultipleNegativesRankingLoss(model)
# Load an example training dataset that works with our loss function:
train_dataset = load_dataset("sentence-transformers/gooaq", split="train")
Training Arguments
The CrossEncoderTrainingArguments
class can be used to specify parameters for influencing training performance as well as defining the tracking/debugging parameters. Although it is optional, it is heavily recommended to experiment with the various useful arguments.
learning_rate
lr_scheduler_type
warmup_ratio
num_train_epochs
max_steps
per_device_train_batch_size
per_device_eval_batch_size
auto_find_batch_size
fp16
bf16
load_best_model_at_end
metric_for_best_model
gradient_accumulation_steps
gradient_checkpointing
eval_accumulation_steps
optim
dataloader_num_workers
dataloader_prefetch_factor
batch_sampler
multi_dataset_batch_sampler
Here is an example of how CrossEncoderTrainingArguments
can be initialized:
from sentence_transformers.cross_encoder import CrossEncoderTrainingArguments
args = CrossEncoderTrainingArguments(
# Required parameter:
output_dir="models/reranker-MiniLM-msmarco-v1",
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
learning_rate=2e-5,
warmup_ratio=0.1,
fp16=True, # Set to False if you get an error that your GPU can't run on FP16
bf16=False, # Set to True if you have a GPU that supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # losses that use "in-batch negatives" benefit from no duplicates
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=100,
save_total_limit=2,
logging_steps=100,
run_name="reranker-MiniLM-msmarco-v1", # Will be used in W&B if `wandb` is installed
)
Evaluator
You can provide the CrossEncoderTrainer
with an eval_dataset
to get the evaluation loss during training, but it may be useful to get more concrete metrics during training, too. For this, you can use evaluators to assess the model’s performance with useful metrics before, during, or after training. You can use both an eval_dataset
and an evaluator, one or the other, or neither. They evaluate based on the eval_strategy
and eval_steps
Training Arguments.
Here are the implemented Evaluators that come with Sentence Transformers:
Evaluator |
Required Data |
---|---|
Pairs with class labels (binary or multiclass). |
|
Pairs with similarity scores. |
|
No data required. |
|
List of |
Additionally, SequentialEvaluator
should be used to combine multiple evaluators into one Evaluator that can be passed to the CrossEncoderTrainer
.
Sometimes you don’t have the required evaluation data to prepare one of these evaluators on your own, but you still want to track how well the model performs on some common benchmarks. In that case, you can use these evaluators with data from Hugging Face.
from sentence_transformers import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CrossEncoderNanoBEIREvaluator
# Load a model
model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2")
# Initialize the evaluator. Unlike most other evaluators, this one loads the relevant datasets
# directly from Hugging Face, so there's no mandatory arguments
dev_evaluator = CrossEncoderNanoBEIREvaluator()
# You can run evaluation like so:
# results = dev_evaluator(model)
Preparing data for CrossEncoderRerankingEvaluator
can be difficult as you need negatives in addition to your query-positive data.
The mine_hard_negatives()
function has a convenient include_positives
parameter, which can be set to True
to also mine for the positive texts. When supplied as documents
(which have to be 1. ranked and 2. contain positives) to CrossEncoderRerankingEvaluator
, the evaluator will not just evaluate the reranking performance of the CrossEncoder, but also the original rankings by the embedding model used for mining.
For example:
CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev dataset:
Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 49.0, Mean 49.1, Max 50.0
Base -> Reranked
MAP: 53.28 -> 67.28
MRR@10: 52.40 -> 66.65
NDCG@10: 59.12 -> 71.35
Note that by default, if you are using CrossEncoderRerankingEvaluator
with documents
, the evaluator will rerank with all positives, even if they are not in the documents. This is useful for getting a stronger signal out of your evaluator, but does give a slightly unrealistic performance. After all, the maximum performance is now 100, whereas normally its bounded by whether the first-stage retriever actually retrieved the positives.
You can enable the realistic behaviour by setting always_rerank_positives=False
when initializing CrossEncoderRerankingEvaluator
. Repeating the same script with this realistic two-stage performance results in:
CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev dataset:
Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 49.0, Mean 49.1, Max 50.0
Base -> Reranked
MAP: 53.28 -> 66.12
MRR@10: 52.40 -> 65.61
NDCG@10: 59.12 -> 70.10
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CrossEncoderRerankingEvaluator
from sentence_transformers.util import mine_hard_negatives
# Load a model
model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2")
# Load the GooAQ dataset: https://huggingface.co/datasets/sentence-transformers/gooaq
full_dataset = load_dataset("sentence-transformers/gooaq", split=f"train").select(range(100_000))
dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"]
eval_dataset = dataset_dict["test"]
print(eval_dataset)
"""
Dataset({
features: ['question', 'answer'],
num_rows: 1000
})
"""
# Mine hard negatives using a very efficient embedding model
embedding_model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
hard_eval_dataset = mine_hard_negatives(
eval_dataset,
embedding_model,
corpus=full_dataset["answer"], # Use the full dataset as the corpus
num_negatives=50, # How many negatives per question-answer pair
batch_size=4096, # Use a batch size of 4096 for the embedding model
output_format="n-tuple", # The output format is (query, positive, negative1, negative2, ...) for the evaluator
include_positives=True, # Key: Include the positive answer in the list of negatives
use_faiss=True, # Using FAISS is recommended to keep memory usage low (pip install faiss-gpu or pip install faiss-cpu)
)
print(hard_eval_dataset)
"""
Dataset({
features: ['question', 'answer', 'negative_1', 'negative_2', 'negative_3', 'negative_4', 'negative_5', 'negative_6', 'negative_7', 'negative_8', 'negative_9', 'negative_10', 'negative_11', 'negative_12', 'negative_13', 'negative_14', 'negative_15', 'negative_16', 'negative_17', 'negative_18', 'negative_19', 'negative_20', 'negative_21', 'negative_22', 'negative_23', 'negative_24', 'negative_25', 'negative_26', 'negative_27', 'negative_28', 'negative_29', 'negative_30', 'negative_31', 'negative_32', 'negative_33', 'negative_34', 'negative_35', 'negative_36', 'negative_37', 'negative_38', 'negative_39', 'negative_40', 'negative_41', 'negative_42', 'negative_43', 'negative_44', 'negative_45', 'negative_46', 'negative_47', 'negative_48', 'negative_49', 'negative_50'],
num_rows: 1000
})
"""
reranking_evaluator = CrossEncoderRerankingEvaluator(
samples=[
{
"query": sample["question"],
"positive": [sample["answer"]],
"documents": [sample[column_name] for column_name in hard_eval_dataset.column_names[2:]],
}
for sample in hard_eval_dataset
],
batch_size=32,
name="gooaq-dev",
)
# You can run evaluation like so
results = reranking_evaluator(model)
"""
CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev dataset:
Queries: 1000 Positives: Min 1.0, Mean 1.0, Max 1.0 Negatives: Min 49.0, Mean 49.1, Max 50.0
Base -> Reranked
MAP: 53.28 -> 67.28
MRR@10: 52.40 -> 66.65
NDCG@10: 59.12 -> 71.35
"""
# {'gooaq-dev_map': 0.6728370126462222, 'gooaq-dev_mrr@10': 0.6665190476190477, 'gooaq-dev_ndcg@10': 0.7135068904582963, 'gooaq-dev_base_map': 0.5327714512001362, 'gooaq-dev_base_mrr@10': 0.5239674603174603, 'gooaq-dev_base_ndcg@10': 0.5912299141913905}
from datasets import load_dataset
from sentence_transformers import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CrossEncoderCorrelationEvaluator
# Load a model
model = CrossEncoder("cross-encoder/stsb-TinyBERT-L4")
# Load the STSB dataset (https://huggingface.co/datasets/sentence-transformers/stsb)
eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")
pairs = list(zip(eval_dataset["sentence1"], eval_dataset["sentence2"]))
# Initialize the evaluator
dev_evaluator = CrossEncoderCorrelationEvaluator(
sentence_pairs=pairs,
scores=eval_dataset["score"],
name="sts_dev",
)
# You can run evaluation like so:
# results = dev_evaluator(model)
from datasets import load_dataset
from sentence_transformers import CrossEncoder
from sentence_transformers.evaluation import TripletEvaluator, SimilarityFunction
# Load a model
model = CrossEncoder("cross-encoder/nli-deberta-v3-base")
# Load triplets from the AllNLI dataset (https://huggingface.co/datasets/sentence-transformers/all-nli)
max_samples = 1000
eval_dataset = load_dataset("sentence-transformers/all-nli", "pair-class", split=f"dev[:{max_samples}]")
# Create a list of pairs, and map the labels to the labels that the model knows
pairs = list(zip(eval_dataset["premise"], eval_dataset["hypothesis"]))
label_mapping = {0: 1, 1: 2, 2: 0}
labels = [label_mapping[label] for label in eval_dataset["label"]]
# Initialize the evaluator
cls_evaluator = CrossEncoderClassificationEvaluator(
sentence_pairs=pairs,
labels=labels,
name="all-nli-dev",
)
# You can run evaluation like so:
# results = cls_evaluator(model)
Warning
When using Distributed Training, the evaluator only runs on the first device, unlike the training and evaluation datasets, which are shared across all devices.
Trainer
The CrossEncoderTrainer
is where all previous components come together. We only have to specify the trainer with the model, training arguments (optional), training dataset, evaluation dataset (optional), loss function, evaluator (optional) and we can start training. Let’s have a look at a script where all of these components come together:
import logging
import traceback
from datasets import load_dataset
from sentence_transformers.cross_encoder import (
CrossEncoder,
CrossEncoderModelCardData,
CrossEncoderTrainer,
CrossEncoderTrainingArguments,
)
from sentence_transformers.cross_encoder.evaluation import CrossEncoderNanoBEIREvaluator
from sentence_transformers.cross_encoder.losses import CachedMultipleNegativesRankingLoss
# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
model_name = "microsoft/MiniLM-L12-H384-uncased"
train_batch_size = 64
num_epochs = 1
num_rand_negatives = 5 # How many random negatives should be used for each question-answer pair
# 1a. Load a model to finetune with 1b. (Optional) model card data
model = CrossEncoder(
model_name,
model_card_data=CrossEncoderModelCardData(
language="en",
license="apache-2.0",
model_name="MiniLM-L12-H384 trained on GooAQ",
),
)
print("Model max length:", model.max_length)
print("Model num labels:", model.num_labels)
# 2. Load the GooAQ dataset: https://huggingface.co/datasets/sentence-transformers/gooaq
logging.info("Read the gooaq training dataset")
full_dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(100_000))
dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"]
eval_dataset = dataset_dict["test"]
logging.info(train_dataset)
logging.info(eval_dataset)
# 3. Define our training loss.
loss = CachedMultipleNegativesRankingLoss(
model=model,
num_negatives=num_rand_negatives,
mini_batch_size=32, # Informs the memory usage
)
# 4. Use CrossEncoderNanoBEIREvaluator, a light-weight evaluator for English reranking
evaluator = CrossEncoderNanoBEIREvaluator(
dataset_names=["msmarco", "nfcorpus", "nq"],
batch_size=train_batch_size,
)
evaluator(model)
# 5. Define the training arguments
short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
run_name = f"reranker-{short_model_name}-gooaq-cmnrl"
args = CrossEncoderTrainingArguments(
# Required parameter:
output_dir=f"models/{run_name}",
# Optional training parameters:
num_train_epochs=num_epochs,
per_device_train_batch_size=train_batch_size,
per_device_eval_batch_size=train_batch_size,
learning_rate=2e-5,
warmup_ratio=0.1,
fp16=False, # Set to False if you get an error that your GPU can't run on FP16
bf16=True, # Set to True if you have a GPU that supports BF16
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=100,
save_total_limit=2,
logging_steps=50,
logging_first_step=True,
run_name=run_name, # Will be used in W&B if `wandb` is installed
seed=12,
)
# 6. Create the trainer & start training
trainer = CrossEncoderTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
evaluator=evaluator,
)
trainer.train()
# 7. Evaluate the final model, useful to include these in the model card
evaluator(model)
# 8. Save the final model
final_output_dir = f"models/{run_name}/final"
model.save_pretrained(final_output_dir)
# 9. (Optional) save the model to the Hugging Face Hub!
# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
try:
model.push_to_hub(run_name)
except Exception:
logging.error(
f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` "
f"and saving it using `model.push_to_hub('{run_name}')`."
)
import logging
import traceback
import torch
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import (
CrossEncoder,
CrossEncoderModelCardData,
CrossEncoderTrainer,
CrossEncoderTrainingArguments,
)
from sentence_transformers.cross_encoder.evaluation import (
CrossEncoderNanoBEIREvaluator,
CrossEncoderRerankingEvaluator,
)
from sentence_transformers.cross_encoder.losses import BinaryCrossEntropyLoss
from sentence_transformers.evaluation import SequentialEvaluator
from sentence_transformers.util import mine_hard_negatives
# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
def main():
model_name = "answerdotai/ModernBERT-base"
train_batch_size = 64
num_epochs = 1
num_hard_negatives = 5 # How many hard negatives should be mined for each question-answer pair
# 1a. Load a model to finetune with 1b. (Optional) model card data
model = CrossEncoder(
model_name,
model_card_data=CrossEncoderModelCardData(
language="en",
license="apache-2.0",
model_name="ModernBERT-base trained on GooAQ",
),
)
print("Model max length:", model.max_length)
print("Model num labels:", model.num_labels)
# 2a. Load the GooAQ dataset: https://huggingface.co/datasets/sentence-transformers/gooaq
logging.info("Read the gooaq training dataset")
full_dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(100_000))
dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"]
eval_dataset = dataset_dict["test"]
logging.info(train_dataset)
logging.info(eval_dataset)
# 2b. Modify our training dataset to include hard negatives using a very efficient embedding model
embedding_model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
hard_train_dataset = mine_hard_negatives(
train_dataset,
embedding_model,
num_negatives=num_hard_negatives, # How many negatives per question-answer pair
margin=0, # Similarity between query and negative samples should be x lower than query-positive similarity
range_min=0, # Skip the x most similar samples
range_max=100, # Consider only the x most similar samples
sampling_strategy="top", # Sample the top negatives from the range
batch_size=4096, # Use a batch size of 4096 for the embedding model
output_format="labeled-pair", # The output format is (query, passage, label), as required by BinaryCrossEntropyLoss
use_faiss=True,
)
logging.info(hard_train_dataset)
# 2c. (Optionally) Save the hard training dataset to disk
# hard_train_dataset.save_to_disk("gooaq-hard-train")
# Load again with:
# hard_train_dataset = load_from_disk("gooaq-hard-train")
# 3. Define our training loss.
# pos_weight is recommended to be set as the ratio between positives to negatives, a.k.a. `num_hard_negatives`
loss = BinaryCrossEntropyLoss(model=model, pos_weight=torch.tensor(num_hard_negatives))
# 4a. Define evaluators. We use the CrossEncoderNanoBEIREvaluator, which is a light-weight evaluator for English reranking
nano_beir_evaluator = CrossEncoderNanoBEIREvaluator(
dataset_names=["msmarco", "nfcorpus", "nq"],
batch_size=train_batch_size,
)
# 4b. Define a reranking evaluator by mining hard negatives given query-answer pairs
# We include the positive answer in the list of negatives, so the evaluator can use the performance of the
# embedding model as a baseline.
hard_eval_dataset = mine_hard_negatives(
eval_dataset,
embedding_model,
corpus=full_dataset["answer"], # Use the full dataset as the corpus
num_negatives=30, # How many documents to rerank
batch_size=4096,
include_positives=True,
output_format="n-tuple",
use_faiss=True,
)
logging.info(hard_eval_dataset)
reranking_evaluator = CrossEncoderRerankingEvaluator(
samples=[
{
"query": sample["question"],
"positive": [sample["answer"]],
"documents": [sample[column_name] for column_name in hard_eval_dataset.column_names[2:]],
}
for sample in hard_eval_dataset
],
batch_size=train_batch_size,
name="gooaq-dev",
# Realistic setting: only rerank the positives that the retriever found
# Set to True to rerank *all* positives
always_rerank_positives=False,
)
# 4c. Combine the evaluators & run the base model on them
evaluator = SequentialEvaluator([reranking_evaluator, nano_beir_evaluator])
evaluator(model)
# 5. Define the training arguments
short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
run_name = f"reranker-{short_model_name}-gooaq-bce"
args = CrossEncoderTrainingArguments(
# Required parameter:
output_dir=f"models/{run_name}",
# Optional training parameters:
num_train_epochs=num_epochs,
per_device_train_batch_size=train_batch_size,
per_device_eval_batch_size=train_batch_size,
learning_rate=2e-5,
warmup_ratio=0.1,
fp16=False, # Set to False if you get an error that your GPU can't run on FP16
bf16=True, # Set to True if you have a GPU that supports BF16
dataloader_num_workers=4,
load_best_model_at_end=True,
metric_for_best_model="eval_gooaq-dev_ndcg@10",
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=1000,
save_strategy="steps",
save_steps=1000,
save_total_limit=2,
logging_steps=200,
logging_first_step=True,
run_name=run_name, # Will be used in W&B if `wandb` is installed
seed=12,
)
# 6. Create the trainer & start training
trainer = CrossEncoderTrainer(
model=model,
args=args,
train_dataset=hard_train_dataset,
loss=loss,
evaluator=evaluator,
)
trainer.train()
# 7. Evaluate the final model, useful to include these in the model card
evaluator(model)
# 8. Save the final model
final_output_dir = f"models/{run_name}/final"
model.save_pretrained(final_output_dir)
# 9. (Optional) save the model to the Hugging Face Hub!
# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
try:
model.push_to_hub(run_name)
except Exception:
logging.error(
f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` "
f"and saving it using `model.push_to_hub('{run_name}')`."
)
if __name__ == "__main__":
main()
Callbacks
This CrossEncoder trainer integrates support for various transformers.TrainerCallback
subclasses, such as:
WandbCallback
to automatically log training metrics to W&B ifwandb
is installedTensorBoardCallback
to log training metrics to TensorBoard iftensorboard
is accessible.CodeCarbonCallback
to track the carbon emissions of your model during training ifcodecarbon
is installed.Note: These carbon emissions will be included in your automatically generated model card.
See the Transformers Callbacks documentation for more information on the integrated callbacks and how to write your own callbacks.
Multi-Dataset Training
The top performing models are trained using many datasets at once. Normally, this is rather tricky, as each dataset has a different format. However, CrossEncoderTrainer
can train with multiple datasets without having to convert each dataset to the same format. It can even apply different loss functions to each of the datasets. The steps to train with multiple datasets are:
Use a dictionary of
Dataset
instances (or aDatasetDict
) as thetrain_dataset
(and optionally alsoeval_dataset
).(Optional) Use a dictionary of loss functions mapping dataset names to losses. Only required if you wish to use different loss function for different datasets.
Each training/evaluation batch will only contain samples from one of the datasets. The order in which batches are samples from the multiple datasets is defined by the MultiDatasetBatchSamplers
enum, which can be passed to the CrossEncoderTrainingArguments
via multi_dataset_batch_sampler
. Valid options are:
MultiDatasetBatchSamplers.ROUND_ROBIN
: Round-robin sampling from each dataset until one is exhausted. With this strategy, it’s likely that not all samples from each dataset are used, but each dataset is sampled from equally.MultiDatasetBatchSamplers.PROPORTIONAL
(default): Sample from each dataset in proportion to its size. With this strategy, all samples from each dataset are used and larger datasets are sampled from more frequently.
Training Tips
Cross Encoder models have their own unique quirks, so here’s some tips to help you out:
CrossEncoder
models overfit rather quickly, so it’s recommended to use an evaluator likeCrossEncoderNanoBEIREvaluator
orCrossEncoderRerankingEvaluator
together with theload_best_model_at_end
andmetric_for_best_model
training arguments to load the model with the best evaluation performance after training.CrossEncoder
are particularly receptive to strong hard negatives (mine_hard_negatives()
). They teach the model to be very strict, useful e.g. when distinguishing between passages that answer a question or passages that relate to a question.Note that if you only use hard negatives, your model may unexpectedly perform worse for easier tasks. This can mean that reranking the top 200 results from a first-stage retrieval system (e.g. with a
SentenceTransformer
model) can actually give worse top-10 results than reranking the top 100. Training using random negatives alongside hard negatives can mitigate this.
Don’t underestimate
BinaryCrossEntropyLoss
, it remains a very strong option despite being simpler than learning-to-rank (LambdaLoss
,ListNetLoss
) or in-batch negatives (CachedMultipleNegativesRankingLoss
,MultipleNegativesRankingLoss
) losses, and its data is easy to prepare, especially usingmine_hard_negatives()
.
Deprecated Training
Prior to the Sentence Transformers v4.0 release, models would be trained with the CrossEncoder.fit()
method and a DataLoader
of InputExample
, which looked something like this:
from sentence_transformers import CrossEncoder, InputExample
from torch.utils.data import DataLoader
# Define the model. Either from scratch of by loading a pre-trained model
model = CrossEncoder("distilbert/distilbert-base-uncased")
# Define your train examples. You need more than just two examples...
train_examples = [
InputExample(texts=["What are pandas?", "The giant panda ..."], label=1),
InputExample(texts=["What's a panda?", "Mount Vesuvius is a ..."], label=0),
]
# Define your train dataset, the dataloader and the train loss
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
# Tune the model
model.fit(train_dataloader=train_dataloader, epochs=1, warmup_steps=100)
Since the v4.0 release, using CrossEncoder.fit()
is still possible, but it will initialize a CrossEncoderTrainer
behind the scenes. It is recommended to use the Trainer directly, as you will have more control via the CrossEncoderTrainingArguments
, but existing training scripts relying on CrossEncoder.fit()
should still work.
In case there are issues with the updated CrossEncoder.fit()
, you can also get exactly the old behaviour by calling CrossEncoder.old_fit()
instead, but this method is planned to be deprecated fully in the future.
Comparisons with SentenceTransformer Training
Training CrossEncoder
models is very similar as training SentenceTransformer
models, with some key differences:
Instead of just
score
andlabel
, columns namedscores
andlabels
will also be considered “label columns” forCrossEncoder
training. As you can see in the Loss Overview documentation, some losses require specific labels/scores in a column with one of these names.In
SentenceTransformer
training, you cannot use lists of inputs (e.g. texts) in a column of the training/evaluation dataset(s). ForCrossEncoder
training, you can use (variably sized) lists of texts in a column. This is required for theListNetLoss
class, for example.
See the Sentence Transformer > Training Overview documentation for more details on training SentenceTransformer
models.