SentenceTransformer¶

This page documents the properties and methods when you load a SentenceTransformer model:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("model-name")
class sentence_transformers.SentenceTransformer(model_name_or_path: Optional[str] = None, modules: Optional[Iterable[torch.nn.modules.module.Module]] = None, device: Optional[str] = None, prompts: Optional[Dict[str, str]] = None, default_prompt_name: Optional[str] = None, cache_folder: Optional[str] = None, trust_remote_code: bool = False, revision: Optional[str] = None, token: Optional[Union[bool, str]] = None, use_auth_token: Optional[Union[bool, str]] = None, truncate_dim: Optional[int] = None)¶

Loads or creates a SentenceTransformer model that can be used to map sentences / text to embeddings.

Parameters
  • model_name_or_path – If it is a filepath on disc, it loads the model from that path. If it is not a path, it first tries to download a pre-trained SentenceTransformer model. If that fails, tries to construct a model from the Hugging Face Hub with that name.

  • modules – A list of torch Modules that should be called sequentially, can be used to create custom SentenceTransformer models from scratch.

  • device – Device (like “cuda”, “cpu”, “mps”, “npu”) that should be used for computation. If None, checks if a GPU can be used.

  • prompts – A dictionary with prompts for the model. The key is the prompt name, the value is the prompt text. The prompt text will be prepended before any text to encode. For example: {“query”: “query: “, “passage”: “passage: “} or {“clustering”: “Identify the main category based on the titles in “}.

  • default_prompt_name – The name of the prompt that should be used by default. If not set, no prompt will be applied.

  • cache_folder – Path to store models. Can also be set by the SENTENCE_TRANSFORMERS_HOME environment variable.

  • trust_remote_code – Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set to True for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine.

  • revision – The specific model version to use. It can be a branch name, a tag name, or a commit id, for a stored model on Hugging Face.

  • token – Hugging Face authentication token to download private models.

  • truncate_dim – The dimension to truncate sentence embeddings to. None does no truncation. Truncation is only applicable during inference when .encode is called.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

property device¶

Get torch.device from module, assuming that the whole module has one device.

encode(sentences: Union[str, List[str]], prompt_name: Optional[str] = None, prompt: Optional[str] = None, batch_size: int = 32, show_progress_bar: Optional[bool] = None, output_value: Optional[Literal[sentence_embedding, token_embeddings]] = 'sentence_embedding', precision: Literal[float32, int8, uint8, binary, ubinary] = 'float32', convert_to_numpy: bool = True, convert_to_tensor: bool = False, device: Optional[str] = None, normalize_embeddings: bool = False)Union[List[torch.Tensor], numpy.ndarray, torch.Tensor]¶

Computes sentence embeddings.

Parameters
  • sentences – the sentences to embed.

  • prompt_name – The name of the prompt to use for encoding. Must be a key in the prompts dictionary, which is either set in the constructor or loaded from the model configuration. For example if prompt_name is "query" and the prompts is {"query": "query: ", ...}, then the sentence “What is the capital of France?” will be encoded as “query: What is the capital of France?” because the sentence is appended to the prompt. If prompt is also set, this argument is ignored.

  • prompt – The prompt to use for encoding. For example, if the prompt is "query: ", then the sentence “What is the capital of France?” will be encoded as “query: What is the capital of France?” because the sentence is appended to the prompt. If prompt is set, prompt_name is ignored.

  • batch_size – the batch size used for the computation.

  • show_progress_bar – Whether to output a progress bar when encode sentences.

  • output_value – The type of embeddings to return: “sentence_embedding” to get sentence embeddings, “token_embeddings” to get wordpiece token embeddings, and None, to get all output values. Defaults to “sentence_embedding”.

  • precision – The precision to use for the embeddings. Can be “float32”, “int8”, “uint8”, “binary”, or “ubinary”. All non-float32 precisions are quantized embeddings. Quantized embeddings are smaller in size and faster to compute, but may have a lower accuracy. They are useful for reducing the size of the embeddings of a corpus for semantic search, among other tasks. Defaults to “float32”.

  • convert_to_numpy – Whether the output should be a list of numpy vectors. If False, it is a list of PyTorch tensors.

  • convert_to_tensor – Whether the output should be one large tensor. Overwrites convert_to_numpy.

  • device – Which torch.device to use for the computation.

  • normalize_embeddings – Whether to normalize returned vectors to have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.

Returns

