Evaluation

sentence_transformers.evaluation defines different classes, that can be used to evaluate the model during training.

BinaryClassificationEvaluator

class sentence_transformers.evaluation.BinaryClassificationEvaluator(sentences1: List[str], sentences2: List[str], labels: List[int], name: str = '', batch_size: int = 32, show_progress_bar: bool = False, write_csv: bool = True, truncate_dim: Optional[int] = None)[source]

Evaluate a model based on the similarity of the embeddings by calculating the accuracy of identifying similar and dissimilar sentences. The metrics are the cosine similarity, dot score, Euclidean and Manhattan distance The returned score is the accuracy with a specified metric.

The results are written in a CSV. If a CSV already exists, then values are appended.

The labels need to be 0 for dissimilar pairs and 1 for similar pairs.

Parameters
  • sentences1 (List[str]) – The first column of sentences.

  • sentences2 (List[str]) – The second column of sentences.

  • labels (List[int]) – labels[i] is the label for the pair (sentences1[i], sentences2[i]). Must be 0 or 1.

  • name (str, optional) – Name for the output. Defaults to “”.

  • batch_size (int, optional) – Batch size used to compute embeddings. Defaults to 32.

  • show_progress_bar (bool, optional) – If true, prints a progress bar. Defaults to False.

  • write_csv (bool, optional) – Write results to a CSV file. Defaults to True.

  • truncate_dim (Optional[int], optional) – The dimension to truncate sentence embeddings to. None uses the model’s current truncation dimension. Defaults to None.

Example

from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import BinaryClassificationEvaluator
from datasets import load_dataset

# Load a model
model = SentenceTransformer('all-mpnet-base-v2')

# Load a dataset with two text columns and a class label column (https://huggingface.co/datasets/sentence-transformers/quora-duplicates)
eval_dataset = load_dataset("sentence-transformers/quora-duplicates", "pair-class", split="train[-1000:]")

# Initialize the evaluator
binary_acc_evaluator = BinaryClassificationEvaluator(
    sentences1=eval_dataset["sentence1"],
    sentences2=eval_dataset["sentence2"],
    labels=eval_dataset["label"],
    name="quora-duplicates-dev",
)
results = binary_acc_evaluator(model)
'''
Binary Accuracy Evaluation of the model on the quora-duplicates-dev dataset:
Accuracy with Cosine-Similarity:           81.60    (Threshold: 0.8352)
F1 with Cosine-Similarity:                 75.27    (Threshold: 0.7715)
Precision with Cosine-Similarity:          65.81
Recall with Cosine-Similarity:             87.89
Average Precision with Cosine-Similarity:  76.03

Accuracy with Dot-Product:           81.60  (Threshold: 0.8352)
F1 with Dot-Product:                 75.27  (Threshold: 0.7715)
Precision with Dot-Product:          65.81
Recall with Dot-Product:             87.89
Average Precision with Dot-Product:  76.03

Accuracy with Manhattan-Distance:           81.50   (Threshold: 12.0727)
F1 with Manhattan-Distance:                 74.97   (Threshold: 15.2269)
Precision with Manhattan-Distance:          63.89
Recall with Manhattan-Distance:             90.68
Average Precision with Manhattan-Distance:  75.66

Accuracy with Euclidean-Distance:           81.60   (Threshold: 0.5741)
F1 with Euclidean-Distance:                 75.27   (Threshold: 0.6760)
Precision with Euclidean-Distance:          65.81
Recall with Euclidean-Distance:             87.89
Average Precision with Euclidean-Distance:  76.03
'''
print(binary_acc_evaluator.primary_metric)
# => "quora-duplicates-dev_max_ap"
print(results[binary_acc_evaluator.primary_metric])
# => 0.760277070888393

Base class for all evaluators. Notably, this class introduces the greater_is_better and primary_metric attributes. The former is a boolean indicating whether a higher evaluation score is better, which is used for choosing the best checkpoint if load_best_model_at_end is set to True in the training arguments.

The latter is a string indicating the primary metric for the evaluator. This has to be defined whenever the evaluator returns a dictionary of metrics, and the primary metric is the key pointing to the primary metric, i.e. the one that is used for model selection and/or logging.

