util

sentence_transformers.util defines different helpful functions to work with text embeddings.

Helper Functions

sentence_transformers.util.community_detection()[source]

Function for Fast Community Detection.

Finds in the embeddings all communities, i.e. embeddings that are close (closer than threshold). Returns only communities that are larger than min_community_size. The communities are returned in decreasing order. The first element in each list is the central point in the community.

Parameters
  • embeddings (torch.Tensor or numpy.ndarray) – The input embeddings.

  • threshold (float) – The threshold for determining if two embeddings are close. Defaults to 0.75.

  • min_community_size (int) – The minimum size of a community to be considered. Defaults to 10.

  • batch_size (int) – The batch size for computing cosine similarity scores. Defaults to 1024.

  • show_progress_bar (bool) – Whether to show a progress bar during computation. Defaults to False.

Returns

A list of communities, where each community is represented as a list of indices.

Return type

List[List[int]]

sentence_transformers.util.http_get(url: str, path: str)None[source]

Downloads a URL to a given path on disk.

Parameters
  • url (str) – The URL to download.

  • path (str) – The path to save the downloaded file.

Raises

requests.HTTPError – If the HTTP request returns a non-200 status code.

Returns

None

sentence_transformers.util.is_training_available()bool[source]

Returns True if we have the required dependencies for training Sentence Transformers models, i.e. Huggingface datasets and Huggingface accelerate.

sentence_transformers.util.mine_hard_negatives(dataset: Dataset, model: SentenceTransformer, anchor_column_name: str | None = None, positive_column_name: str | None = None, corpus: list[str] | None = None, cross_encoder: CrossEncoder | None = None, range_min: int = 0, range_max: int | None = None, max_score: float | None = None, min_score: float | None = None, margin: float | None = None, num_negatives: int = 3, sampling_strategy: Literal[random, top] = 'top', as_triplets: bool = True, batch_size: int = 32, faiss_batch_size: int = 16384, use_faiss: bool = False, verbose: bool = True)Dataset[source]

Add hard negatives to a dataset of (anchor, positive) pairs to create (anchor, positive, negative) triplets or (anchor, positive, negative_1, …, negative_n) tuples.

Hard negative mining is a technique to improve the quality of a dataset by adding hard negatives, which are texts that may appear similar to the anchor, but are not. Using hard negatives can improve the performance of models trained on the dataset.

This function uses a SentenceTransformer model to embed the sentences in the dataset, and then finds the closest matches to each anchor sentence in the dataset. It then samples negatives from the closest matches, optionally using a CrossEncoder model to rescore the candidates.

You can influence the candidate negative selection in various ways:

  • range_min: Minimum rank of the closest matches to consider as negatives: useful to skip the most similar texts to avoid marking texts as negative that are actually positives.

  • range_max: Maximum rank of the closest matches to consider as negatives: useful to limit the number of candidates to sample negatives from. A lower value makes processing faster, but may result in less candidate negatives that satisfy the margin or max_score conditions.

  • max_score: Maximum score to consider as a negative: useful to skip candidates that are too similar to the anchor.

  • min_score: Minimum score to consider as a negative: useful to skip candidates that are too dissimilar to the anchor.

  • margin: Margin for hard negative mining: useful to skip candidates negatives whose similarity to the anchor is within a certain margin of the positive pair. A value of 0 can be used to enforce that the negative is always further away from the anchor than the positive.

  • sampling_strategy: Sampling strategy for negatives: “top” or “random”. “top” will always sample the top n candidates as negatives, while “random” will sample n negatives randomly from the candidates that satisfy the margin or max_score conditions.

Example

