Models

sentence_transformers.models defines different building blocks, that can be used to create SentenceTransformer networks from scratch. For more details, see Training Overview.

Main Classes

class sentence_transformers.models.Transformer(model_name_or_path: str, max_seq_length: int | None = None, model_args: dict[str, Any] | None = None, tokenizer_args: dict[str, Any] | None = None, config_args: dict[str, Any] | None = None, cache_dir: str | None = None, do_lower_case: bool = False, tokenizer_name_or_path: str = None, backend: str = 'torch')[source]

Hugging Face AutoModel to generate token embeddings. Loads the correct class, e.g. BERT / RoBERTa etc.

Parameters
  • model_name_or_path – Hugging Face models name (https://huggingface.co/models)

  • max_seq_length – Truncate any inputs longer than max_seq_length

  • model_args – Keyword arguments passed to the Hugging Face Transformers model

  • tokenizer_args – Keyword arguments passed to the Hugging Face Transformers tokenizer

  • config_args – Keyword arguments passed to the Hugging Face Transformers config

  • cache_dir – Cache dir for Hugging Face Transformers to store/load models

  • do_lower_case – If true, lowercases the input (independent if the model is cased or not)

  • tokenizer_name_or_path – Name or path of the tokenizer. When None, then model_name_or_path is used

  • backend – Backend used for model inference. Can be torch, onnx, or openvino. Default is torch.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

class sentence_transformers.models.Pooling(word_embedding_dimension: int, pooling_mode: Optional[str] = None, pooling_mode_cls_token: bool = False, pooling_mode_max_tokens: bool = False, pooling_mode_mean_tokens: bool = True, pooling_mode_mean_sqrt_len_tokens: bool = False, pooling_mode_weightedmean_tokens: bool = False, pooling_mode_lasttoken: bool = False, include_prompt: bool = True)[source]

Performs pooling (max or mean) on the token embeddings.

Using pooling, it generates from a variable sized sentence a fixed sized sentence embedding. This layer also allows to use the CLS token if it is returned by the underlying word embedding model. You can concatenate multiple poolings together.

Parameters
  • word_embedding_dimension – Dimensions for the word embeddings

  • pooling_mode – Either “cls”, “lasttoken”, “max”, “mean”, “mean_sqrt_len_tokens”, or “weightedmean”. If set, overwrites the other pooling_mode_* settings

  • pooling_mode_cls_token – Use the first token (CLS token) as text representations

  • pooling_mode_max_tokens – Use max in each dimension over all tokens.

  • pooling_mode_mean_tokens – Perform mean-pooling

  • pooling_mode_mean_sqrt_len_tokens – Perform mean-pooling, but divide by sqrt(input_length).

  • pooling_mode_weightedmean_tokens – Perform (position) weighted mean pooling. See SGPT: GPT Sentence Embeddings for Semantic Search.

  • pooling_mode_lasttoken

    Perform last token pooling. See SGPT: GPT Sentence Embeddings for Semantic Search and Text and Code Embeddings by Contrastive Pre-Training.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

class sentence_transformers.models.Dense(in_features: int, out_features: int, bias: bool = True, activation_function=Tanh(), init_weight: Optional[torch.Tensor] = None, init_bias: Optional[torch.Tensor] = None)[source]

Feed-forward function with activation function.

This layer takes a fixed-sized sentence embedding and passes it through a feed-forward layer. Can be used to generate deep averaging networks (DAN).

Parameters
  • in_features – Size of the input dimension

  • out_features – Output size

  • bias – Add a bias vector

  • activation_function – Pytorch activation function applied on output

  • init_weight – Initial value for the matrix of the linear layer

  • init_bias – Initial value for the bias of the linear layer

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Further Classes

class sentence_transformers.models.Asym(sub_modules: dict[str, list[nn.Module]], allow_empty_key: bool = True)[source]

This model allows to create asymmetric SentenceTransformer models, that apply different models depending on the specified input key.

In the below example, we create two different Dense models for ‘query’ and ‘doc’. Text that is passed as {‘query’: ‘My query’} will be passed along along the first Dense model, and text that will be passed as {‘doc’: ‘My document’} will use the other Dense model.

Note, that when you call encode(), that only inputs of the same type can be encoded. Mixed-Types cannot be encoded.

Example::

word_embedding_model = models.Transformer(model_name) pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) asym_model = models.Asym({‘query’: [models.Dense(word_embedding_model.get_word_embedding_dimension(), 128)], ‘doc’: [models.Dense(word_embedding_model.get_word_embedding_dimension(), 128)]}) model = SentenceTransformer(modules=[word_embedding_model, pooling_model, asym_model])

