cross_encoderΒΆ

For an introduction to Cross-Encoders, see Cross-Encoders.

class sentence_transformers.cross_encoder.CrossEncoder(model_name: str, num_labels: Optional[int] = None, max_length: Optional[int] = None, device: Optional[str] = None, tokenizer_args: Dict = {}, automodel_args: Dict = {}, revision: Optional[str] = None, default_activation_function=None, classifier_dropout: Optional[float] = None)ΒΆ

A CrossEncoder takes exactly two sentences / texts as input and either predicts a score or label for this sentence pair. It can for example predict the similarity of the sentence pair on a scale of 0 … 1.

It does not yield a sentence embedding and does not work for individual sentences.

Parameters
  • model_name – A model name from Hugging Face Hub that can be loaded with AutoModel, or a path to a local model. We provide several pre-trained CrossEncoder models that can be used for common tasks.

  • num_labels – Number of labels of the classifier. If 1, the CrossEncoder is a regression model that outputs a continuous score 0…1. If > 1, it output several scores that can be soft-maxed to get probability scores for the different classes.

  • max_length – Max length for input sequences. Longer sequences will be truncated. If None, max length of the model will be used

  • device – Device that should be used for the model. If None, it will use CUDA if available.

  • tokenizer_args – Arguments passed to AutoTokenizer

  • automodel_args – Arguments passed to AutoModelForSequenceClassification

  • revision – The specific model version to use. It can be a branch name, a tag name, or a commit id, for a stored model on Hugging Face.

  • default_activation_function – Callable (like nn.Sigmoid) about the default activation function that should be used on-top of model.predict(). If None. nn.Sigmoid() will be used if num_labels=1, else nn.Identity()

  • classifier_dropout – The dropout ratio for the classification head.

fit(train_dataloader: torch.utils.data.dataloader.DataLoader, evaluator: Optional[sentence_transformers.evaluation.SentenceEvaluator.SentenceEvaluator] = None, epochs: int = 1, loss_fct=None, activation_fct=Identity(), scheduler: str = 'WarmupLinear', warmup_steps: int = 10000, optimizer_class: Type[torch.optim.optimizer.Optimizer] = <class 'torch.optim.adamw.AdamW'>, optimizer_params: Dict[str, object] = {'lr': 2e-05}, weight_decay: float = 0.01, evaluation_steps: int = 0, output_path: Optional[str] = None, save_best_model: bool = True, max_grad_norm: float = 1, use_amp: bool = False, callback: Optional[Callable[[float, int, int], None]] = None, show_progress_bar: bool = True)ΒΆ

Train the model with the given training objective Each training objective is sampled in turn for one batch. We sample only as many batches from each objective as there are in the smallest one to make sure of equal training with each dataset.

Parameters
  • train_dataloader – DataLoader with training InputExamples

  • evaluator – An evaluator (sentence_transformers.evaluation) evaluates the model performance during training on held-out dev data. It is used to determine the best model that is saved to disc.

  • epochs – Number of epochs for training

  • loss_fct – Which loss function to use for training. If None, will use nn.BCEWithLogitsLoss() if self.config.num_labels == 1 else nn.CrossEntropyLoss()

  • activation_fct – Activation function applied on top of logits output of model.

  • scheduler – Learning rate scheduler. Available schedulers: constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts

  • warmup_steps – Behavior depends on the scheduler. For WarmupLinear (default), the learning rate is increased from o up to the maximal learning rate. After these many training steps, the learning rate is decreased linearly back to zero.

  • optimizer_class – Optimizer

  • optimizer_params – Optimizer parameters

  • weight_decay – Weight decay for model parameters

  • evaluation_steps – If > 0, evaluate the model using evaluator after each number of training steps

  • output_path – Storage path for the model and evaluation files

  • save_best_model – If true, the best model (according to evaluator) is stored at output_path

  • max_grad_norm – Used for gradient normalization.

  • use_amp – Use Automatic Mixed Precision (AMP). Only for Pytorch >= 1.6.0

  • callback – Callback function that is invoked after each evaluation. It must accept the following three parameters in this order: score, epoch, steps

  • show_progress_bar – If True, output a tqdm progress bar