>>> from sentence_transformers.util import mine_hard_negatives
>>> from sentence_transformers import SentenceTransformer
>>> from datasets import load_dataset
>>> # Load a Sentence Transformer model
>>> model = SentenceTransformer("all-MiniLM-L6-v2")
>>>
>>> # Load a dataset to mine hard negatives from
>>> dataset = load_dataset("sentence-transformers/natural-questions", split="train")
>>> dataset
Dataset({
    features: ['query', 'answer'],
    num_rows: 100231
})
>>> dataset = mine_hard_negatives(
...     dataset=dataset,
...     model=model,
...     range_min=10,
...     range_max=50,
...     max_score=0.8,
...     margin=0.1,
...     num_negatives=5,
...     sampling_strategy="random",
...     batch_size=128,
...     use_faiss=True,
... )
Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 784/784 [00:43<00:00, 17.83it/s]
Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 784/784 [00:07<00:00, 99.60it/s]
Querying FAISS index: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 784/784 [00:00<00:00, 884.99it/s]
Metric       Positive       Negative     Difference
Count         100,231        431,255        431,255
Mean           0.6866         0.4289         0.2804
Median         0.7010         0.4193         0.2740
Std            0.1125         0.0754         0.0999
Min            0.0303         0.1720         0.1001
25%            0.6221         0.3747         0.1991
50%            0.7010         0.4193         0.2740
75%            0.7667         0.4751         0.3530
Max            0.9584         0.7743         0.7003
Skipped 1289492 potential negatives (25.23%) due to the margin of 0.1.
Skipped 39 potential negatives (0.00%) due to the maximum score of 0.8.
Could not find enough negatives for 69900 samples (13.95%). Consider adjusting the range_max, range_min, margin and max_score parameters if you'd like to find more valid negatives.
>>> # Note: The minimum similarity difference is 0.1001 due to our margin of 0.1
>>> dataset
Dataset({
    features: ['query', 'answer', 'negative'],
    num_rows: 431255
})
>>> dataset[0]
{
    'query': 'when did richmond last play in a preliminary final',
    'answer': "Richmond Football Club Richmond began 2017 with 5 straight wins, a feat it had not achieved since 1995. A series of close losses hampered the Tigers throughout the middle of the season, including a 5-point loss to the Western Bulldogs, 2-point loss to Fremantle, and a 3-point loss to the Giants. Richmond ended the season strongly with convincing victories over Fremantle and St Kilda in the final two rounds, elevating the club to 3rd on the ladder. Richmond's first final of the season against the Cats at the MCG attracted a record qualifying final crowd of 95,028; the Tigers won by 51 points. Having advanced to the first preliminary finals for the first time since 2001, Richmond defeated Greater Western Sydney by 36 points in front of a crowd of 94,258 to progress to the Grand Final against Adelaide, their first Grand Final appearance since 1982. The attendance was 100,021, the largest crowd to a grand final since 1986. The Crows led at quarter time and led by as many as 13, but the Tigers took over the game as it progressed and scored seven straight goals at one point. They eventually would win by 48 points – 16.12 (108) to Adelaide's 8.12 (60) – to end their 37-year flag drought.[22] Dustin Martin also became the first player to win a Premiership medal, the Brownlow Medal and the Norm Smith Medal in the same season, while Damien Hardwick was named AFL Coaches Association Coach of the Year. Richmond's jump from 13th to premiers also marked the biggest jump from one AFL season to the next.",
    'negative': "2018 NRL Grand Final The 2018 NRL Grand Final was the conclusive and premiership-deciding game of the 2018 National Rugby League season and was played on Sunday September 30 at Sydney's ANZ Stadium.[1] The match was contested between minor premiers the Sydney Roosters and defending premiers the Melbourne Storm. In front of a crowd of 82,688, Sydney won the match 21–6 to claim their 14th premiership title and their first since 2013. Roosters five-eighth Luke Keary was awarded the Clive Churchill Medal as the game's official man of the match."
}
>>> dataset.push_to_hub("natural-questions-hard-negatives", "triplet-all")
Parameters
  • dataset (Dataset) – A dataset containing (anchor, positive) pairs.

  • model (SentenceTransformer) – A SentenceTransformer model to use for embedding the sentences.

  • anchor_column_name (str, optional) – The column name in dataset that contains the anchor/query. Defaults to None, in which case the first column in dataset will be used.

  • positive_column_name (str, optional) – The column name in dataset that contains the positive candidates. Defaults to None, in which case the second column in dataset will be used.

  • corpus (List[str], optional) – A list containing documents as strings that will be used as candidate negatives in addition to the second column in dataset. Defaults to None, in which case the second column in dataset will exclusively be used as the negative candidate corpus.

  • cross_encoder (CrossEncoder, optional) – A CrossEncoder model to use for rescoring the candidates. Defaults to None.

  • range_min (int) – Minimum rank of the closest matches to consider as negatives. Defaults to 0.

  • range_max (int, optional) – Maximum rank of the closest matches to consider as negatives. Defaults to None.

  • max_score (float, optional) – Maximum score to consider as a negative. Defaults to None.

  • min_score (float, optional) – Minimum score to consider as a negative. Defaults to None.

  • margin (float, optional) – Margin for hard negative mining. Defaults to None.

  • num_negatives (int) – Number of negatives to sample. Defaults to 3.

  • sampling_strategy (Literal["random", "top"]) – Sampling strategy for negatives: “top” or “random”. Defaults to “top”.

  • as_triplets (bool) – If True, returns up to num_negatives (anchor, positive, negative) triplets for each input sample. If False, returns 1 (anchor, positive, negative_1, …, negative_n) tuple for each input sample. Defaults to True.

  • batch_size (int) – Batch size for encoding the dataset. Defaults to 32.

  • faiss_batch_size (int) – Batch size for FAISS top-k search. Defaults to 16384.

  • use_faiss (bool) – Whether to use FAISS for similarity search. May be recommended for large datasets. Defaults to False.

  • verbose (bool) – Whether to print statistics and logging. Defaults to True.