By default, a 2d numpy array with shape [num_inputs, output_dimension] is returned. If only one string input is provided, then the output is a 1d array with shape [output_dimension]. If convert_to_tensor, a torch Tensor is returned instead. If self.truncate_dim <= output_dimension then output_dimension is self.truncate_dim.

encode_multi_process(sentences: List[str], pool: Dict[str, object], prompt_name: Optional[str] = None, prompt: Optional[str] = None, batch_size: int = 32, chunk_size: Optional[int] = None, normalize_embeddings: bool = False)¶

This method allows to run encode() on multiple GPUs. The sentences are chunked into smaller packages and sent to individual processes, which encode these on the different GPUs. This method is only suitable for encoding large sets of sentences

Parameters
  • sentences – List of sentences

  • pool – A pool of workers started with SentenceTransformer.start_multi_process_pool

  • prompt_name – The name of the prompt to use for encoding. Must be a key in the prompts dictionary, which is either set in the constructor or loaded from the model configuration. For example if prompt_name is "query" and the prompts is {"query": "query: {}", ...}, then the sentence “What is the capital of France?” will be encoded as “query: What is the capital of France?”. If prompt is also set, this argument is ignored.

  • prompt – The prompt to use for encoding. For example, if the prompt is "query: {}", then the sentence “What is the capital of France?” will be encoded as “query: What is the capital of France?”. If prompt is set, prompt_name is ignored.

  • batch_size – Encode sentences with batch size

  • chunk_size – Sentences are chunked and sent to the individual processes. If none, it determine a sensible size.

  • normalize_embeddings – Whether to normalize returned vectors to have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used.

Returns

2d numpy array with shape [num_inputs, output_dimension]

evaluate(evaluator: sentence_transformers.evaluation.SentenceEvaluator.SentenceEvaluator, output_path: Optional[str] = None)¶

Evaluate the model

Parameters
  • evaluator – the evaluator

  • output_path – the evaluator can write the results to this path

fit(train_objectives: Iterable[Tuple[torch.utils.data.dataloader.DataLoader, torch.nn.modules.module.Module]], evaluator: Optional[sentence_transformers.evaluation.SentenceEvaluator.SentenceEvaluator] = None, epochs: int = 1, steps_per_epoch=None, scheduler: str = 'WarmupLinear', warmup_steps: int = 10000, optimizer_class: Type[torch.optim.optimizer.Optimizer] = <class 'torch.optim.adamw.AdamW'>, optimizer_params: Dict[str, object] = {'lr': 2e-05}, weight_decay: float = 0.01, evaluation_steps: int = 0, output_path: Optional[str] = None, save_best_model: bool = True, max_grad_norm: float = 1, use_amp: bool = False, callback: Optional[Callable[[float, int, int], None]] = None, show_progress_bar: bool = True, checkpoint_path: Optional[str] = None, checkpoint_save_steps: int = 500, checkpoint_save_total_limit: int = 0)¶

Train the model with the given training objective Each training objective is sampled in turn for one batch. We sample only as many batches from each objective as there are in the smallest one to make sure of equal training with each dataset.

Parameters
  • train_objectives – Tuples of (DataLoader, LossFunction). Pass more than one for multi-task learning

  • evaluator – An evaluator (sentence_transformers.evaluation) evaluates the model performance during training on held-out dev data. It is used to determine the best model that is saved to disc.

  • epochs – Number of epochs for training

  • steps_per_epoch – Number of training steps per epoch. If set to None (default), one epoch is equal the DataLoader size from train_objectives.

  • scheduler – Learning rate scheduler. Available schedulers: constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts

  • warmup_steps – Behavior depends on the scheduler. For WarmupLinear (default), the learning rate is increased from o up to the maximal learning rate. After these many training steps, the learning rate is decreased linearly back to zero.

  • optimizer_class – Optimizer

  • optimizer_params – Optimizer parameters

  • weight_decay – Weight decay for model parameters

  • evaluation_steps – If > 0, evaluate the model using evaluator after each number of training steps

  • output_path – Storage path for the model and evaluation files

  • save_best_model – If true, the best model (according to evaluator) is stored at output_path

  • max_grad_norm – Used for gradient normalization.

  • use_amp – Use Automatic Mixed Precision (AMP). Only for Pytorch >= 1.6.0

  • callback – Callback function that is invoked after each evaluation. It must accept the following three parameters in this order: score, epoch, steps

  • show_progress_bar – If True, output a tqdm progress bar

  • checkpoint_path – Folder to save checkpoints during training

  • checkpoint_save_steps – Will save a checkpoint after so many steps

  • checkpoint_save_total_limit – Total number of checkpoints to store