predict(sentences: List[List[str]], batch_size: int = 32, show_progress_bar: Optional[bool] = None, num_workers: int = 0, activation_fct=None, apply_softmax=False, convert_to_numpy: bool = True, convert_to_tensor: bool = False)ΒΆ

Performs predicts with the CrossEncoder on the given sentence pairs.

Parameters
  • sentences – A list of sentence pairs [[Sent1, Sent2], [Sent3, Sent4]]

  • batch_size – Batch size for encoding

  • show_progress_bar – Output progress bar

  • num_workers – Number of workers for tokenization

  • activation_fct – Activation function applied on the logits output of the CrossEncoder. If None, nn.Sigmoid() will be used if num_labels=1, else nn.Identity

  • convert_to_numpy – Convert the output to a numpy matrix.

  • apply_softmax – If there are more than 2 dimensions and apply_softmax=True, applies softmax on the logits output

  • convert_to_tensor – Convert the output to a tensor.

Returns

Predictions for the passed sentence pairs

rank(query: str, documents: List[str], top_k: Optional[int] = None, return_documents: bool = False, batch_size: int = 32, show_progress_bar: Optional[bool] = None, num_workers: int = 0, activation_fct=None, apply_softmax=False, convert_to_numpy: bool = True, convert_to_tensor: bool = False)List[Dict]ΒΆ

Performs ranking with the CrossEncoder on the given query and documents. Returns a sorted list with the document indices and scores.

Example:
from sentence_transformers import CrossEncoder
model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

query = "Who wrote 'To Kill a Mockingbird'?"
documents = [
    "'To Kill a Mockingbird' is a novel by Harper Lee published in 1960. It was immediately successful, winning the Pulitzer Prize, and has become a classic of modern American literature.",
    "The novel 'Moby-Dick' was written by Herman Melville and first published in 1851. It is considered a masterpiece of American literature and deals with complex themes of obsession, revenge, and the conflict between good and evil.",
    "Harper Lee, an American novelist widely known for her novel 'To Kill a Mockingbird', was born in 1926 in Monroeville, Alabama. She received the Pulitzer Prize for Fiction in 1961.",
    "Jane Austen was an English novelist known primarily for her six major novels, which interpret, critique and comment upon the British landed gentry at the end of the 18th century.",
    "The 'Harry Potter' series, which consists of seven fantasy novels written by British author J.K. Rowling, is among the most popular and critically acclaimed books of the modern era.",
    "'The Great Gatsby', a novel written by American author F. Scott Fitzgerald, was published in 1925. The story is set in the Jazz Age and follows the life of millionaire Jay Gatsby and his pursuit of Daisy Buchanan."
]

model.rank(query, documents, return_documents=True)
[{'corpus_id': 0,
'score': 10.67858,
'text': "'To Kill a Mockingbird' is a novel by Harper Lee published in 1960. It was immediately successful, winning the Pulitzer Prize, and has become a classic of modern American literature."},
{'corpus_id': 2,
'score': 9.761677,
'text': "Harper Lee, an American novelist widely known for her novel 'To Kill a Mockingbird', was born in 1926 in Monroeville, Alabama. She received the Pulitzer Prize for Fiction in 1961."},
{'corpus_id': 1,
'score': -3.3099542,
'text': "The novel 'Moby-Dick' was written by Herman Melville and first published in 1851. It is considered a masterpiece of American literature and deals with complex themes of obsession, revenge, and the conflict between good and evil."},
{'corpus_id': 5,
'score': -4.8989105,
'text': "'The Great Gatsby', a novel written by American author F. Scott Fitzgerald, was published in 1925. The story is set in the Jazz Age and follows the life of millionaire Jay Gatsby and his pursuit of Daisy Buchanan."},
{'corpus_id': 4,
'score': -5.082967,
'text': "The 'Harry Potter' series, which consists of seven fantasy novels written by British author J.K. Rowling, is among the most popular and critically acclaimed books of the modern era."}]
Parameters
  • query – A single query

  • documents – A list of documents

  • top_k – Return the top-k documents. If None, all documents are returned.

  • return_documents – If True, also returns the documents. If False, only returns the indices and scores.

  • batch_size – Batch size for encoding

  • show_progress_bar – Output progress bar

  • num_workers – Number of workers for tokenization

  • activation_fct – Activation function applied on the logits output of the CrossEncoder. If None, nn.Sigmoid() will be used if num_labels=1, else nn.Identity

  • convert_to_numpy – Convert the output to a numpy matrix.

  • apply_softmax – If there are more than 2 dimensions and apply_softmax=True, applies softmax on the logits output

  • convert_to_tensor – Convert the output to a tensor.