EmbeddingSimilarityEvaluator

class sentence_transformers.evaluation.EmbeddingSimilarityEvaluator(sentences1: List[str], sentences2: List[str], scores: List[float], batch_size: int = 16, main_similarity: Optional[Union[str, sentence_transformers.similarity_functions.SimilarityFunction]] = None, name: str = '', show_progress_bar: bool = False, write_csv: bool = True, precision: Optional[Literal[float32, int8, uint8, binary, ubinary]] = None, truncate_dim: Optional[int] = None)[source]

Evaluate a model based on the similarity of the embeddings by calculating the Spearman and Pearson rank correlation in comparison to the gold standard labels. The metrics are the cosine similarity as well as euclidean and Manhattan distance The returned score is the Spearman correlation with a specified metric.

Example

from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SimilarityFunction

# Load a model
model = SentenceTransformer('all-mpnet-base-v2')

# Load the STSB dataset (https://huggingface.co/datasets/sentence-transformers/stsb)
eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")

# Initialize the evaluator
dev_evaluator = EmbeddingSimilarityEvaluator(
    sentences1=eval_dataset["sentence1"],
    sentences2=eval_dataset["sentence2"],
    scores=eval_dataset["score"],
    main_similarity=SimilarityFunction.COSINE,
    name="sts-dev",
)
dev_evaluator(model)
'''
EmbeddingSimilarityEvaluator: Evaluating the model on the sts-dev dataset:
Cosine-Similarity :       Pearson: 0.7874 Spearman: 0.8004
Manhattan-Distance:       Pearson: 0.7823 Spearman: 0.7827
Euclidean-Distance:       Pearson: 0.7824 Spearman: 0.7827
Dot-Product-Similarity:   Pearson: 0.7192 Spearman: 0.7126
'''
# => {'sts-dev_pearson_cosine': 0.880607226102985, 'sts-dev_spearman_cosine': 0.881019449484294, ...}

Constructs an evaluator based for the dataset.

Parameters
  • sentences1 (List[str]) – List with the first sentence in a pair.

  • sentences2 (List[str]) – List with the second sentence in a pair.

  • scores (List[float]) – Similarity score between sentences1[i] and sentences2[i].

  • batch_size (int, optional) – The batch size for processing the sentences. Defaults to 16.

  • main_similarity (Optional[Union[str, SimilarityFunction]], optional) – The main similarity function to use. Can be a string (e.g. “cosine”, “dot”) or a SimilarityFunction object. Defaults to None.

  • name (str, optional) – The name of the evaluator. Defaults to “”.

  • show_progress_bar (bool, optional) – Whether to show a progress bar during evaluation. Defaults to False.

  • write_csv (bool, optional) – Whether to write the evaluation results to a CSV file. Defaults to True.

  • precision (Optional[Literal["float32", "int8", "uint8", "binary", "ubinary"]], optional) – The precision to use for the embeddings. Can be “float32”, “int8”, “uint8”, “binary”, or “ubinary”. Defaults to None.

  • truncate_dim (Optional[int], optional) – The dimension to truncate sentence embeddings to. None uses the model’s current truncation dimension. Defaults to None.

InformationRetrievalEvaluator