Returns

A dataset containing (anchor, positive, negative) triplets or (anchor, positive, negative_1, …, negative_n) tuples.

Return type

Dataset

sentence_transformers.util.normalize_embeddings(embeddings: torch.Tensor)torch.Tensor[source]

Normalizes the embeddings matrix, so that each sentence embedding has unit length.

Parameters

embeddings (Tensor) – The input embeddings matrix.

Returns

The normalized embeddings matrix.

Return type

Tensor

sentence_transformers.util.paraphrase_mining(model, sentences: list, show_progress_bar: bool = False, batch_size: int = 32, query_chunk_size: int = 5000, corpus_chunk_size: int = 100000, max_pairs: int = 500000, top_k: int = 100, score_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = <function cos_sim>)list[source]

Given a list of sentences / texts, this function performs paraphrase mining. It compares all sentences against all other sentences and returns a list with the pairs that have the highest cosine similarity score.

Parameters
  • model (SentenceTransformer) – SentenceTransformer model for embedding computation

  • sentences (List[str]) – A list of strings (texts or sentences)

  • show_progress_bar (bool, optional) – Plotting of a progress bar. Defaults to False.

  • batch_size (int, optional) – Number of texts that are encoded simultaneously by the model. Defaults to 32.

  • query_chunk_size (int, optional) – Search for most similar pairs for #query_chunk_size at the same time. Decrease, to lower memory footprint (increases run-time). Defaults to 5000.

  • corpus_chunk_size (int, optional) – Compare a sentence simultaneously against #corpus_chunk_size other sentences. Decrease, to lower memory footprint (increases run-time). Defaults to 100000.

  • max_pairs (int, optional) – Maximal number of text pairs returned. Defaults to 500000.

  • top_k (int, optional) – For each sentence, we retrieve up to top_k other sentences. Defaults to 100.

  • score_function (Callable[[Tensor, Tensor], Tensor], optional) – Function for computing scores. By default, cosine similarity. Defaults to cos_sim.

Returns

Returns a list of triplets with the format [score, id1, id2]

Return type

List[List[Union[float, int]]]

This function performs a cosine similarity search between a list of query embeddings and a list of corpus embeddings. It can be used for Information Retrieval / Semantic Search for corpora up to about 1 Million entries.

Parameters
  • query_embeddings (Tensor) – A 2 dimensional tensor with the query embeddings.

  • corpus_embeddings (Tensor) – A 2 dimensional tensor with the corpus embeddings.

  • query_chunk_size (int, optional) – Process 100 queries simultaneously. Increasing that value increases the speed, but requires more memory. Defaults to 100.

  • corpus_chunk_size (int, optional) – Scans the corpus 100k entries at a time. Increasing that value increases the speed, but requires more memory. Defaults to 500000.

  • top_k (int, optional) – Retrieve top k matching entries. Defaults to 10.

  • score_function (Callable[[Tensor, Tensor], Tensor], optional) – Function for computing scores. By default, cosine similarity.

Returns

A list with one entry for each query. Each entry is a list of dictionaries with the keys ‘corpus_id’ and ‘score’, sorted by decreasing cosine similarity scores.

Return type

List[List[Dict[str, Union[int, float]]]]

Model Optimization

