Evaluation¶

sentence_transformers.evaluation defines different classes, that can be used to evaluate the model during training.

class sentence_transformers.evaluation.BinaryClassificationEvaluator(sentences1: List[str], sentences2: List[str], labels: List[int], name: str = '', batch_size: int = 32, show_progress_bar: bool = False, write_csv: bool = True)¶

Evaluate a model based on the similarity of the embeddings by calculating the accuracy of identifying similar and dissimilar sentences. The metrics are the cosine similarity as well as euclidean and Manhattan distance The returned score is the accuracy with a specified metric.

The results are written in a CSV. If a CSV already exists, then values are appended.

The labels need to be 0 for dissimilar pairs and 1 for similar pairs.

Parameters
  • sentences1 – The first column of sentences

  • sentences2 – The second column of sentences

  • labels – labels[i] is the label for the pair (sentences1[i], sentences2[i]). Must be 0 or 1

  • name – Name for the output

  • batch_size – Batch size used to compute embeddings

  • show_progress_bar – If true, prints a progress bar

  • write_csv – Write results to a CSV file

class sentence_transformers.evaluation.EmbeddingSimilarityEvaluator(sentences1: List[str], sentences2: List[str], scores: List[float], batch_size: int = 16, main_similarity: Optional[sentence_transformers.evaluation.SimilarityFunction.SimilarityFunction] = None, name: str = '', show_progress_bar: bool = False, write_csv: bool = True)¶

Evaluate a model based on the similarity of the embeddings by calculating the Spearman and Pearson rank correlation in comparison to the gold standard labels. The metrics are the cosine similarity as well as euclidean and Manhattan distance The returned score is the Spearman correlation with a specified metric.

The results are written in a CSV. If a CSV already exists, then values are appended.

Constructs an evaluator based for the dataset

The labels need to indicate the similarity between the sentences.

Parameters
  • sentences1 – List with the first sentence in a pair

  • sentences2 – List with the second sentence in a pair

  • scores – Similarity score between sentences1[i] and sentences2[i]

  • write_csv – Write results to a CSV file