Returns

A sorted list with the document indices and scores, and optionally also documents.

save(path: str)NoneΒΆ

Saves all model and tokenizer to path

save_pretrained(path: str)NoneΒΆ

Same function as save

EvaluationΒΆ

CrossEncoder have their own evaluation classes, that are in sentence_transformers.cross_encoder.evaluation.

class sentence_transformers.cross_encoder.evaluation.CEBinaryAccuracyEvaluator(sentence_pairs: List[List[str]], labels: List[int], name: str = '', threshold: float = 0.5, write_csv: bool = True)ΒΆ

This evaluator can be used with the CrossEncoder class.

It is designed for CrossEncoders with 1 outputs. It measure the accuracy of the predict class vs. the gold labels. It uses a fixed threshold to determine the label (0 vs 1).

See CEBinaryClassificationEvaluator for an evaluator that determines automatically the optimal threshold.

class sentence_transformers.cross_encoder.evaluation.CEBinaryClassificationEvaluator(sentence_pairs: List[List[str]], labels: List[int], name: str = '', show_progress_bar: bool = False, write_csv: bool = True)ΒΆ

This evaluator can be used with the CrossEncoder class. Given sentence pairs and binary labels (0 and 1), it compute the average precision and the best possible f1 score

class sentence_transformers.cross_encoder.evaluation.CECorrelationEvaluator(sentence_pairs: List[List[str]], scores: List[float], name: str = '', write_csv: bool = True)ΒΆ

This evaluator can be used with the CrossEncoder class. Given sentence pairs and continuous scores, it compute the pearson & spearman correlation between the predicted score for the sentence pair and the gold score.

class sentence_transformers.cross_encoder.evaluation.CEF1Evaluator(sentence_pairs: List[List[str]], labels: List[int], *, batch_size: int = 32, show_progress_bar: bool = False, name: str = '', write_csv: bool = True)ΒΆ

CrossEncoder F1 score based evaluator for binary and multiclass tasks.

The task type (binary or multiclass) is determined from the labels array. For binary tasks the returned metric is binary F1 score. For the multiclass tasks the returned metric is macro F1 score.

Parameters
  • sentence_pairs (list[list[str]]) – A list of sentence pairs, where each pair is a list of two strings.

  • labels (list[int]) – A list of integer labels corresponding to each sentence pair.

  • batch_size (int) – Batch size for prediction. Defaults to 32.

  • show_progress_bar (bool) – Show tqdm progress bar.

  • name (str, optional) – An optional name for the CSV file with stored results. Defaults to an empty string.

  • write_csv (bool, optional) – Flag to determine if the data should be saved to a CSV file. Defaults to True.

class sentence_transformers.cross_encoder.evaluation.CESoftmaxAccuracyEvaluator(sentence_pairs: List[List[str]], labels: List[int], name: str = '', write_csv: bool = True)ΒΆ

This evaluator can be used with the CrossEncoder class.

It is designed for CrossEncoders with 2 or more outputs. It measure the accuracy of the predict class vs. the gold labels.

class sentence_transformers.cross_encoder.evaluation.CERerankingEvaluator(samples, at_k: int = 10, name: str = '', write_csv: bool = True, mrr_at_k: Optional[int] = None)ΒΆ

This class evaluates a CrossEncoder model for the task of re-ranking.

Given a query and a list of documents, it computes the score [query, doc_i] for all possible documents and sorts them in decreasing order. Then, MRR@10 and NDCG@10 are computed to measure the quality of the ranking.

Parameters

samples – Must be a list and each element is of the form: {β€˜query’: β€˜β€™, β€˜positive’: [], β€˜negative’: []}. Query is the search query, positive is a list of positive (relevant) documents, negative is a list of negative (irrelevant) documents.