sentence_transformers.backend.export_dynamic_quantized_onnx_model(model: SentenceTransformer, quantization_config: QuantizationConfig | Literal[arm64, avx2, avx512, avx512_vnni], model_name_or_path: str, push_to_hub: bool = False, create_pr: bool = False, file_suffix: str | None = None)None[source]

Export a quantized ONNX model from a SentenceTransformer model.

This function applies dynamic quantization, i.e. without a calibration dataset. Each of the default quantization configurations quantize the model to int8, allowing for faster inference on CPUs, but are likely slower on GPUs.

See https://sbert.net/docs/sentence_transformer/usage/efficiency.html for more information & benchmarks.

Parameters
  • model (SentenceTransformer) – The SentenceTransformer model to be quantized. Must be loaded with backend=”onnx”.

  • quantization_config (QuantizationConfig) – The quantization configuration.

  • model_name_or_path (str) – The path or Hugging Face Hub repository name where the quantized model will be saved.

  • push_to_hub (bool, optional) – Whether to push the quantized model to the Hugging Face Hub. Defaults to False.

  • create_pr (bool, optional) – Whether to create a pull request when pushing to the Hugging Face Hub. Defaults to False.

  • file_suffix (str | None, optional) – The suffix to add to the quantized model file name. Defaults to None.

Raises
  • ImportError – If the required packages optimum and onnxruntime are not installed.

  • ValueError – If the provided model is not a valid SentenceTransformer model loaded with backend=”onnx”.

  • ValueError – If the provided quantization_config is not valid.

Returns

None

sentence_transformers.backend.export_optimized_onnx_model(model: SentenceTransformer, optimization_config: OptimizationConfig | Literal[O1, O2, O3, O4], model_name_or_path: str, push_to_hub: bool = False, create_pr: bool = False, file_suffix: str | None = None)None[source]

Export an optimized ONNX model from a SentenceTransformer model.

The O1-O4 optimization levels are defined by Optimum and are documented here: https://huggingface.co/docs/optimum/main/en/onnxruntime/usage_guides/optimization

The optimization levels are:

  • O1: basic general optimizations.

  • O2: basic and extended general optimizations, transformers-specific fusions.

  • O3: same as O2 with GELU approximation.

  • O4: same as O3 with mixed precision (fp16, GPU-only)

See https://sbert.net/docs/sentence_transformer/usage/efficiency.html for more information & benchmarks.

Parameters
  • model (SentenceTransformer) – The SentenceTransformer model to be optimized. Must be loaded with backend=”onnx”.

  • optimization_config (OptimizationConfig | Literal["O1", "O2", "O3", "O4"]) – The optimization configuration or level.

  • model_name_or_path (str) – The path or Hugging Face Hub repository name where the optimized model will be saved.

  • push_to_hub (bool, optional) – Whether to push the optimized model to the Hugging Face Hub. Defaults to False.

  • create_pr (bool, optional) – Whether to create a pull request when pushing to the Hugging Face Hub. Defaults to False.

  • file_suffix (str | None, optional) – The suffix to add to the optimized model file name. Defaults to None.

Raises
  • ImportError – If the required packages optimum and onnxruntime are not installed.

  • ValueError – If the provided model is not a valid SentenceTransformer model loaded with backend=”onnx”.

  • ValueError – If the provided optimization_config is not valid.

Returns

None

sentence_transformers.backend.export_static_quantized_openvino_model(model: SentenceTransformer, quantization_config: OVQuantizationConfig | dict | None, model_name_or_path: str, dataset_name: str | None = None, dataset_config_name: str | None = None, dataset_split: str | None = None, column_name: str | None = None, push_to_hub: bool = False, create_pr: bool = False, file_suffix: str = 'qint8_quantized')None[source]

Export a quantized OpenVINO model from a SentenceTransformer model.

This function applies Post-Training Static Quantization (PTQ) using a calibration dataset, which calibrates quantization constants without requiring model retraining. Each default quantization configuration converts the model to int8 precision, enabling faster inference while maintaining accuracy.

See https://sbert.net/docs/sentence_transformer/usage/efficiency.html for more information & benchmarks.