class sentence_transformers.evaluation.InformationRetrievalEvaluator(queries: Dict[str, str], corpus: Dict[str, str], relevant_docs: Dict[str, Set[str]], corpus_chunk_size: int = 50000, mrr_at_k: List[int] = [10], ndcg_at_k: List[int] = [10], accuracy_at_k: List[int] = [1, 3, 5, 10], precision_recall_at_k: List[int] = [1, 3, 5, 10], map_at_k: List[int] = [100], show_progress_bar: bool = False, batch_size: int = 32, name: str = '', write_csv: bool = True, truncate_dim: Optional[int] = None, score_functions: Dict[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = {'cosine': <function cos_sim>, 'dot': <function dot_score>}, main_score_function: Optional[Union[str, sentence_transformers.similarity_functions.SimilarityFunction]] = None)[source]

This class evaluates an Information Retrieval (IR) setting.

Given a set of queries and a large corpus set. It will retrieve for each query the top-k most similar document. It measures Mean Reciprocal Rank (MRR), Recall@k, and Normalized Discounted Cumulative Gain (NDCG)

Example

import random
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from datasets import load_dataset

# Load a model
model = SentenceTransformer('all-mpnet-base-v2')

# Load the Quora IR dataset (https://huggingface.co/datasets/BeIR/quora, https://huggingface.co/datasets/BeIR/quora-qrels)
corpus = load_dataset("BeIR/quora", "corpus", split="corpus")
queries = load_dataset("BeIR/quora", "queries", split="queries")
relevant_docs_data = load_dataset("BeIR/quora-qrels", split="validation")

# Shrink the corpus size heavily to only the relevant documents + 10,000 random documents
required_corpus_ids = list(map(str, relevant_docs_data["corpus-id"]))
required_corpus_ids += random.sample(corpus["_id"], k=10_000)
corpus = corpus.filter(lambda x: x["_id"] in required_corpus_ids)

# Convert the datasets to dictionaries
corpus = dict(zip(corpus["_id"], corpus["text"]))  # Our corpus (cid => document)
queries = dict(zip(queries["_id"], queries["text"]))  # Our queries (qid => question)
relevant_docs = {}  # Query ID to relevant documents (qid => set([relevant_cids])
for qid, corpus_ids in zip(relevant_docs_data["query-id"], relevant_docs_data["corpus-id"]):
    qid = str(qid)
    corpus_ids = str(corpus_ids)
    if qid not in relevant_docs:
        relevant_docs[qid] = set()
    relevant_docs[qid].add(corpus_ids)

# Given queries, a corpus and a mapping with relevant documents, the InformationRetrievalEvaluator computes different IR metrics.
ir_evaluator = InformationRetrievalEvaluator(
    queries=queries,
    corpus=corpus,
    relevant_docs=relevant_docs,
    name="BeIR-quora-dev",
)
results = ir_evaluator(model)
'''
Information Retrieval Evaluation of the model on the BeIR-quora-dev dataset:
Queries: 5000
Corpus: 17476

Score-Function: cosine
Accuracy@1: 96.26%
Accuracy@3: 99.38%
Accuracy@5: 99.74%
Accuracy@10: 99.94%
Precision@1: 96.26%
Precision@3: 43.01%
Precision@5: 27.66%
Precision@10: 14.58%
Recall@1: 82.93%
Recall@3: 96.28%
Recall@5: 98.38%
Recall@10: 99.55%
MRR@10: 0.9782
NDCG@10: 0.9807
MAP@100: 0.9732
Score-Function: dot
Accuracy@1: 96.26%
Accuracy@3: 99.38%
Accuracy@5: 99.74%
Accuracy@10: 99.94%
Precision@1: 96.26%
Precision@3: 43.01%
Precision@5: 27.66%
Precision@10: 14.58%
Recall@1: 82.93%
Recall@3: 96.28%
Recall@5: 98.38%
Recall@10: 99.55%
MRR@10: 0.9782
NDCG@10: 0.9807
MAP@100: 0.9732
'''
print(ir_evaluator.primary_metric)
# => "BeIR-quora-dev_cosine_map@100"
print(results[ir_evaluator.primary_metric])
# => 0.9732046108457585

Initializes the InformationRetrievalEvaluator.

Parameters
  • queries (Dict[str, str]) – A dictionary mapping query IDs to queries.

  • corpus (Dict[str, str]) – A dictionary mapping document IDs to documents.

  • relevant_docs (Dict[str, Set[str]]) – A dictionary mapping query IDs to a set of relevant document IDs.

  • corpus_chunk_size (int) – The size of each chunk of the corpus. Defaults to 50000.

  • mrr_at_k (List[int]) – A list of integers representing the values of k for MRR calculation. Defaults to [10].

  • ndcg_at_k (List[int]) – A list of integers representing the values of k for NDCG calculation. Defaults to [10].

  • accuracy_at_k (List[int]) – A list of integers representing the values of k for accuracy calculation. Defaults to [1, 3, 5, 10].

  • precision_recall_at_k (List[int]) – A list of integers representing the values of k for precision and recall calculation. Defaults to [1, 3, 5, 10].

  • map_at_k (List[int]) – A list of integers representing the values of k for MAP calculation. Defaults to [100].

  • show_progress_bar (bool) – Whether to show a progress bar during evaluation. Defaults to False.

  • batch_size (int) – The batch size for evaluation. Defaults to 32.

  • name (str) – A name for the evaluation. Defaults to “”.

  • write_csv (bool) – Whether to write the evaluation results to a CSV file. Defaults to True.

  • truncate_dim (int, optional) – The dimension to truncate the embeddings to. Defaults to None.

  • score_functions (Dict[str, Callable[[Tensor, Tensor], Tensor]]) – A dictionary mapping score function names to score functions. Defaults to {SimilarityFunction.COSINE.value: cos_sim, SimilarityFunction.DOT_PRODUCT.value: dot_score}.

  • main_score_function (Union[str, SimilarityFunction], optional) – The main score function to use for evaluation. Defaults to None.

MSEEvaluator

class sentence_transformers.evaluation.MSEEvaluator(source_sentences: List[str], target_sentences: List[str], teacher_model=None, show_progress_bar: bool = False, batch_size: int = 32, name: str = '', write_csv: bool = True, truncate_dim: Optional[int] = None)[source]

Computes the mean squared error (x100) between the computed sentence embedding and some target sentence embedding.

The MSE is computed between ||teacher.encode(source_sentences) - student.encode(target_sentences)||.

For multilingual knowledge distillation (https://arxiv.org/abs/2004.09813), source_sentences are in English and target_sentences are in a different language like German, Chinese, Spanish…

Parameters
  • source_sentences (List[str]) – Source sentences to embed with the teacher model.

  • target_sentences (List[str]) – Target sentences to embed with the student model.

  • teacher_model (SentenceTransformer, optional) – The teacher model to compute the source sentence embeddings.

  • show_progress_bar (bool, optional) – Show progress bar when computing embeddings. Defaults to False.

  • batch_size (int, optional) – Batch size to compute sentence embeddings. Defaults to 32.

  • name (str, optional) – Name of the evaluator. Defaults to “”.

  • write_csv (bool, optional) – Write results to CSV file. Defaults to True.

  • truncate_dim (int, optional) – The dimension to truncate sentence embeddings to. None uses the model’s current truncation dimension. Defaults to None.

Example

from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import MSEEvaluator
from datasets import load_dataset

# Load a model
student_model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
teacher_model = SentenceTransformer('all-mpnet-base-v2')

# Load any dataset with some texts
dataset = load_dataset("sentence-transformers/stsb", split="validation")
sentences = dataset["sentence1"] + dataset["sentence2"]

# Given queries, a corpus and a mapping with relevant documents, the InformationRetrievalEvaluator computes different IR metrics.
mse_evaluator = MSEEvaluator(
    source_sentences=sentences,
    target_sentences=sentences,
    teacher_model=teacher_model,
    name="stsb-dev",
)
results = mse_evaluator(student_model)
'''
MSE evaluation (lower = better) on the stsb-dev dataset:
MSE (*100):  0.805045
'''
print(mse_evaluator.primary_metric)
# => "stsb-dev_negative_mse"
print(results[mse_evaluator.primary_metric])
# => -0.8050452917814255

Base class for all evaluators. Notably, this class introduces the greater_is_better and primary_metric attributes. The former is a boolean indicating whether a higher evaluation score is better, which is used for choosing the best checkpoint if load_best_model_at_end is set to True in the training arguments.

The latter is a string indicating the primary metric for the evaluator. This has to be defined whenever the evaluator returns a dictionary of metrics, and the primary metric is the key pointing to the primary metric, i.e. the one that is used for model selection and/or logging.

ParaphraseMiningEvaluator

class sentence_transformers.evaluation.ParaphraseMiningEvaluator(sentences_map: Dict[str, str], duplicates_list: Optional[List[Tuple[str, str]]] = None, duplicates_dict: Optional[Dict[str, Dict[str, bool]]] = None, add_transitive_closure: bool = False, query_chunk_size: int = 5000, corpus_chunk_size: int = 100000, max_pairs: int = 500000, top_k: int = 100, show_progress_bar: bool = False, batch_size: int = 16, name: str = '', write_csv: bool = True, truncate_dim: Optional[int] = None)[source]

Given a large set of sentences, this evaluator performs paraphrase (duplicate) mining and identifies the pairs with the highest similarity. It compare the extracted paraphrase pairs with a set of gold labels and computes the F1 score.

Example

from datasets import load_dataset
from sentence_transformers.SentenceTransformer import SentenceTransformer
from sentence_transformers.evaluation import ParaphraseMiningEvaluator

# Load a model
model = SentenceTransformer('all-mpnet-base-v2')

# Load the Quora Duplicates Mining dataset
questions_dataset = load_dataset("sentence-transformers/quora-duplicates-mining", "questions", split="dev")
duplicates_dataset = load_dataset("sentence-transformers/quora-duplicates-mining", "duplicates", split="dev")

# Create a mapping from qid to question & a list of duplicates (qid1, qid2)
qid_to_questions = dict(zip(questions_dataset["qid"], questions_dataset["question"]))
duplicates = list(zip(duplicates_dataset["qid1"], duplicates_dataset["qid2"]))

# Initialize the paraphrase mining evaluator
paraphrase_mining_evaluator = ParaphraseMiningEvaluator(
    sentences_map=qid_to_questions,
    duplicates_list=duplicates,
    name="quora-duplicates-dev",
)
results = paraphrase_mining_evaluator(model)
'''
Paraphrase Mining Evaluation of the model on the quora-duplicates-dev dataset:
Number of candidate pairs: 250564
Average Precision: 56.51
Optimal threshold: 0.8325
Precision: 52.76
Recall: 59.19
F1: 55.79
'''
print(paraphrase_mining_evaluator.primary_metric)
# => "quora-duplicates-dev_average_precision"
print(results[paraphrase_mining_evaluator.primary_metric])
# => 0.5650940787776353

Initializes the ParaphraseMiningEvaluator.

Parameters
  • sentences_map (Dict[str, str]) – A dictionary that maps sentence-ids to sentences. For example, sentences_map[id] => sentence.

  • duplicates_list (List[Tuple[str, str]], optional) – A list with id pairs [(id1, id2), (id1, id5)] that identifies the duplicates / paraphrases in the sentences_map. Defaults to None.

  • duplicates_dict (Dict[str, Dict[str, bool]], optional) – A default dictionary mapping [id1][id2] to true if id1 and id2 are duplicates. Must be symmetric, i.e., if [id1][id2] => True, then [id2][id1] => True. Defaults to None.

  • add_transitive_closure (bool, optional) – If true, it adds a transitive closure, i.e. if dup[a][b] and dup[b][c], then dup[a][c]. Defaults to False.

  • query_chunk_size (int, optional) – To identify the paraphrases, the cosine-similarity between all sentence-pairs will be computed. As this might require a lot of memory, we perform a batched computation. query_chunk_size sentences will be compared against up to corpus_chunk_size sentences. In the default setting, 5000 sentences will be grouped together and compared up-to against 100k other sentences. Defaults to 5000.

  • corpus_chunk_size (int, optional) – The corpus will be batched, to reduce the memory requirement. Defaults to 100000.

  • max_pairs (int, optional) – We will only extract up to max_pairs potential paraphrase candidates. Defaults to 500000.

  • top_k (int, optional) – For each query, we extract the top_k most similar pairs and add it to a sorted list. I.e., for one sentence we cannot find more than top_k paraphrases. Defaults to 100.

  • show_progress_bar (bool, optional) – Output a progress bar. Defaults to False.

  • batch_size (int, optional) – Batch size for computing sentence embeddings. Defaults to 16.

  • name (str, optional) – Name of the experiment. Defaults to “”.

  • write_csv (bool, optional) – Write results to CSV file. Defaults to True.

  • truncate_dim (Optional[int], optional) – The dimension to truncate sentence embeddings to. None uses the model’s current truncation dimension. Defaults to None.

RerankingEvaluator

class sentence_transformers.evaluation.RerankingEvaluator(samples, at_k: int = 10, name: str = '', write_csv: bool = True, similarity_fct: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = <function cos_sim>, batch_size: int = 64, show_progress_bar: bool = False, use_batched_encoding: bool = True, truncate_dim: Optional[int] = None, mrr_at_k: Optional[int] = None)[source]

This class evaluates a SentenceTransformer model for the task of re-ranking.

Given a query and a list of documents, it computes the score [query, doc_i] for all possible documents and sorts them in decreasing order. Then, MRR@10, NDCG@10 and MAP is compute to measure the quality of the ranking.

Parameters
  • samples (list) – A list of dictionaries, where each dictionary represents a sample and has the following keys: - ‘query’: The search query. - ‘positive’: A list of positive (relevant) documents. - ‘negative’: A list of negative (irrelevant) documents.

  • at_k (int, optional) – Only consider the top k most similar documents to each query for the evaluation. Defaults to 10.

  • name (str, optional) – Name of the evaluator. Defaults to “”.

  • write_csv (bool, optional) – Write results to CSV file. Defaults to True.

  • similarity_fct (Callable[[torch.Tensor, torch.Tensor], torch.Tensor], optional) – Similarity function between sentence embeddings. By default, cosine similarity. Defaults to cos_sim.

  • batch_size (int, optional) – Batch size to compute sentence embeddings. Defaults to 64.

  • show_progress_bar (bool, optional) – Show progress bar when computing embeddings. Defaults to False.

  • use_batched_encoding (bool, optional) – Whether or not to encode queries and documents in batches for greater speed, or 1-by-1 to save memory. Defaults to True.

  • truncate_dim (Optional[int], optional) – The dimension to truncate sentence embeddings to. None uses the model’s current truncation dimension. Defaults to None.

  • mrr_at_k (Optional[int], optional) – Deprecated parameter. Please use at_k instead. Defaults to None.

Base class for all evaluators. Notably, this class introduces the greater_is_better and primary_metric attributes. The former is a boolean indicating whether a higher evaluation score is better, which is used for choosing the best checkpoint if load_best_model_at_end is set to True in the training arguments.

The latter is a string indicating the primary metric for the evaluator. This has to be defined whenever the evaluator returns a dictionary of metrics, and the primary metric is the key pointing to the primary metric, i.e. the one that is used for model selection and/or logging.

SentenceEvaluator

class sentence_transformers.evaluation.SentenceEvaluator[source]

Base class for all evaluators

Extend this class and implement __call__ for custom evaluators.

Base class for all evaluators. Notably, this class introduces the greater_is_better and primary_metric attributes. The former is a boolean indicating whether a higher evaluation score is better, which is used for choosing the best checkpoint if load_best_model_at_end is set to True in the training arguments.

The latter is a string indicating the primary metric for the evaluator. This has to be defined whenever the evaluator returns a dictionary of metrics, and the primary metric is the key pointing to the primary metric, i.e. the one that is used for model selection and/or logging.

SequentialEvaluator

class sentence_transformers.evaluation.SequentialEvaluator(evaluators: Iterable[sentence_transformers.evaluation.SentenceEvaluator.SentenceEvaluator], main_score_function=<function SequentialEvaluator.<lambda>>)[source]

This evaluator allows that multiple sub-evaluators are passed. When the model is evaluated, the data is passed sequentially to all sub-evaluators.

All scores are passed to ‘main_score_function’, which derives one final score value

Initializes a SequentialEvaluator object.

Parameters
  • evaluators (Iterable[SentenceEvaluator]) – A collection of SentenceEvaluator objects.

  • main_score_function (function, optional) – A function that takes a list of scores and returns the main score. Defaults to selecting the last score in the list.

Example

evaluator1 = BinaryClassificationEvaluator(...)
evaluator2 = InformationRetrievalEvaluator(...)
evaluator3 = MSEEvaluator(...)
seq_evaluator = SequentialEvaluator([evaluator1, evaluator2, evaluator3])

TranslationEvaluator

class sentence_transformers.evaluation.TranslationEvaluator(source_sentences: List[str], target_sentences: List[str], show_progress_bar: bool = False, batch_size: int = 16, name: str = '', print_wrong_matches: bool = False, write_csv: bool = True, truncate_dim: Optional[int] = None)[source]

Given two sets of sentences in different languages, e.g. (en_1, en_2, en_3…) and (fr_1, fr_2, fr_3, …), and assuming that fr_i is the translation of en_i. Checks if vec(en_i) has the highest similarity to vec(fr_i). Computes the accuracy in both directions

Example

from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import TranslationEvaluator
from datasets import load_dataset

# Load a model
model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')

# Load a parallel sentences dataset
dataset = load_dataset("sentence-transformers/parallel-sentences-news-commentary", "en-nl", split="train[:1000]")

# Initialize the TranslationEvaluator using the same texts from two languages
translation_evaluator = TranslationEvaluator(
    source_sentences=dataset["english"],
    target_sentences=dataset["non_english"],
    name="news-commentary-en-nl",
)
results = translation_evaluator(model)
'''
Evaluating translation matching Accuracy of the model on the news-commentary-en-nl dataset:
Accuracy src2trg: 90.80
Accuracy trg2src: 90.40
'''
print(translation_evaluator.primary_metric)
# => "news-commentary-en-nl_mean_accuracy"
print(results[translation_evaluator.primary_metric])
# => 0.906

Constructs an evaluator based for the dataset

The labels need to indicate the similarity between the sentences.

Parameters
  • source_sentences (List[str]) – List of sentences in the source language.

  • target_sentences (List[str]) – List of sentences in the target language.

  • show_progress_bar (bool) – Whether to show a progress bar when computing embeddings. Defaults to False.

  • batch_size (int) – The batch size to compute sentence embeddings. Defaults to 16.

  • name (str) – The name of the evaluator. Defaults to an empty string.

  • print_wrong_matches (bool) – Whether to print incorrect matches. Defaults to False.

  • write_csv (bool) – Whether to write the evaluation results to a CSV file. Defaults to True.

  • truncate_dim (int, optional) – The dimension to truncate sentence embeddings to. If None, the model’s current truncation dimension will be used. Defaults to None.

TripletEvaluator

class sentence_transformers.evaluation.TripletEvaluator(anchors: List[str], positives: List[str], negatives: List[str], main_distance_function: Optional[Union[str, sentence_transformers.similarity_functions.SimilarityFunction]] = None, name: str = '', batch_size: int = 16, show_progress_bar: bool = False, write_csv: bool = True, truncate_dim: Optional[int] = None)[source]

Evaluate a model based on a triplet: (sentence, positive_example, negative_example). Checks if distance(sentence, positive_example) < distance(sentence, negative_example).

Example

from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import TripletEvaluator
from datasets import load_dataset

# Load a model
model = SentenceTransformer('all-mpnet-base-v2')

# Load a dataset with (anchor, positive, negative) triplets
dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")

# Initialize the TripletEvaluator using anchors, positives, and negatives
triplet_evaluator = TripletEvaluator(
    anchors=dataset[:1000]["anchor"],
    positives=dataset[:1000]["positive"],
    negatives=dataset[:1000]["negative"],
    name="all-nli-dev",
)
results = triplet_evaluator(model)
'''
TripletEvaluator: Evaluating the model on the all-nli-dev dataset:
Accuracy Cosine Distance:        95.60
Accuracy Dot Product:            4.40
Accuracy Manhattan Distance:     95.40
Accuracy Euclidean Distance:     95.60
'''
print(triplet_evaluator.primary_metric)
# => "all-nli-dev_max_accuracy"
print(results[triplet_evaluator.primary_metric])
# => 0.956

Initializes a TripletEvaluator object.

Parameters
  • anchors (List[str]) – Sentences to check similarity to. (e.g. a query)

  • positives (List[str]) – List of positive sentences

  • negatives (List[str]) – List of negative sentences

  • main_distance_function (Union[str, SimilarityFunction], optional) – The distance function to use. If not specified, use cosine similarity, dot product, Euclidean, and Manhattan. Defaults to None.

  • name (str) – Name for the output. Defaults to “”.

  • batch_size (int) – Batch size used to compute embeddings. Defaults to 16.

  • show_progress_bar (bool) – If true, prints a progress bar. Defaults to False.

  • write_csv (bool) – Write results to a CSV file. Defaults to True.

  • truncate_dim (int, optional) – The dimension to truncate sentence embeddings to. None uses the model’s current truncation dimension. Defaults to None.