Cross-Encoders¶
SentenceTransformers also supports the option to train Cross-Encoders for sentence pair score and sentence pair classification tasks. For more details on what Cross-Encoders are and the difference between Cross- and Bi-Encoders, see Cross-Encoders.
Examples¶
See the following examples how to train Cross-Encoders:
training_stsbenchmark.py - Example how to train for Semantic Textual Similarity (STS) on the STS benchmark dataset.
training_quora_duplicate_questions.py - Example how to train a Cross-Encoder to predict if two questions are duplicates. Uses Quora Duplicate Questions as training dataset.
training_nli.py - Example for a multilabel classification task for Natural Language Inference (NLI) task.
Training CrossEncoders¶
The CrossEncoder
class is a wrapper around Hugging Face AutoModelForSequenceClassification
, but with some methods to make training and predicting scores a little bit easier. The saved models are 100% compatible with Hugging Face and can also be loaded with their classes.
First, you need some sentence pair data. You can either have a continuous score, like:
from sentence_transformers import InputExample
train_samples = [
InputExample(texts=["sentence1", "sentence2"], label=0.3),
InputExample(texts=["Another", "pair"], label=0.8),
]
Or you have distinct classes as in the training_nli.py example:
from sentence_transformers import InputExample
label2int = {"contradiction": 0, "entailment": 1, "neutral": 2}
train_samples = [
InputExample(texts=["sentence1", "sentence2"], label=label2int["neutral"]),
InputExample(texts=["Another", "pair"], label=label2int["entailment"]),
]
Then, you define the base model and the number of labels. You can take any Hugging Face pre-trained model that is compatible with AutoModel:
model = CrossEncoder('distilroberta-base', num_labels=1)
For binary tasks and tasks with continuous scores (like STS), we set num_labels=1. For classification tasks, we set it to the number of labels we have.
We start the training by calling model.fit()
:
model.fit(
train_dataloader=train_dataloader,
evaluator=evaluator,
epochs=num_epochs,
warmup_steps=warmup_steps,
output_path=model_save_path,
)