SentenceTransformer¶
SentenceTransformer¶
-
class
sentence_transformers.
SentenceTransformer
(model_name_or_path: Optional[str, None] = None, modules: Optional[collections.abc.Iterable[torch.nn.modules.module.Module], None] = None, device: Optional[str, None] = None, prompts: Optional[dict, None] = None, default_prompt_name: Optional[str, None] = None, similarity_fn_name: Optional[Union[str, sentence_transformers.similarity_functions.SimilarityFunction]] = None, cache_folder: Optional[str, None] = None, trust_remote_code: bool = False, revision: Optional[str, None] = None, local_files_only: bool = False, token: Optional[Union[bool, str]] = None, use_auth_token: Optional[Union[bool, str]] = None, truncate_dim: Optional[int, None] = None, model_kwargs: Optional[dict, None] = None, tokenizer_kwargs: Optional[dict, None] = None, config_kwargs: Optional[dict, None] = None, model_card_data: Optional[sentence_transformers.model_card.SentenceTransformerModelCardData, None] = None, backend: Literal[torch, onnx, openvino] = 'torch')[source]¶ Loads or creates a SentenceTransformer model that can be used to map sentences / text to embeddings.
- Parameters
model_name_or_path (str, optional) – If it is a filepath on disc, it loads the model from that path. If it is not a path, it first tries to download a pre-trained SentenceTransformer model. If that fails, tries to construct a model from the Hugging Face Hub with that name.
modules (Iterable[nn.Module], optional) – A list of torch Modules that should be called sequentially, can be used to create custom SentenceTransformer models from scratch.
device (str, optional) – Device (like “cuda”, “cpu”, “mps”, “npu”) that should be used for computation. If None, checks if a GPU can be used.
prompts (Dict[str, str], optional) – A dictionary with prompts for the model. The key is the prompt name, the value is the prompt text. The prompt text will be prepended before any text to encode. For example: {“query”: “query: “, “passage”: “passage: “} or {“clustering”: “Identify the main category based on the titles in “}.
default_prompt_name (str, optional) – The name of the prompt that should be used by default. If not set, no prompt will be applied.
similarity_fn_name (str or SimilarityFunction, optional) – The name of the similarity function to use. Valid options are “cosine”, “dot”, “euclidean”, and “manhattan”. If not set, it is automatically set to “cosine” if similarity or similarity_pairwise are called while model.similarity_fn_name is still None.
cache_folder (str, optional) – Path to store models. Can also be set by the SENTENCE_TRANSFORMERS_HOME environment variable.
trust_remote_code (bool, optional) – Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set to True for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine.
revision (str, optional) – The specific model version to use. It can be a branch name, a tag name, or a commit id, for a stored model on Hugging Face.
local_files_only (bool, optional) – Whether or not to only look at local files (i.e., do not try to download the model).
token (bool or str, optional) – Hugging Face authentication token to download private models.
use_auth_token (bool or str, optional) – Deprecated argument. Please use token instead.
truncate_dim (int, optional) – The dimension to truncate sentence embeddings to. None does no truncation. Truncation is only applicable during inference when
SentenceTransformer.encode()
is called.model_kwargs (Dict[str, Any], optional) –
Additional model configuration parameters to be passed to the Hugging Face Transformers model. Particularly useful options are:
torch_dtype
: Override the default torch.dtype and load the model under a specific dtype. The different options are:1.
torch.float16
,torch.bfloat16
ortorch.float
: load in a specifieddtype
, ignoring the model’sconfig.torch_dtype
if one exists. If not specified - the model will get loaded intorch.float
(fp32).2.
"auto"
- Atorch_dtype
entry in theconfig.json
file of the model will be attempted to be used. If this entry isn’t found then next check thedtype
of the first weight in the checkpoint that’s of a floating point type and use that asdtype
. This will load the model using thedtype
it was saved in at the end of the training. It can’t be used as an indicator of how the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32.attn_implementation
: The attention implementation to use in the model (if relevant). Can be any of “eager” (manual implementation of the attention), “sdpa” (using F.scaled_dot_product_attention), or “flash_attention_2” (using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual “eager” implementation.provider
: If backend is “onnx”, this is the provider to use for inference, for example “CPUExecutionProvider”, “CUDAExecutionProvider”, etc. See https://onnxruntime.ai/docs/execution-providers/ for all ONNX execution providers.file_name
: If backend is “onnx” or “openvino”, this is the file name to load, useful for loading optimized or quantized ONNX or OpenVINO models.export
: If backend is “onnx” or “openvino”, then this is a boolean flag specifying whether this model should be exported to the backend. If not specified, the model will be exported only if the model repository or directory does not already contain an exported model.
See the PreTrainedModel.from_pretrained documentation for more details.
tokenizer_kwargs (Dict[str, Any], optional) – Additional tokenizer configuration parameters to be passed to the Hugging Face Transformers tokenizer. See the AutoTokenizer.from_pretrained documentation for more details.
config_kwargs (Dict[str, Any], optional) – Additional model configuration parameters to be passed to the Hugging Face Transformers config. See the AutoConfig.from_pretrained documentation for more details.
model_card_data (
SentenceTransformerModelCardData
, optional) – A model card data object that contains information about the model. This is used to generate a model card when saving the model. If not set, a default model card data object is created.backend (str) – The backend to use for inference. Can be one of “torch” (default), “onnx”, or “openvino”. See https://sbert.net/docs/sentence_transformer/usage/efficiency.html for benchmarking information on the different backends.
Example
from sentence_transformers import SentenceTransformer # Load a pre-trained SentenceTransformer model model = SentenceTransformer('all-mpnet-base-v2') # Encode some texts sentences = [ "The weather is lovely today.", "It's so sunny outside!", "He drove to the stadium.", ] embeddings = model.encode(sentences) print(embeddings.shape) # (3, 768) # Get the similarity scores between all sentences similarities = model.similarity(embeddings, embeddings) print(similarities) # tensor([[1.0000, 0.6817, 0.0492], # [0.6817, 1.0000, 0.0421], # [0.0492, 0.0421, 1.0000]])
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
active_adapters
() → list[source]¶ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT official documentation: https://huggingface.co/docs/peft
Gets the current active adapters of the model. In case of multi-adapter inference (combining multiple adapters for inference) returns the list of all active adapters so that users can deal with them accordingly.
For previous PEFT versions (that does not support multi-adapter inference), module.active_adapter will return a single string.
-
add_adapter
(*args, **kwargs) → None[source]¶ Adds a fresh new adapter to the current model for training purposes. If no adapter name is passed, a default name is assigned to the adapter to follow the convention of PEFT library (in PEFT we use “default” as the default adapter name).
Requires peft as a backend to load the adapter weights and the underlying model to be compatible with PEFT.
- Parameters
*args – Positional arguments to pass to the underlying AutoModel add_adapter function. More information can be found in the transformers documentation https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.add_adapter
**kwargs – Keyword arguments to pass to the underlying AutoModel add_adapter function. More information can be found in the transformers documentation https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.add_adapter
-
bfloat16
() → T¶ Casts all floating point parameters and buffers to
bfloat16
datatype.Note
This method modifies the module in-place.
- Returns
self
- Return type
Module
-
compile
(*args, **kwargs)¶ Compile this Module’s forward using
torch.compile()
.This Module’s __call__ method is compiled and all arguments are passed as-is to
torch.compile()
.See
torch.compile()
for details on the arguments for this function.
-
cpu
() → T¶ Moves all model parameters and buffers to the CPU.
Note
This method modifies the module in-place.
- Returns
self
- Return type
Module
-
cuda
(device: Optional[Union[int, torch.device]] = None) → T¶ Moves all model parameters and buffers to the GPU.
This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized.
Note
This method modifies the module in-place.
- Parameters
device (int, optional) – if specified, all parameters will be copied to that device
- Returns
self
- Return type
Module
-
property
device
¶ Get torch.device from module, assuming that the whole module has one device. In case there are no PyTorch parameters, fall back to CPU.
-
disable_adapters
() → None[source]¶ Disable all adapters that are attached to the model. This leads to inferring with the base model only.
-
double
() → T¶ Casts all floating point parameters and buffers to
double
datatype.Note
This method modifies the module in-place.
- Returns
self
- Return type
Module
-
enable_adapters
() → None[source]¶ Enable adapters that are attached to the model. The model will use self.active_adapter()
-
encode_multi_process
(sentences: list, pool: dict, prompt_name: Optional[str, None] = None, prompt: Optional[str, None] = None, batch_size: int = 32, chunk_size: Optional[int, None] = None, show_progress_bar: Optional[bool, None] = None, precision: Literal[float32, int8, uint8, binary, ubinary] = 'float32', normalize_embeddings: bool = False) → numpy.ndarray[source]¶ Encodes a list of sentences using multiple processes and GPUs via
SentenceTransformer.encode
. The sentences are chunked into smaller packages and sent to individual processes, which encode them on different GPUs or CPUs. This method is only suitable for encoding large sets of sentences.- Parameters
sentences (List[str]) – List of sentences to encode.
pool (Dict[Literal["input", "output", "processes"], Any]) – A pool of workers started with
SentenceTransformer.start_multi_process_pool
.prompt_name (Optional[str], optional) – The name of the prompt to use for encoding. Must be a key in the prompts dictionary, which is either set in the constructor or loaded from the model configuration. For example if
prompt_name
is “query” and theprompts
is {“query”: “query: “, …}, then the sentence “What is the capital of France?” will be encoded as “query: What is the capital of France?” because the sentence is appended to the prompt. Ifprompt
is also set, this argument is ignored. Defaults to None.prompt (Optional[str], optional) – The prompt to use for encoding. For example, if the prompt is “query: “, then the sentence “What is the capital of France?” will be encoded as “query: What is the capital of France?” because the sentence is appended to the prompt. If
prompt
is set,prompt_name
is ignored. Defaults to None.batch_size (int) – Encode sentences with batch size. (default: 32)
chunk_size (int) – Sentences are chunked and sent to the individual processes. If None, it determines a sensible size. Defaults to None.
show_progress_bar (bool, optional) – Whether to output a progress bar when encode sentences. Defaults to None.
precision (Literal["float32", "int8", "uint8", "binary", "ubinary"]) – The precision to use for the embeddings. Can be “float32”, “int8”, “uint8”, “binary”, or “ubinary”. All non-float32 precisions are quantized embeddings. Quantized embeddings are smaller in size and faster to compute, but may have lower accuracy. They are useful for reducing the size of the embeddings of a corpus for semantic search, among other tasks. Defaults to “float32”.
normalize_embeddings (bool) – Whether to normalize returned vectors to have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used. Defaults to False.
- Returns
A 2D numpy array with shape [num_inputs, output_dimension].
- Return type
np.ndarray
Example
from sentence_transformers import SentenceTransformer def main(): model = SentenceTransformer("all-mpnet-base-v2") sentences = ["The weather is so nice!", "It's so sunny outside.", "He's driving to the movie theater.", "She's going to the cinema."] * 1000 pool = model.start_multi_process_pool() embeddings = model.encode_multi_process(sentences, pool) model.stop_multi_process_pool(pool) print(embeddings.shape) # => (4000, 768) if __name__ == "__main__": main()
-
eval
() → T¶ Sets the module in evaluation mode.
This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g.
Dropout
,BatchNorm
, etc.This is equivalent with
self.train(False)
.See Locally disabling gradient computation for a comparison between .eval() and several similar mechanisms that may be confused with it.
- Returns
self
- Return type
Module
-
evaluate
(evaluator: SentenceEvaluator, output_path: str = None) → dict[str, float] | float[source]¶ Evaluate the model based on an evaluator
- Parameters
evaluator (SentenceEvaluator) – The evaluator used to evaluate the model.
output_path (str, optional) – The path where the evaluator can write the results. Defaults to None.
- Returns
The evaluation results.
-
fit
(train_objectives: Iterable[tuple[DataLoader, nn.Module]], evaluator: SentenceEvaluator = None, epochs: int = 1, steps_per_epoch=None, scheduler: str = 'WarmupLinear', warmup_steps: int = 10000, optimizer_class: type[Optimizer] = <class 'torch.optim.adamw.AdamW'>, optimizer_params: dict[str, object] = {'lr': 2e-05}, weight_decay: float = 0.01, evaluation_steps: int = 0, output_path: str = None, save_best_model: bool = True, max_grad_norm: float = 1, use_amp: bool = False, callback: Callable[[float, int, int], None] = None, show_progress_bar: bool = True, checkpoint_path: str = None, checkpoint_save_steps: int = 500, checkpoint_save_total_limit: int = 0) → None[source]¶ Deprecated training method from before Sentence Transformers v3.0, it is recommended to use
SentenceTransformerTrainer
instead. This method usesSentenceTransformerTrainer
behind the scenes, but does not provide as much flexibility as the Trainer itself.This training approach uses a list of DataLoaders and Loss functions to train the model. Each DataLoader is sampled in turn for one batch. We sample only as many batches from each DataLoader as there are in the smallest one to make sure of equal training with each dataset, i.e. round robin sampling.
This method should produce equivalent results in v3.0+ as before v3.0, but if you encounter any issues with your existing training scripts, then you may wish to use
SentenceTransformer.old_fit
instead. That uses the old training method from before v3.0.- Parameters
train_objectives – Tuples of (DataLoader, LossFunction). Pass more than one for multi-task learning
evaluator – An evaluator (sentence_transformers.evaluation) evaluates the model performance during training on held- out dev data. It is used to determine the best model that is saved to disc.
epochs – Number of epochs for training
steps_per_epoch – Number of training steps per epoch. If set to None (default), one epoch is equal the DataLoader size from train_objectives.
scheduler – Learning rate scheduler. Available schedulers: constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts
warmup_steps – Behavior depends on the scheduler. For WarmupLinear (default), the learning rate is increased from o up to the maximal learning rate. After these many training steps, the learning rate is decreased linearly back to zero.
optimizer_class – Optimizer
optimizer_params – Optimizer parameters
weight_decay – Weight decay for model parameters
evaluation_steps – If > 0, evaluate the model using evaluator after each number of training steps
output_path – Storage path for the model and evaluation files
save_best_model – If true, the best model (according to evaluator) is stored at output_path
max_grad_norm – Used for gradient normalization.
use_amp – Use Automatic Mixed Precision (AMP). Only for Pytorch >= 1.6.0
callback – Callback function that is invoked after each evaluation. It must accept the following three parameters in this order: score, epoch, steps
show_progress_bar – If True, output a tqdm progress bar
checkpoint_path – Folder to save checkpoints during training
checkpoint_save_steps – Will save a checkpoint after so many steps
checkpoint_save_total_limit – Total number of checkpoints to store
-
float
() → T¶ Casts all floating point parameters and buffers to
float
datatype.Note
This method modifies the module in-place.
- Returns
self
- Return type
Module
-
get_adapter_state_dict
(*args, **kwargs) → dict[source]¶ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT official documentation: https://huggingface.co/docs/peft
Gets the adapter state dict that should only contain the weights tensors of the specified adapter_name adapter. If no adapter_name is passed, the active adapter is used.
- Parameters
*args – Positional arguments to pass to the underlying AutoModel get_adapter_state_dict function. More information can be found in the transformers documentation https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.get_adapter_state_dict
**kwargs – Keyword arguments to pass to the underlying AutoModel get_adapter_state_dict function. More information can be found in the transformers documentation https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.get_adapter_state_dict
-
get_backend
() → Literal[torch, onnx, openvino][source]¶ Return the backend used for inference, which can be one of “torch”, “onnx”, or “openvino”.
- Returns
The backend used for inference.
- Return type
str
-
get_max_seq_length
()[source]¶ Returns the maximal sequence length that the model accepts. Longer inputs will be truncated.
- Returns
The maximal sequence length that the model accepts, or None if it is not defined.
- Return type
Optional[int]
-
get_sentence_embedding_dimension
()[source]¶ Returns the number of dimensions in the output of
SentenceTransformer.encode
.- Returns
The number of dimensions in the output of encode. If it’s not known, it’s None.
- Return type
Optional[int]
-
half
() → T¶ Casts all floating point parameters and buffers to
half
datatype.Note
This method modifies the module in-place.
- Returns
self
- Return type
Module
-
load_adapter
(*args, **kwargs) → None[source]¶ Load adapter weights from file or remote Hub folder.” If you are not familiar with adapters and PEFT methods, we invite you to read more about them on PEFT official documentation: https://huggingface.co/docs/peft
Requires peft as a backend to load the adapter weights and the underlying model to be compatible with PEFT.
- Parameters
*args – Positional arguments to pass to the underlying AutoModel load_adapter function. More information can be found in the transformers documentation https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.load_adapter
**kwargs – Keyword arguments to pass to the underlying AutoModel load_adapter function. More information can be found in the transformers documentation https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.load_adapter
-
property
max_seq_length
¶ Returns the maximal input sequence length for the model. Longer inputs will be truncated.
- Returns
The maximal input sequence length.
- Return type
int
Example
from sentence_transformers import SentenceTransformer model = SentenceTransformer("all-mpnet-base-v2") print(model.max_seq_length) # => 384
-
old_fit
(train_objectives: Iterable[tuple[DataLoader, nn.Module]], evaluator: SentenceEvaluator = None, epochs: int = 1, steps_per_epoch=None, scheduler: str = 'WarmupLinear', warmup_steps: int = 10000, optimizer_class: type[Optimizer] = <class 'torch.optim.adamw.AdamW'>, optimizer_params: dict[str, object] = {'lr': 2e-05}, weight_decay: float = 0.01, evaluation_steps: int = 0, output_path: str = None, save_best_model: bool = True, max_grad_norm: float = 1, use_amp: bool = False, callback: Callable[[float, int, int], None] = None, show_progress_bar: bool = True, checkpoint_path: str = None, checkpoint_save_steps: int = 500, checkpoint_save_total_limit: int = 0) → None[source]¶ Deprecated training method from before Sentence Transformers v3.0, it is recommended to use
sentence_transformers.trainer.SentenceTransformerTrainer
instead. This method should only be used if you encounter issues with your existing training scripts after upgrading to v3.0+.This training approach uses a list of DataLoaders and Loss functions to train the model. Each DataLoader is sampled in turn for one batch. We sample only as many batches from each DataLoader as there are in the smallest one to make sure of equal training with each dataset, i.e. round robin sampling.
- Parameters
train_objectives – Tuples of (DataLoader, LossFunction). Pass more than one for multi-task learning
evaluator – An evaluator (sentence_transformers.evaluation) evaluates the model performance during training on held- out dev data. It is used to determine the best model that is saved to disc.
epochs – Number of epochs for training
steps_per_epoch – Number of training steps per epoch. If set to None (default), one epoch is equal the DataLoader size from train_objectives.
scheduler – Learning rate scheduler. Available schedulers: constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts
warmup_steps – Behavior depends on the scheduler. For WarmupLinear (default), the learning rate is increased from o up to the maximal learning rate. After these many training steps, the learning rate is decreased linearly back to zero.
optimizer_class – Optimizer
optimizer_params – Optimizer parameters
weight_decay – Weight decay for model parameters
evaluation_steps – If > 0, evaluate the model using evaluator after each number of training steps
output_path – Storage path for the model and evaluation files
save_best_model – If true, the best model (according to evaluator) is stored at output_path
max_grad_norm – Used for gradient normalization.
use_amp – Use Automatic Mixed Precision (AMP). Only for Pytorch >= 1.6.0
callback – Callback function that is invoked after each evaluation. It must accept the following three parameters in this order: score, epoch, steps
show_progress_bar – If True, output a tqdm progress bar
checkpoint_path – Folder to save checkpoints during training
checkpoint_save_steps – Will save a checkpoint after so many steps
checkpoint_save_total_limit – Total number of checkpoints to store
-
push_to_hub
(repo_id: str, token: Optional[str, None] = None, private: Optional[bool, None] = None, safe_serialization: bool = True, commit_message: Optional[str, None] = None, local_model_path: Optional[str, None] = None, exist_ok: bool = False, replace_model_card: bool = False, train_datasets: Optional[list, None] = None, revision: Optional[str, None] = None, create_pr: bool = False) → str[source]¶ Uploads all elements of this Sentence Transformer to a new HuggingFace Hub repository.
- Parameters
repo_id (str) – Repository name for your model in the Hub, including the user or organization.
token (str, optional) – An authentication token (See https://huggingface.co/settings/token)
private (bool, optional) – Set to true, for hosting a private model
safe_serialization (bool, optional) – If true, save the model using safetensors. If false, save the model the traditional PyTorch way
commit_message (str, optional) – Message to commit while pushing.
local_model_path (str, optional) – Path of the model locally. If set, this file path will be uploaded. Otherwise, the current model will be uploaded
exist_ok (bool, optional) – If true, saving to an existing repository is OK. If false, saving only to a new repository is possible
replace_model_card (bool, optional) – If true, replace an existing model card in the hub with the automatically created model card
train_datasets (List[str], optional) – Datasets used to train the model. If set, the datasets will be added to the model card in the Hub.
revision (str, optional) – Branch to push the uploaded files to
create_pr (bool, optional) – If True, create a pull request instead of pushing directly to the main branch
- Returns
The url of the commit of your model in the repository on the Hugging Face Hub.
- Return type
str
-
save
(path: str, model_name: Optional[str, None] = None, create_model_card: bool = True, train_datasets: Optional[list, None] = None, safe_serialization: bool = True) → None[source]¶ Saves a model and its configuration files to a directory, so that it can be loaded with
SentenceTransformer(path)
again.- Parameters
path (str) – Path on disc where the model will be saved.
model_name (str, optional) – Optional model name.
create_model_card (bool, optional) – If True, create a README.md with basic information about this model.
train_datasets (List[str], optional) – Optional list with the names of the datasets used to train the model.
safe_serialization (bool, optional) – If True, save the model using safetensors. If False, save the model the traditional (but unsafe) PyTorch way.
-
save_pretrained
(path: str, model_name: Optional[str, None] = None, create_model_card: bool = True, train_datasets: Optional[list, None] = None, safe_serialization: bool = True) → None[source]¶ Saves a model and its configuration files to a directory, so that it can be loaded with
SentenceTransformer(path)
again.- Parameters
path (str) – Path on disc where the model will be saved.
model_name (str, optional) – Optional model name.
create_model_card (bool, optional) – If True, create a README.md with basic information about this model.
train_datasets (List[str], optional) – Optional list with the names of the datasets used to train the model.
safe_serialization (bool, optional) – If True, save the model using safetensors. If False, save the model the traditional (but unsafe) PyTorch way.
-
set_adapter
(*args, **kwargs) → None[source]¶ Sets a specific adapter by forcing the model to use a that adapter and disable the other adapters.
- Parameters
*args – Positional arguments to pass to the underlying AutoModel set_adapter function. More information can be found in the transformers documentation https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.set_adapter
**kwargs – Keyword arguments to pass to the underlying AutoModel set_adapter function. More information can be found in the transformers documentation https://huggingface.co/docs/transformers/main/en/peft#transformers.integrations.PeftAdapterMixin.set_adapter
-
set_pooling_include_prompt
(include_prompt: bool) → None[source]¶ Sets the include_prompt attribute in the pooling layer in the model, if there is one.
This is useful for INSTRUCTOR models, as the prompt should be excluded from the pooling strategy for these models.
- Parameters
include_prompt (bool) – Whether to include the prompt in the pooling layer.
- Returns
None
-
property
similarity
¶ Compute the similarity between two collections of embeddings. The output will be a matrix with the similarity scores between all embeddings from the first parameter and all embeddings from the second parameter. This differs from similarity_pairwise which computes the similarity between each pair of embeddings.
- Parameters
embeddings1 (Union[Tensor, ndarray]) – [num_embeddings_1, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
embeddings2 (Union[Tensor, ndarray]) – [num_embeddings_2, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
- Returns
A [num_embeddings_1, num_embeddings_2]-shaped torch tensor with similarity scores.
- Return type
Tensor
Example
>>> model = SentenceTransformer("all-mpnet-base-v2") >>> sentences = [ ... "The weather is so nice!", ... "It's so sunny outside.", ... "He's driving to the movie theater.", ... "She's going to the cinema.", ... ] >>> embeddings = model.encode(sentences, normalize_embeddings=True) >>> model.similarity(embeddings, embeddings) tensor([[1.0000, 0.7235, 0.0290, 0.1309], [0.7235, 1.0000, 0.0613, 0.1129], [0.0290, 0.0613, 1.0000, 0.5027], [0.1309, 0.1129, 0.5027, 1.0000]]) >>> model.similarity_fn_name "cosine" >>> model.similarity_fn_name = "euclidean" >>> model.similarity(embeddings, embeddings) tensor([[-0.0000, -0.7437, -1.3935, -1.3184], [-0.7437, -0.0000, -1.3702, -1.3320], [-1.3935, -1.3702, -0.0000, -0.9973], [-1.3184, -1.3320, -0.9973, -0.0000]])
-
property
similarity_fn_name
¶ Return the name of the similarity function used by
SentenceTransformer.similarity()
andSentenceTransformer.similarity_pairwise()
.- Returns
- The name of the similarity function. Can be None if not set, in which case it will
default to “cosine” when first called.
- Return type
Optional[str]
Example
>>> model = SentenceTransformer("multi-qa-mpnet-base-dot-v1") >>> model.similarity_fn_name 'dot'
-
property
similarity_pairwise
¶ Compute the similarity between two collections of embeddings. The output will be a vector with the similarity scores between each pair of embeddings.
- Parameters
embeddings1 (Union[Tensor, ndarray]) – [num_embeddings, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
embeddings2 (Union[Tensor, ndarray]) – [num_embeddings, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
- Returns
A [num_embeddings]-shaped torch tensor with pairwise similarity scores.
- Return type
Tensor
Example
>>> model = SentenceTransformer("all-mpnet-base-v2") >>> sentences = [ ... "The weather is so nice!", ... "It's so sunny outside.", ... "He's driving to the movie theater.", ... "She's going to the cinema.", ... ] >>> embeddings = model.encode(sentences, normalize_embeddings=True) >>> model.similarity_pairwise(embeddings[::2], embeddings[1::2]) tensor([0.7235, 0.5027]) >>> model.similarity_fn_name "cosine" >>> model.similarity_fn_name = "euclidean" >>> model.similarity_pairwise(embeddings[::2], embeddings[1::2]) tensor([-0.7437, -0.9973])
-
smart_batching_collate
(batch: list[InputExample]) → tuple[list[dict[str, Tensor]], Tensor][source]¶ Transforms a batch from a SmartBatchingDataset to a batch of tensors for the model Here, batch is a list of InputExample instances: [InputExample(…), …]
- Parameters
batch – a batch from a SmartBatchingDataset
- Returns
a batch of tensors for the model
-
start_multi_process_pool
(target_devices: Optional[list, None] = None) → dict[source]¶ Starts a multi-process pool to process the encoding with several independent processes via
SentenceTransformer.encode_multi_process
.This method is recommended if you want to encode on multiple GPUs or CPUs. It is advised to start only one process per GPU. This method works together with encode_multi_process and stop_multi_process_pool.
- Parameters
target_devices (List[str], optional) – PyTorch target devices, e.g. [“cuda:0”, “cuda:1”, …], [“npu:0”, “npu:1”, …], or [“cpu”, “cpu”, “cpu”, “cpu”]. If target_devices is None and CUDA/NPU is available, then all available CUDA/NPU devices will be used. If target_devices is None and CUDA/NPU is not available, then 4 CPU devices will be used.
- Returns
A dictionary with the target processes, an input queue, and an output queue.
- Return type
Dict[str, Any]
-
static
stop_multi_process_pool
(pool: dict) → None[source]¶ Stops all processes started with start_multi_process_pool.
- Parameters
pool (Dict[str, object]) – A dictionary containing the input queue, output queue, and process list.
- Returns
None
-
to
(*args, **kwargs)¶ Moves and/or casts the parameters and buffers.
This can be called as
-
to
(device=None, dtype=None, non_blocking=False)
-
to
(dtype, non_blocking=False)
-
to
(tensor, non_blocking=False)
-
to
(memory_format=torch.channels_last)
Its signature is similar to
torch.Tensor.to()
, but only accepts floating point or complexdtype
s. In addition, this method will only cast the floating point or complex parameters and buffers todtype
(if given). The integral parameters and buffers will be moveddevice
, if that is given, but with dtypes unchanged. Whennon_blocking
is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.See below for examples.
Note
This method modifies the module in-place.
- Parameters
device (
torch.device
) – the desired device of the parameters and buffers in this moduledtype (
torch.dtype
) – the desired floating point or complex dtype of the parameters and buffers in this moduletensor (torch.Tensor) – Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module
memory_format (
torch.memory_format
) – the desired memory format for 4D parameters and buffers in this module (keyword only argument)
- Returns
self
- Return type
Module
Examples:
>>> # xdoctest: +IGNORE_WANT("non-deterministic") >>> linear = nn.Linear(2, 2) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]]) >>> linear.to(torch.double) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1913, -0.3420], [-0.5113, -0.2325]], dtype=torch.float64) >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1) >>> gpu1 = torch.device("cuda:1") >>> linear.to(gpu1, dtype=torch.half, non_blocking=True) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1') >>> cpu = torch.device("cpu") >>> linear.to(cpu) Linear(in_features=2, out_features=2, bias=True) >>> linear.weight Parameter containing: tensor([[ 0.1914, -0.3420], [-0.5112, -0.2324]], dtype=torch.float16) >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble) >>> linear.weight Parameter containing: tensor([[ 0.3741+0.j, 0.2382+0.j], [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128) >>> linear(torch.ones(3, 2, dtype=torch.cdouble)) tensor([[0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j], [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
-
-
tokenize
()[source]¶ Tokenizes the texts.
- Parameters
texts (Union[List[str], List[Dict], List[Tuple[str, str]]]) – A list of texts to be tokenized.
- Returns
- A dictionary of tensors with the tokenized texts. Common keys are “input_ids”,
”attention_mask”, and “token_type_ids”.
- Return type
Dict[str, Tensor]
-
property
tokenizer
¶ Property to get the tokenizer that is used by this model
-
train
(mode: bool = True) → T¶ Sets the module in training mode.
This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g.
Dropout
,BatchNorm
, etc.- Parameters
mode (bool) – whether to set training mode (
True
) or evaluation mode (False
). Default:True
.- Returns
self
- Return type
Module
-
truncate_sentence_embeddings
()[source]¶ In this context,
SentenceTransformer.encode
outputs sentence embeddings truncated at dimensiontruncate_dim
.This may be useful when you are using the same model for different applications where different dimensions are needed.
- Parameters
truncate_dim (int, optional) – The dimension to truncate sentence embeddings to.
None
does no truncation.
Example
from sentence_transformers import SentenceTransformer model = SentenceTransformer("all-mpnet-base-v2") with model.truncate_sentence_embeddings(truncate_dim=16): embeddings_truncated = model.encode(["hello there", "hiya"]) assert embeddings_truncated.shape[-1] == 16
SentenceTransformerModelCardData¶
-
class
sentence_transformers.model_card.
SentenceTransformerModelCardData
[source]¶ A dataclass storing data used in the model card.
- Parameters
language (Optional[Union[str, List[str]]]) – The model language, either a string or a list, e.g. “en” or [“en”, “de”, “nl”]
license (Optional[str]) – The license of the model, e.g. “apache-2.0”, “mit”, or “cc-by-nc-sa-4.0”
model_name (Optional[str]) – The pretty name of the model, e.g. “SentenceTransformer based on microsoft/mpnet-base”.
model_id (Optional[str]) – The model ID when pushing the model to the Hub, e.g. “tomaarsen/sbert-mpnet-base-allnli”.
train_datasets (List[Dict[str, str]]) – A list of the names and/or Hugging Face dataset IDs of the training datasets. e.g. [{“name”: “SNLI”, “id”: “stanfordnlp/snli”}, {“name”: “MultiNLI”, “id”: “nyu-mll/multi_nli”}, {“name”: “STSB”}]
eval_datasets (List[Dict[str, str]]) – A list of the names and/or Hugging Face dataset IDs of the evaluation datasets. e.g. [{“name”: “SNLI”, “id”: “stanfordnlp/snli”}, {“id”: “mteb/stsbenchmark-sts”}]
task_name (str) – The human-readable task the model is trained on, e.g. “semantic textual similarity, semantic search, paraphrase mining, text classification, clustering, and more”.
tags (Optional[List[str]]) – A list of tags for the model, e.g. [“sentence-transformers”, “sentence-similarity”, “feature-extraction”].
Tip
Install codecarbon to automatically track carbon emission usage and include it in your model cards.
Example:
>>> model = SentenceTransformer( ... "microsoft/mpnet-base", ... model_card_data=SentenceTransformerModelCardData( ... model_id="tomaarsen/sbert-mpnet-base-allnli", ... train_datasets=[{"name": "SNLI", "id": "stanfordnlp/snli"}, {"name": "MultiNLI", "id": "nyu-mll/multi_nli"}], ... eval_datasets=[{"name": "SNLI", "id": "stanfordnlp/snli"}, {"name": "MultiNLI", "id": "nyu-mll/multi_nli"}], ... license="apache-2.0", ... language="en", ... ), ... )
SimilarityFunction¶
-
class
sentence_transformers.
SimilarityFunction
(value)[source]¶ Enum class for supported similarity functions. The following functions are supported:
SimilarityFunction.COSINE
("cosine"
): Cosine similaritySimilarityFunction.DOT_PRODUCT
("dot"
,dot_product
): Dot product similaritySimilarityFunction.EUCLIDEAN
("euclidean"
): Euclidean distanceSimilarityFunction.MANHATTAN
("manhattan"
): Manhattan distance
-
static
possible_values
() → list[source]¶ Returns a list of possible values for the SimilarityFunction enum.
- Returns
A list of possible values for the SimilarityFunction enum.
- Return type
list
Example
>>> possible_values = SimilarityFunction.possible_values() >>> possible_values ['cosine', 'dot', 'euclidean', 'manhattan']
-
static
to_similarity_fn
()[source]¶ Converts a similarity function name or enum value to the corresponding similarity function.
- Parameters
similarity_function (Union[str, SimilarityFunction]) – The name or enum value of the similarity function.
- Returns
The corresponding similarity function.
- Return type
Callable[[Union[Tensor, ndarray], Union[Tensor, ndarray]], Tensor]
- Raises
ValueError – If the provided function is not supported.
Example
>>> similarity_fn = SimilarityFunction.to_similarity_fn("cosine") >>> similarity_scores = similarity_fn(embeddings1, embeddings2) >>> similarity_scores tensor([[0.3952, 0.0554], [0.0992, 0.1570]])
-
static
to_similarity_pairwise_fn
()[source]¶ Converts a similarity function into a pairwise similarity function.
The pairwise similarity function returns the diagonal vector from the similarity matrix, i.e. it only computes the similarity(a[i], b[i]) for each i in the range of the input tensors, rather than computing the similarity between all pairs of a and b.
- Parameters
similarity_function (Union[str, SimilarityFunction]) – The name or enum value of the similarity function.
- Returns
The pairwise similarity function.
- Return type
Callable[[Union[Tensor, ndarray], Union[Tensor, ndarray]], Tensor]
- Raises
ValueError – If the provided similarity function is not supported.
Example
>>> pairwise_fn = SimilarityFunction.to_similarity_pairwise_fn("cosine") >>> similarity_scores = pairwise_fn(embeddings1, embeddings2) >>> similarity_scores tensor([0.3952, 0.1570])