Parameters
  • model (SentenceTransformer) – The SentenceTransformer model to be quantized. Must be loaded with backend=”openvino”.

  • quantization_config (OVQuantizationConfig | dict | None) – The quantization configuration. If None, default values are used.

  • model_name_or_path (str) – The path or Hugging Face Hub repository name where the quantized model will be saved.

  • dataset_name (str, optional) – The name of the dataset to load for calibration. If not specified, the sst2 subset of the glue dataset will be used by default.

  • dataset_config_name (str, optional) – The specific configuration of the dataset to load.

  • dataset_split (str, optional) – The split of the dataset to load (e.g., ‘train’, ‘test’). Defaults to None.

  • column_name (str, optional) – The column name in the dataset to use for calibration. Defaults to None.

  • push_to_hub (bool, optional) – Whether to push the quantized model to the Hugging Face Hub. Defaults to False.

  • create_pr (bool, optional) – Whether to create a pull request when pushing to the Hugging Face Hub. Defaults to False.

  • file_suffix (str, optional) – The suffix to add to the quantized model file name. Defaults to qint8_quantized.

Raises
  • ImportError – If the required packages optimum and openvino are not installed.

  • ValueError – If the provided model is not a valid SentenceTransformer model loaded with backend=”openvino”.

  • ValueError – If the provided quantization_config is not valid.

Returns

None

Similarity Metrics

sentence_transformers.util.cos_sim()[source]

Computes the cosine similarity between two tensors.

Parameters
  • a (Union[list, np.ndarray, Tensor]) – The first tensor.

  • b (Union[list, np.ndarray, Tensor]) – The second tensor.

Returns

Matrix with res[i][j] = cos_sim(a[i], b[j])

Return type

Tensor

sentence_transformers.util.dot_score()[source]

Computes the dot-product dot_prod(a[i], b[j]) for all i and j.

Parameters
  • a (Union[list, np.ndarray, Tensor]) – The first tensor.

  • b (Union[list, np.ndarray, Tensor]) – The second tensor.

Returns

Matrix with res[i][j] = dot_prod(a[i], b[j])

Return type

Tensor

sentence_transformers.util.euclidean_sim()[source]

Computes the euclidean similarity (i.e., negative distance) between two tensors.

Parameters
  • a (Union[list, np.ndarray, Tensor]) – The first tensor.

  • b (Union[list, np.ndarray, Tensor]) – The second tensor.

Returns

Matrix with res[i][j] = -euclidean_distance(a[i], b[j])

Return type

Tensor

sentence_transformers.util.manhattan_sim()[source]

Computes the manhattan similarity (i.e., negative distance) between two tensors.

Parameters
  • a (Union[list, np.ndarray, Tensor]) – The first tensor.

  • b (Union[list, np.ndarray, Tensor]) – The second tensor.

Returns

Matrix with res[i][j] = -manhattan_distance(a[i], b[j])

Return type

Tensor

sentence_transformers.util.pairwise_cos_sim(a: torch.Tensor, b: torch.Tensor)torch.Tensor[source]

Computes the pairwise cosine similarity cos_sim(a[i], b[i]).

Parameters
  • a (Union[list, np.ndarray, Tensor]) – The first tensor.

  • b (Union[list, np.ndarray, Tensor]) – The second tensor.

Returns

Vector with res[i] = cos_sim(a[i], b[i])

Return type

Tensor

sentence_transformers.util.pairwise_dot_score(a: torch.Tensor, b: torch.Tensor)torch.Tensor[source]

Computes the pairwise dot-product dot_prod(a[i], b[i]).

Parameters
  • a (Union[list, np.ndarray, Tensor]) – The first tensor.

  • b (Union[list, np.ndarray, Tensor]) – The second tensor.

Returns

Vector with res[i] = dot_prod(a[i], b[i])

Return type

Tensor

sentence_transformers.util.pairwise_euclidean_sim()[source]

Computes the euclidean distance (i.e., negative distance) between pairs of tensors.

Parameters
  • a (Union[list, np.ndarray, Tensor]) – The first tensor.

  • b (Union[list, np.ndarray, Tensor]) – The second tensor.

Returns

Vector with res[i] = -euclidean_distance(a[i], b[i])

Return type

Tensor

sentence_transformers.util.pairwise_manhattan_sim()[source]

Computes the manhattan similarity (i.e., negative distance) between pairs of tensors.

Parameters
  • a (Union[list, np.ndarray, Tensor]) – The first tensor.

  • b (Union[list, np.ndarray, Tensor]) – The second tensor.

Returns

Vector with res[i] = -manhattan_distance(a[i], b[i])

Return type

Tensor