Computing Sentence Embeddings¶

The basic function to compute sentence embeddings looks like this:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("all-MiniLM-L6-v2")

# Our sentences we like to encode
sentences = [
    "This framework generates embeddings for each input sentence",
    "Sentences are passed as a list of strings.",
    "The quick brown fox jumps over the lazy dog.",
]

# Sentences are encoded by calling model.encode()
embeddings = model.encode(sentences)

# Print the embeddings
for sentence, embedding in zip(sentences, embeddings):
    print("Sentence:", sentence)
    print("Embedding:", embedding)
    print("")

Note: Even though we talk about sentence embeddings, you can use it also for shorter phrases as well as for longer texts with multiple sentences. See the section on Input Sequence Length for more notes on embeddings for paragraphs.

First, we load a sentence-transformer model:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("model_name_or_path")

You can either specify a pre-trained model or you can pass a path on your disc to load the sentence-transformer model from that folder.

If available, the model is automatically executed on the GPU. You can specify the device for the model like this:

model = SentenceTransformer("model_name_or_path", device="cuda")

With device any pytorch device (like CPU, cuda, cuda:0 etc.)

The relevant method to encode a set of sentences / texts is model.encode(). In the following, you can find parameters this method accepts. Some relevant parameters are batch_size (depending on your GPU a different batch size is optimal) as well as convert_to_numpy (returns a numpy matrix) and convert_to_tensor (returns a pytorch tensor).

class sentence_transformers.SentenceTransformer(model_name_or_path: Optional[str] = None, modules: Optional[Iterable[torch.nn.modules.module.Module]] = None, device: Optional[str] = None, prompts: Optional[Dict[str, str]] = None, default_prompt_name: Optional[str] = None, cache_folder: Optional[str] = None, trust_remote_code: bool = False, revision: Optional[str] = None, token: Optional[Union[bool, str]] = None, use_auth_token: Optional[Union[bool, str]] = None)¶

Loads or creates a SentenceTransformer model that can be used to map sentences / text to embeddings.

Parameters
  • model_name_or_path – If it is a filepath on disc, it loads the model from that path. If it is not a path, it first tries to download a pre-trained SentenceTransformer model. If that fails, tries to construct a model from the Hugging Face Hub with that name.

  • modules – A list of torch Modules that should be called sequentially, can be used to create custom SentenceTransformer models from scratch.

  • device – Device (like “cuda”, “cpu”, “mps”, “npu”) that should be used for computation. If None, checks if a GPU can be used.

  • prompts – A dictionary with prompts for the model. The key is the prompt name, the value is the prompt text. The prompt text will be prepended before any text to encode. For example: {“query”: “query: “, “passage”: “passage: “} or {“clustering”: “Identify the main category based on the titles in “}.

  • default_prompt_name – The name of the prompt that should be used by default. If not set, no prompt will be applied.

  • cache_folder – Path to store models. Can also be set by the SENTENCE_TRANSFORMERS_HOME environment variable.

  • revision – The specific model version to use. It can be a branch name, a tag name, or a commit id, for a stored model on Hugging Face.

  • trust_remote_code – Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set to True for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine.

  • token – Hugging Face authentication token to download private models.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

encode(sentences: Union[str, List[str]], prompt_name: Optional[str] = None, prompt: Optional[str] = None, batch_size: int = 32, show_progress_bar: Optional[bool] = None, output_value: str = 'sentence_embedding', convert_to_numpy: bool = True, convert_to_tensor: bool = False, device: Optional[str] = None, normalize_embeddings: bool = False)Union[List[torch.Tensor], numpy.ndarray, torch.Tensor]¶

Computes sentence embeddings.