get_max_seq_length()¶

Returns the maximal sequence length for input the model accepts. Longer inputs will be truncated

get_sentence_embedding_dimension()Optional[int]¶
Returns

The number of dimensions in the output of encode. If it’s not known, it’s None.

property max_seq_length¶

Property to get the maximal input sequence length for the model. Longer inputs will be truncated.

push_to_hub(repo_id: str, token: Optional[str] = None, private: Optional[bool] = None, safe_serialization: bool = True, commit_message: str = 'Add new SentenceTransformer model.', local_model_path: Optional[str] = None, exist_ok: bool = False, replace_model_card: bool = False, train_datasets: Optional[List[str]] = None)str¶

Uploads all elements of this Sentence Transformer to a new HuggingFace Hub repository.

Parameters
  • repo_id – Repository name for your model in the Hub, including the user or organization.

  • token – An authentication token (See https://huggingface.co/settings/token)

  • private – Set to true, for hosting a private model

  • safe_serialization – If true, save the model using safetensors. If false, save the model the traditional PyTorch way

  • commit_message – Message to commit while pushing.

  • local_model_path – Path of the model locally. If set, this file path will be uploaded. Otherwise, the current model will be uploaded

  • exist_ok – If true, saving to an existing repository is OK. If false, saving only to a new repository is possible

  • replace_model_card – If true, replace an existing model card in the hub with the automatically created model card

  • train_datasets – Datasets used to train the model. If set, the datasets will be added to the model card in the Hub.

Returns

The url of the commit of your model in the repository on the Hugging Face Hub.

save(path: str, model_name: Optional[str] = None, create_model_card: bool = True, train_datasets: Optional[List[str]] = None, safe_serialization: bool = True)¶

Saves all elements for this seq. sentence embedder into different sub-folders

Parameters
  • path – Path on disc

  • model_name – Optional model name

  • create_model_card – If True, create a README.md with basic information about this model

  • train_datasets – Optional list with the names of the datasets used to to train the model

  • safe_serialization – If true, save the model using safetensors. If false, save the model the traditional PyTorch way

set_pooling_include_prompt(include_prompt: bool)None¶

Sets the include_prompt attribute in the pooling layer in the model, if there is one.

Parameters

include_prompt – Whether to include the prompt in the pooling layer.

smart_batching_collate(batch: List[InputExample])Tuple[List[Dict[str, torch.Tensor]], torch.Tensor]¶

Transforms a batch from a SmartBatchingDataset to a batch of tensors for the model Here, batch is a list of InputExample instances: [InputExample(…), …]

Parameters

batch – a batch from a SmartBatchingDataset

Returns

a batch of tensors for the model

start_multi_process_pool(target_devices: Optional[List[str]] = None)¶

Starts multi process to process the encoding with several, independent processes. This method is recommended if you want to encode on multiple GPUs or CPUs. It is advised to start only one process per GPU. This method works together with encode_multi_process and stop_multi_process_pool.

Parameters

target_devices – PyTorch target devices, e.g. [“cuda:0”, “cuda:1”, …], [“npu:0”, “npu:1”, …] or [“cpu”, “cpu”, “cpu”, “cpu”]. If target_devices is None and CUDA/NPU is available, then all available CUDA/NPU devices will be used. If target_devices is None and CUDA/NPU is not available, then 4 CPU devices will be used.

Returns

Returns a dict with the target processes, an input queue and and output queue.

static stop_multi_process_pool(pool)¶

Stops all processes started with start_multi_process_pool

tokenize(texts: Union[List[str], List[Dict], List[Tuple[str, str]]])¶

Tokenizes the texts

property tokenizer¶

Property to get the tokenizer that is used by this model

truncate_sentence_embeddings(truncate_dim: Optional[int])¶

In this context, model.encode outputs sentence embeddings truncated at dimension truncate_dim.

This may be useful when you are using the same model for different applications where different dimensions are needed.

Parameters

truncate_dim – The dimension to truncate sentence embeddings to. None does no truncation.

Example:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("model-name")

with model.truncate_sentence_embeddings(truncate_dim=16):
    embeddings_truncated = model.encode(["hello there", "hiya"])
assert embeddings_truncated.shape[-1] == 16