cross_encoder¶
For an introduction to Cross-Encoders, see Cross-Encoders.
- class sentence_transformers.cross_encoder.CrossEncoder(model_name: str, num_labels: Optional[int] = None, max_length: Optional[int] = None, device: Optional[str] = None, tokenizer_args: Dict = {}, automodel_args: Dict = {}, default_activation_function=None)¶
A CrossEncoder takes exactly two sentences / texts as input and either predicts a score or label for this sentence pair. It can for example predict the similarity of the sentence pair on a scale of 0 … 1.
It does not yield a sentence embedding and does not work for individually sentences.
- Parameters
model_name – Any model name from Huggingface Models Repository that can be loaded with AutoModel. We provide several pre-trained CrossEncoder models that can be used for common tasks
num_labels – Number of labels of the classifier. If 1, the CrossEncoder is a regression model that outputs a continous score 0…1. If > 1, it output several scores that can be soft-maxed to get probability scores for the different classes.
max_length – Max length for input sequences. Longer sequences will be truncated. If None, max length of the model will be used
device – Device that should be used for the model. If None, it will use CUDA if available.
tokenizer_args – Arguments passed to AutoTokenizer
automodel_args – Arguments passed to AutoModelForSequenceClassification
default_activation_function – Callable (like nn.Sigmoid) about the default activation function that should be used on-top of model.predict(). If None. nn.Sigmoid() will be used if num_labels=1, else nn.Identity()
- fit(train_dataloader: torch.utils.data.dataloader.DataLoader, evaluator: typing.Optional[sentence_transformers.evaluation.SentenceEvaluator.SentenceEvaluator] = None, epochs: int = 1, loss_fct=None, activation_fct=Identity(), scheduler: str = 'WarmupLinear', warmup_steps: int = 10000, optimizer_class: typing.Type[torch.optim.optimizer.Optimizer] = <class 'torch.optim.adamw.AdamW'>, optimizer_params: typing.Dict[str, object] = {'lr': 2e-05}, weight_decay: float = 0.01, evaluation_steps: int = 0, output_path: typing.Optional[str] = None, save_best_model: bool = True, max_grad_norm: float = 1, use_amp: bool = False, callback: typing.Optional[typing.Callable[[float, int, int], None]] = None, show_progress_bar: bool = True)¶
Train the model with the given training objective Each training objective is sampled in turn for one batch. We sample only as many batches from each objective as there are in the smallest one to make sure of equal training with each dataset.
- Parameters
train_dataloader – DataLoader with training InputExamples
evaluator – An evaluator (sentence_transformers.evaluation) evaluates the model performance during training on held-out dev data. It is used to determine the best model that is saved to disc.
epochs – Number of epochs for training
loss_fct – Which loss function to use for training. If None, will use nn.BCEWithLogitsLoss() if self.config.num_labels == 1 else nn.CrossEntropyLoss()
activation_fct – Activation function applied on top of logits output of model.
scheduler – Learning rate scheduler. Available schedulers: constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts
warmup_steps – Behavior depends on the scheduler. For WarmupLinear (default), the learning rate is increased from o up to the maximal learning rate. After these many training steps, the learning rate is decreased linearly back to zero.
optimizer_class – Optimizer
optimizer_params – Optimizer parameters
weight_decay – Weight decay for model parameters
evaluation_steps – If > 0, evaluate the model using evaluator after each number of training steps
output_path – Storage path for the model and evaluation files
save_best_model – If true, the best model (according to evaluator) is stored at output_path
max_grad_norm – Used for gradient normalization.
use_amp – Use Automatic Mixed Precision (AMP). Only for Pytorch >= 1.6.0
callback – Callback function that is invoked after each evaluation. It must accept the following three parameters in this order: score, epoch, steps
show_progress_bar – If True, output a tqdm progress bar
- predict(sentences: List[List[str]], batch_size: int = 32, show_progress_bar: Optional[bool] = None, num_workers: int = 0, activation_fct=None, apply_softmax=False, convert_to_numpy: bool = True, convert_to_tensor: bool = False)¶
Performs predicts with the CrossEncoder on the given sentence pairs.
- Parameters
sentences – A list of sentence pairs [[Sent1, Sent2], [Sent3, Sent4]]
batch_size – Batch size for encoding
show_progress_bar – Output progress bar
num_workers – Number of workers for tokenization
activation_fct – Activation function applied on the logits output of the CrossEncoder. If None, nn.Sigmoid() will be used if num_labels=1, else nn.Identity
convert_to_numpy – Convert the output to a numpy matrix.
apply_softmax – If there are more than 2 dimensions and apply_softmax=True, applies softmax on the logits output
convert_to_tensor – Conver the output to a tensor.
- Returns
Predictions for the passed sentence pairs
- save(path)¶
Saves all model and tokenizer to path
- save_pretrained(path)¶
Same function as save
Evaluation¶
CrossEncoder have their own evaluation classes, that are in sentence_transformers.cross_encoder.evaluation
.
- class sentence_transformers.cross_encoder.evaluation.CEBinaryAccuracyEvaluator(sentence_pairs: List[List[str]], labels: List[int], name: str = '', threshold: float = 0.5, write_csv: bool = True)¶
This evaluator can be used with the CrossEncoder class.
It is designed for CrossEncoders with 1 outputs. It measure the accuracy of the predict class vs. the gold labels. It uses a fixed threshold to determine the label (0 vs 1).
See CEBinaryClassificationEvaluator for an evaluator that determines automatically the optimal threshold.
- class sentence_transformers.cross_encoder.evaluation.CEBinaryClassificationEvaluator(sentence_pairs: List[List[str]], labels: List[int], name: str = '', show_progress_bar: bool = False, write_csv: bool = True)¶
This evaluator can be used with the CrossEncoder class. Given sentence pairs and binary labels (0 and 1), it compute the average precision and the best possible f1 score
- class sentence_transformers.cross_encoder.evaluation.CECorrelationEvaluator(sentence_pairs: List[List[str]], scores: List[float], name: str = '', write_csv: bool = True)¶
This evaluator can be used with the CrossEncoder class. Given sentence pairs and continuous scores, it compute the pearson & spearman correlation between the predicted score for the sentence pair and the gold score.
- class sentence_transformers.cross_encoder.evaluation.CESoftmaxAccuracyEvaluator(sentence_pairs: List[List[str]], labels: List[int], name: str = '', write_csv: bool = True)¶
This evaluator can be used with the CrossEncoder class.
It is designed for CrossEncoders with 2 or more outputs. It measure the accuracy of the predict class vs. the gold labels.
- class sentence_transformers.cross_encoder.evaluation.CERerankingEvaluator(samples, mrr_at_k: int = 10, name: str = '', write_csv: bool = True)¶
This class evaluates a CrossEncoder model for the task of re-ranking.
Given a query and a list of documents, it computes the score [query, doc_i] for all possible documents and sorts them in decreasing order. Then, MRR@10 is compute to measure the quality of the ranking.
- Parameters
samples – Must be a list and each element is of the form: {‘query’: ‘’, ‘positive’: [], ‘negative’: []}. Query is the search query, positive is a list of positive (relevant) documents, negative is a list of negative (irrelevant) documents.