Parameters
  • sentences – the sentences to embed.

  • prompt_name – The name of the prompt to use for encoding. Must be a key in the prompts dictionary, which is either set in the constructor or loaded from the model configuration. For example if prompt_name is "query" and the prompts is {"query": "query: ", ...}, then the sentence “What is the capital of France?” will be encoded as “query: What is the capital of France?” because the sentence is appended to the prompt. If prompt is also set, this argument is ignored.

  • prompt – The prompt to use for encoding. For example, if the prompt is "query: ", then the sentence “What is the capital of France?” will be encoded as “query: What is the capital of France?” because the sentence is appended to the prompt. If prompt is set, prompt_name is ignored.

  • batch_size – the batch size used for the computation.

  • show_progress_bar – Whether to output a progress bar when encode sentences.

  • output_value – The type of embeddings to return: “sentence_embedding” to get sentence embeddings, “token_embeddings” to get wordpiece token embeddings, and None, to get all output values. Defaults to “sentence_embedding”.

  • convert_to_numpy – Whether the output should be a list of numpy vectors. If False, it is a list of PyTorch tensors.

  • convert_to_tensor – Whether the output should be one large tensor. Overwrites convert_to_numpy.

  • device – Which torch.device to use for the computation.

  • normalize_embeddings – Whether to normalize returned vectors to have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.

Returns

By default, a list of tensors is returned. If convert_to_tensor, a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned.

Prompt Templates¶

Some models require using specific text prompts to achieve optimal performance. For example, with intfloat/multilingual-e5-large you should prefix all queries with query: and all passages with passage: . Another example is BAAI/bge-large-en-v1.5, which performs best for retrieval when the input texts are prefixed with Represent this sentence for searching relevant passages: .

Sentence Transformer models can be initialized with prompts and default_prompt_name parameters:

  • prompts is an optional argument that accepts a dictionary of prompts with prompt names to prompt texts. The prompt will be prepended to the input text during inference. For example,

    model = SentenceTransformer(
        "intfloat/multilingual-e5-large",
        prompts={
            "classification": "Classify the following text: ",
            "retrieval": "Retrieve semantically similar text: ",
            "clustering": "Identify the topic or theme based on the text: ",
        },
    )
    # or
    model.prompts = {
        "classification": "Classify the following text: ",
        "retrieval": "Retrieve semantically similar text: ",
        "clustering": "Identify the topic or theme based on the text: ",
    }
    
  • default_prompt_name is an optional argument that determines the default prompt to be used. It has to correspond with a prompt name from prompts. If None, then no prompt is used by default. For example,

    model = SentenceTransformer(
        "intfloat/multilingual-e5-large",
        prompts={
            "classification": "Classify the following text: ",
            "retrieval": "Retrieve semantically similar text: ",
            "clustering": "Identify the topic or theme based on the text: ",
        },
        default_prompt_name="retrieval",
    )
    # or
    model.default_prompt_name="retrieval"
    

Both of these parameters can also be specified in the config_sentence_transformers.json file of a saved model. That way, you won’t have to specify these options manually when loading. When you save a Sentence Transformer model, these options will be automatically saved as well.

During inference, prompts can be applied in a few different ways. All of these scenarios result in identical texts being embedded:

  1. Explicitly using the prompt option in SentenceTransformer.encode:

    embeddings = model.encode("How to bake a strawberry cake", prompt="Retrieve semantically similar text: ")
    
  2. Explicitly using the prompt_name option in SentenceTransformer.encode by relying on the prompts loaded from a) initialization or b) the model config.

    embeddings = model.encode("How to bake a strawberry cake", prompt_name="retrieval")
    
  3. If prompt nor prompt_name are specified in SentenceTransformer.encode, then the prompt specified by default_prompt_name will be applied. If it is None, then no prompt will be applied.

    embeddings = model.encode("How to bake a strawberry cake")
    

Input Sequence Length¶

Transformer models like BERT / RoBERTa / DistilBERT etc. the runtime and the memory requirement grows quadratic with the input length. This limits transformers to inputs of certain lengths. A common value for BERT & Co. are 512 word pieces, which corresponds to about 300-400 words (for English). Longer texts than this are truncated to the first x word pieces.