class sentence_transformers.evaluation.InformationRetrievalEvaluator(queries: Dict[str, str], corpus: Dict[str, str], relevant_docs: Dict[str, Set[str]], corpus_chunk_size: int = 50000, mrr_at_k: List[int] = [10], ndcg_at_k: List[int] = [10], accuracy_at_k: List[int] = [1, 3, 5, 10], precision_recall_at_k: List[int] = [1, 3, 5, 10], map_at_k: List[int] = [100], show_progress_bar: bool = False, batch_size: int = 32, name: str = '', write_csv: bool = True, score_functions: Dict[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = {'cos_sim': <function cos_sim>, 'dot_score': <function dot_score>}, main_score_function: Optional[str] = None)¶

This class evaluates an Information Retrieval (IR) setting.

Given a set of queries and a large corpus set. It will retrieve for each query the top-k most similar document. It measures Mean Reciprocal Rank (MRR), Recall@k, and Normalized Discounted Cumulative Gain (NDCG)

class sentence_transformers.evaluation.LabelAccuracyEvaluator(dataloader: torch.utils.data.dataloader.DataLoader, name: str = '', softmax_model=None, write_csv: bool = True)¶

Evaluate a model based on its accuracy on a labeled dataset

This requires a model with LossFunction.SOFTMAX

The results are written in a CSV. If a CSV already exists, then values are appended.

Constructs an evaluator for the given dataset

Parameters

dataloader – the data for the evaluation

class sentence_transformers.evaluation.MSEEvaluator(source_sentences: List[str], target_sentences: List[str], teacher_model=None, show_progress_bar: bool = False, batch_size: int = 32, name: str = '', write_csv: bool = True)¶

Computes the mean squared error (x100) between the computed sentence embedding and some target sentence embedding.

The MSE is computed between ||teacher.encode(source_sentences) - student.encode(target_sentences)||.

For multilingual knowledge distillation (https://arxiv.org/abs/2004.09813), source_sentences are in English and target_sentences are in a different language like German, Chinese, Spanish…

Parameters
  • source_sentences – Source sentences are embedded with the teacher model

  • target_sentences – Target sentences are ambedding with the student model.

  • show_progress_bar – Show progress bar when computing embeddings

  • batch_size – Batch size to compute sentence embeddings

  • name – Name of the evaluator

  • write_csv – Write results to CSV file

class sentence_transformers.evaluation.MSEEvaluatorFromDataFrame(dataframe: List[Dict[str, str]], teacher_model: <module 'sentence_transformers.SentenceTransformer' from 'c:\\code\\sentence-transformers\\sentence_transformers\\SentenceTransformer.py'>, combinations: List[Tuple[str, str]], batch_size: int = 8, name='', write_csv: bool = True)¶

Computes the mean squared error (x100) between the computed sentence embedding and some target sentence embedding.

Parameters
  • dataframe –

    It must have the following format. Rows contains different, parallel sentences. Columns are the respective language codes:

    [{'en': 'My sentence', 'es': 'Sentence in Spanisch', 'fr': 'Sentence in French'...},
     {'en': 'My second sentence', ...}]
    

  • combinations – Must be of the format [('en', 'es'), ('en', 'fr'), ...]. First entry in a tuple is the source language. The sentence in the respective language will be fetched from the dataframe and passed to the teacher model. Second entry in a tuple the the target language. Sentence will be fetched from the dataframe and passed to the student model

class sentence_transformers.evaluation.ParaphraseMiningEvaluator(sentences_map: Dict[str, str], duplicates_list: Optional[List[Tuple[str, str]]] = None, duplicates_dict: Optional[Dict[str, Dict[str, bool]]] = None, add_transitive_closure: bool = False, query_chunk_size: int = 5000, corpus_chunk_size: int = 100000, max_pairs: int = 500000, top_k: int = 100, show_progress_bar: bool = False, batch_size: int = 16, name: str = '', write_csv: bool = True)¶

Given a large set of sentences, this evaluator performs paraphrase (duplicate) mining and identifies the pairs with the highest similarity. It compare the extracted paraphrase pairs with a set of gold labels and computes the F1 score.

Parameters
  • sentences_map – A dictionary that maps sentence-ids to sentences, i.e. sentences_map[id] => sentence.

  • duplicates_list – Duplicates_list is a list with id pairs [(id1, id2), (id1, id5)] that identifies the duplicates / paraphrases in the sentences_map

  • duplicates_dict – A default dictionary mapping [id1][id2] to true if id1 and id2 are duplicates. Must be symmetric, i.e., if [id1][id2] => True, then [id2][id1] => True.

  • add_transitive_closure – If true, it adds a transitive closure, i.e. if dup[a][b] and dup[b][c], then dup[a][c]

  • query_chunk_size – To identify the paraphrases, the cosine-similarity between all sentence-pairs will be computed. As this might require a lot of memory, we perform a batched computation. #query_batch_size sentences will be compared against up to #corpus_batch_size sentences. In the default setting, 5000 sentences will be grouped together and compared up-to against 100k other sentences.

  • corpus_chunk_size – The corpus will be batched, to reduce the memory requirement

  • max_pairs – We will only extract up to #max_pairs potential paraphrase candidates.

  • top_k – For each query, we extract the top_k most similar pairs and add it to a sorted list. I.e., for one sentence we cannot find more than top_k paraphrases

  • show_progress_bar – Output a progress bar

  • batch_size – Batch size for computing sentence embeddings

  • name – Name of the experiment

  • write_csv – Write results to CSV file

class sentence_transformers.evaluation.RerankingEvaluator(samples, at_k: int = 10, name: str = '', write_csv: bool = True, similarity_fct=<function cos_sim>, batch_size: int = 64, show_progress_bar: bool = False, use_batched_encoding: bool = True, mrr_at_k: Optional[int] = None)¶

This class evaluates a SentenceTransformer model for the task of re-ranking.

Given a query and a list of documents, it computes the score [query, doc_i] for all possible documents and sorts them in decreasing order. Then, MRR@10, NDCG@10 and MAP is compute to measure the quality of the ranking.

Parameters

samples – Must be a list and each element is of the form: {‘query’: ‘’, ‘positive’: [], ‘negative’: []}. Query is the search query, positive is a list of positive (relevant) documents, negative is a list of negative (irrelevant) documents.

class sentence_transformers.evaluation.SentenceEvaluator¶

Base class for all evaluators

Extend this class and implement __call__ for custom evaluators.

class sentence_transformers.evaluation.SequentialEvaluator(evaluators: Iterable[sentence_transformers.evaluation.SentenceEvaluator.SentenceEvaluator], main_score_function=<function SequentialEvaluator.<lambda>>)¶

This evaluator allows that multiple sub-evaluators are passed. When the model is evaluated, the data is passed sequentially to all sub-evaluators.

All scores are passed to ‘main_score_function’, which derives one final score value

class sentence_transformers.evaluation.TranslationEvaluator(source_sentences: List[str], target_sentences: List[str], show_progress_bar: bool = False, batch_size: int = 16, name: str = '', print_wrong_matches: bool = False, write_csv: bool = True)¶

Given two sets of sentences in different languages, e.g. (en_1, en_2, en_3…) and (fr_1, fr_2, fr_3, …), and assuming that fr_i is the translation of en_i. Checks if vec(en_i) has the highest similarity to vec(fr_i). Computes the accuracy in both directions

Constructs an evaluator based for the dataset

The labels need to indicate the similarity between the sentences.

Parameters
  • source_sentences – List of sentences in source language

  • target_sentences – List of sentences in target language

  • print_wrong_matches – Prints incorrect matches

  • write_csv – Write results to CSV file

class sentence_transformers.evaluation.TripletEvaluator(anchors: List[str], positives: List[str], negatives: List[str], main_distance_function: Optional[sentence_transformers.evaluation.SimilarityFunction.SimilarityFunction] = None, name: str = '', batch_size: int = 16, show_progress_bar: bool = False, write_csv: bool = True)¶
Evaluate a model based on a triplet: (sentence, positive_example, negative_example).

Checks if distance(sentence, positive_example) < distance(sentence, negative_example).

Parameters
  • anchors – Sentences to check similarity to. (e.g. a query)

  • positives – List of positive sentences

  • negatives – List of negative sentences

  • main_distance_function – One of 0 (Cosine), 1 (Euclidean) or 2 (Manhattan). Defaults to None, returning all 3.

  • name – Name for the output

  • batch_size – Batch size used to compute embeddings

  • show_progress_bar – If true, prints a progress bar

  • write_csv – Write results to a CSV file