model.encode([{‘query’: ‘Q1’}, {‘query’: ‘Q2’}] model.encode([{‘doc’: ‘Doc1’}, {‘doc’: ‘Doc2’}]

#You can train it with InputExample like this. Note, that the order must always be the same: train_example = InputExample(texts=[{‘query’: ‘Train query’}, {‘doc’: ‘Document’}], label=1)

Parameters
  • sub_modules – Dict in the format str -> List[models]. The models in the specified list will be applied for input marked with the respective key.

  • allow_empty_key – If true, inputs without a key can be processed. If false, an exception will be thrown if no key is specified.

class sentence_transformers.models.BoW(vocab: list[str], word_weights: dict[str, float] = {}, unknown_word_weight: float = 1, cumulative_term_frequency: bool = True)[source]

Implements a Bag-of-Words (BoW) model to derive sentence embeddings.

A weighting can be added to allow the generation of tf-idf vectors. The output vector has the size of the vocab.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

class sentence_transformers.models.CNN(in_word_embedding_dimension: int, out_channels: int = 256, kernel_sizes: list[int] = [1, 3, 5], stride_sizes: list[int] = None)[source]

CNN-layer with multiple kernel-sizes over the word embeddings

Initializes internal Module state, shared by both nn.Module and ScriptModule.

class sentence_transformers.models.LSTM(word_embedding_dimension: int, hidden_dim: int, num_layers: int = 1, dropout: float = 0, bidirectional: bool = True)[source]

Bidirectional LSTM running over word embeddings.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

class sentence_transformers.models.Normalize[source]

This layer normalizes embeddings to unit length

Initializes internal Module state, shared by both nn.Module and ScriptModule.

class sentence_transformers.models.StaticEmbedding(tokenizer: Tokenizer | PreTrainedTokenizerFast, embedding_weights: np.array | torch.Tensor | None = None, embedding_dim: int | None = None, **kwargs)[source]

Initializes the StaticEmbedding model given a tokenizer. The model is a simple embedding bag model that takes the mean of trained per-token embeddings to compute text embeddings.

Parameters
  • tokenizer (Tokenizer | PreTrainedTokenizerFast) – The tokenizer to be used. Must be a fast tokenizer from transformers or tokenizers.

  • embedding_weights (np.array | torch.Tensor | None, optional) – Pre-trained embedding weights. Defaults to None.

  • embedding_dim (int | None, optional) – Dimension of the embeddings. Required if embedding_weights is not provided. Defaults to None.

Example:

from sentence_transformers import SentenceTransformer
from sentence_transformers.models import StaticEmbedding
from tokenizers import Tokenizer

# Pre-distilled embeddings:
static_embedding = StaticEmbedding.from_model2vec("minishlab/M2V_base_output")
# or distill your own embeddings:
static_embedding = StaticEmbedding.from_distillation("BAAI/bge-base-en-v1.5", device="cuda")
# or start with randomized embeddings:
tokenizer = Tokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
static_embedding = StaticEmbedding(tokenizer, embedding_dim=512)

model = SentenceTransformer(modules=[static_embedding])

embeddings = model.encode(["What are Pandas?", "The giant panda (Ailuropoda melanoleuca; Chinese: 大熊猫; pinyin: dàxióngmāo), also known as the panda bear or simply the panda, is a bear native to south central China."])
similarity = model.similarity(embeddings[0], embeddings[1])
# tensor([[0.9177]]) (If you use the distilled bge-base)
Raises
  • ValueError – If the tokenizer is not a fast tokenizer.

  • ValueError – If neither embedding_weights nor embedding_dim is provided.

classmethod from_distillation(model_name: str, vocabulary: list[str] | None = None, device: str | None = None, pca_dims: int | None = 256, apply_zipf: bool = True, use_subword: bool = True)StaticEmbedding[source]

Creates a StaticEmbedding instance from a distillation process using the model2vec package.

Parameters
  • model_name (str) – The name of the model to distill.

  • vocabulary (list[str] | None, optional) – A list of vocabulary words to use. Defaults to None.

  • device (str) – The device to run the distillation on (e.g., ‘cpu’, ‘cuda’). If not specified, the strongest device is automatically detected. Defaults to None.

  • pca_dims (int | None, optional) – The number of dimensions for PCA reduction. Defaults to 256.

  • apply_zipf (bool) – Whether to apply Zipf’s law during distillation. Defaults to True.

  • use_subword (bool) – Whether to use subword tokenization. Defaults to True.

Returns

An instance of StaticEmbedding initialized with the distilled model’s

tokenizer and embedding weights.

Return type

StaticEmbedding

Raises

ImportError – If the model2vec package is not installed.

classmethod from_model2vec(model_id_or_path: str)sentence_transformers.models.StaticEmbedding.StaticEmbedding[source]

Create a StaticEmbedding instance from a model2vec model. This method loads a pre-trained model2vec model and extracts the embedding weights and tokenizer to create a StaticEmbedding instance.

Parameters

model_id_or_path (str) – The identifier or path to the pre-trained model2vec model.

Returns

An instance of StaticEmbedding initialized with the tokenizer and embedding weights

the model2vec model.

Return type

StaticEmbedding

Raises

ImportError – If the model2vec package is not installed.

class sentence_transformers.models.WeightedLayerPooling(word_embedding_dimension, num_hidden_layers: int = 12, layer_start: int = 4, layer_weights=None)[source]

Token embeddings are weighted mean of their different hidden layer representations

Initializes internal Module state, shared by both nn.Module and ScriptModule.

class sentence_transformers.models.WordEmbeddings(tokenizer: sentence_transformers.models.tokenizer.WordTokenizer.WordTokenizer, embedding_weights, update_embeddings: bool = False, max_seq_length: int = 1000000)[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

class sentence_transformers.models.WordWeights(vocab: list[str], word_weights: dict[str, float], unknown_word_weight: float = 1)[source]

This model can weight word embeddings, for example, with idf-values.

Initializes the WordWeights class.

Parameters
  • vocab (List[str]) – Vocabulary of the tokenizer.

  • word_weights (Dict[str, float]) – Mapping of tokens to a float weight value. Word embeddings are multiplied by this float value. Tokens in word_weights must not be equal to the vocab (can contain more or less values).

  • unknown_word_weight (float, optional) – Weight for words in vocab that do not appear in the word_weights lookup. These can be, for example, rare words in the vocab where no weight exists. Defaults to 1.