By default, the provided methods use a limit of 128 word pieces, longer inputs will be truncated. You can get and set the maximal sequence length like this:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("all-MiniLM-L6-v2")

print("Max Sequence Length:", model.max_seq_length)

# Change the length to 200
model.max_seq_length = 200

print("Max Sequence Length:", model.max_seq_length)

Note: You cannot increase the length higher than what is maximally supported by the respective transformer model. Also note that if a model was trained on short texts, the representations for long texts might not be that good.

Storing & Loading Embeddings¶

The easiest method is to use pickle to store pre-computed embeddings on disc and to load it from disc. This can especially be useful if you need to encode large set of sentences.

from sentence_transformers import SentenceTransformer
import pickle

model = SentenceTransformer("all-MiniLM-L6-v2")
sentences = [
    "This framework generates embeddings for each input sentence",
    "Sentences are passed as a list of string.",
    "The quick brown fox jumps over the lazy dog.",
]


embeddings = model.encode(sentences)

# Store sentences & embeddings on disc
with open("embeddings.pkl", "wb") as fOut:
    pickle.dump({"sentences": sentences, "embeddings": embeddings}, fOut, protocol=pickle.HIGHEST_PROTOCOL)

# Load sentences & embeddings from disc
with open("embeddings.pkl", "rb") as fIn:
    stored_data = pickle.load(fIn)
    stored_sentences = stored_data["sentences"]
    stored_embeddings = stored_data["embeddings"]

Multi-Process / Multi-GPU Encoding¶

You can encode input texts with more than one GPU (or with multiple processes on a CPU machine). For an example, see: computing_embeddings_multi_gpu.py.

The relevant method is start_multi_process_pool(), which starts multiple processes that are used for encoding.

SentenceTransformer.start_multi_process_pool(target_devices: Optional[List[str]] = None)¶

Starts multi process to process the encoding with several, independent processes. This method is recommended if you want to encode on multiple GPUs or CPUs. It is advised to start only one process per GPU. This method works together with encode_multi_process and stop_multi_process_pool.

Parameters

target_devices – PyTorch target devices, e.g. [“cuda:0”, “cuda:1”, …], [“npu:0”, “npu:1”, …] or [“cpu”, “cpu”, “cpu”, “cpu”]. If target_devices is None and CUDA/NPU is available, then all available CUDA/NPU devices will be used. If target_devices is None and CUDA/NPU is not available, then 4 CPU devices will be used.

Returns

Returns a dict with the target processes, an input queue and and output queue.

Sentence Embeddings with Transformers¶

Most of our pre-trained models are based on Huggingface.co/Transformers and are also hosted in the models repository from Huggingface. It is possible to use our sentence embeddings models without installing sentence-transformers:

from transformers import AutoTokenizer, AutoModel
import torch


# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
    sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    return sum_embeddings / sum_mask


# Sentences we want sentence embeddings for
sentences = [
    "This framework generates embeddings for each input sentence",
    "Sentences are passed as a list of string.",
    "The quick brown fox jumps over the lazy dog.",
]

# Load AutoModel from huggingface model repository
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

# Tokenize sentences
encoded_input = tokenizer(
    sentences, padding=True, truncation=True, max_length=128, return_tensors="pt"
)

# Compute token embeddings
with torch.no_grad():
    model_output = model(**encoded_input)

# Perform pooling. In this case, mean pooling
sentence_embeddings = mean_pooling(model_output, encoded_input["attention_mask"])

You can find the available models here: https://huggingface.co/sentence-transformers

In the above example we add mean pooling on top of the AutoModel (which will load a BERT model). We also have models with max-pooling and where we use the CLS token. How to apply this pooling correctly, have a look at sentence-transformers/bert-base-nli-max-tokens and /sentence-transformers/bert-base-nli-cls-token.