Models¶
sentence_transformers.models
defines different building blocks, that can be used to create SentenceTransformer networks from scratch. For more details, see Training Overview.
Main Classes¶
-
class
sentence_transformers.models.
Transformer
(model_name_or_path: str, max_seq_length: Optional[int, None] = None, model_args: Optional[dict, None] = None, tokenizer_args: Optional[dict, None] = None, config_args: Optional[dict, None] = None, cache_dir: Optional[str, None] = None, do_lower_case: bool = False, tokenizer_name_or_path: Optional[str, None] = None, backend: str = 'torch')[source]¶ Hugging Face AutoModel to generate token embeddings. Loads the correct class, e.g. BERT / RoBERTa etc.
- Parameters
model_name_or_path – Hugging Face models name (https://huggingface.co/models)
max_seq_length – Truncate any inputs longer than max_seq_length
model_args – Keyword arguments passed to the Hugging Face Transformers model
tokenizer_args – Keyword arguments passed to the Hugging Face Transformers tokenizer
config_args – Keyword arguments passed to the Hugging Face Transformers config
cache_dir – Cache dir for Hugging Face Transformers to store/load models
do_lower_case – If true, lowercases the input (independent if the model is cased or not)
tokenizer_name_or_path – Name or path of the tokenizer. When None, then model_name_or_path is used
backend – Backend used for model inference. Can be torch, onnx, or openvino. Default is torch.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
class
sentence_transformers.models.
Pooling
(word_embedding_dimension: int, pooling_mode: Optional[str, None] = None, pooling_mode_cls_token: bool = False, pooling_mode_max_tokens: bool = False, pooling_mode_mean_tokens: bool = True, pooling_mode_mean_sqrt_len_tokens: bool = False, pooling_mode_weightedmean_tokens: bool = False, pooling_mode_lasttoken: bool = False, include_prompt: bool = True)[source]¶ Performs pooling (max or mean) on the token embeddings.
Using pooling, it generates from a variable sized sentence a fixed sized sentence embedding. This layer also allows to use the CLS token if it is returned by the underlying word embedding model. You can concatenate multiple poolings together.
- Parameters
word_embedding_dimension – Dimensions for the word embeddings
pooling_mode – Either “cls”, “lasttoken”, “max”, “mean”, “mean_sqrt_len_tokens”, or “weightedmean”. If set, overwrites the other pooling_mode_* settings
pooling_mode_cls_token – Use the first token (CLS token) as text representations
pooling_mode_max_tokens – Use max in each dimension over all tokens.
pooling_mode_mean_tokens – Perform mean-pooling
pooling_mode_mean_sqrt_len_tokens – Perform mean-pooling, but divide by sqrt(input_length).
pooling_mode_weightedmean_tokens – Perform (position) weighted mean pooling. See SGPT: GPT Sentence Embeddings for Semantic Search.
pooling_mode_lasttoken –
Perform last token pooling. See SGPT: GPT Sentence Embeddings for Semantic Search and Text and Code Embeddings by Contrastive Pre-Training.
include_prompt – If set to false, the prompt tokens are not included in the pooling. This is useful for reproducing work that does not include the prompt tokens in the pooling like INSTRUCTOR, but otherwise not recommended.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
class
sentence_transformers.models.
Dense
(in_features: int, out_features: int, bias: bool = True, activation_function=Tanh(), init_weight: Optional[torch.Tensor, None] = None, init_bias: Optional[torch.Tensor, None] = None)[source]¶ Feed-forward function with activation function.
This layer takes a fixed-sized sentence embedding and passes it through a feed-forward layer. Can be used to generate deep averaging networks (DAN).
- Parameters
in_features – Size of the input dimension
out_features – Output size
bias – Add a bias vector
activation_function – Pytorch activation function applied on output
init_weight – Initial value for the matrix of the linear layer
init_bias – Initial value for the bias of the linear layer
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Further Classes¶
-
class
sentence_transformers.models.
Asym
(sub_modules: dict, allow_empty_key: bool = True)[source]¶ This model allows to create asymmetric SentenceTransformer models, that apply different models depending on the specified input key.
In the below example, we create two different Dense models for ‘query’ and ‘doc’. Text that is passed as {‘query’: ‘My query’} will be passed along along the first Dense model, and text that will be passed as {‘doc’: ‘My document’} will use the other Dense model.
Note, that when you call encode(), that only inputs of the same type can be encoded. Mixed-Types cannot be encoded.
- Example::
word_embedding_model = models.Transformer(model_name) pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) asym_model = models.Asym({‘query’: [models.Dense(word_embedding_model.get_word_embedding_dimension(), 128)], ‘doc’: [models.Dense(word_embedding_model.get_word_embedding_dimension(), 128)]}) model = SentenceTransformer(modules=[word_embedding_model, pooling_model, asym_model])
model.encode([{‘query’: ‘Q1’}, {‘query’: ‘Q2’}] model.encode([{‘doc’: ‘Doc1’}, {‘doc’: ‘Doc2’}]
#You can train it with InputExample like this. Note, that the order must always be the same: train_example = InputExample(texts=[{‘query’: ‘Train query’}, {‘doc’: ‘Document’}], label=1)
- Parameters
sub_modules – Dict in the format str -> List[models]. The models in the specified list will be applied for input marked with the respective key.
allow_empty_key – If true, inputs without a key can be processed. If false, an exception will be thrown if no key is specified.
-
class
sentence_transformers.models.
BoW
(vocab: list, word_weights: dict = {}, unknown_word_weight: float = 1, cumulative_term_frequency: bool = True)[source]¶ Implements a Bag-of-Words (BoW) model to derive sentence embeddings.
A weighting can be added to allow the generation of tf-idf vectors. The output vector has the size of the vocab.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
class
sentence_transformers.models.
CNN
(in_word_embedding_dimension: int, out_channels: int = 256, kernel_sizes: list = [1, 3, 5], stride_sizes: Optional[list, None] = None)[source]¶ CNN-layer with multiple kernel-sizes over the word embeddings
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
class
sentence_transformers.models.
LSTM
(word_embedding_dimension: int, hidden_dim: int, num_layers: int = 1, dropout: float = 0, bidirectional: bool = True)[source]¶ Bidirectional LSTM running over word embeddings.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
class
sentence_transformers.models.
Normalize
[source]¶ This layer normalizes embeddings to unit length
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
class
sentence_transformers.models.
StaticEmbedding
(tokenizer: Tokenizer | PreTrainedTokenizerFast, embedding_weights: np.array | torch.Tensor | None = None, embedding_dim: int | None = None, **kwargs)[source]¶ Initializes the StaticEmbedding model given a tokenizer. The model is a simple embedding bag model that takes the mean of trained per-token embeddings to compute text embeddings.
- Parameters
tokenizer (Tokenizer | PreTrainedTokenizerFast) – The tokenizer to be used. Must be a fast tokenizer from
transformers
ortokenizers
.embedding_weights (np.array | torch.Tensor | None, optional) – Pre-trained embedding weights. Defaults to None.
embedding_dim (int | None, optional) – Dimension of the embeddings. Required if embedding_weights is not provided. Defaults to None.
Example:
from sentence_transformers import SentenceTransformer from sentence_transformers.models import StaticEmbedding from tokenizers import Tokenizer # Pre-distilled embeddings: static_embedding = StaticEmbedding.from_model2vec("minishlab/M2V_base_output") # or distill your own embeddings: static_embedding = StaticEmbedding.from_distillation("BAAI/bge-base-en-v1.5", device="cuda") # or start with randomized embeddings: tokenizer = Tokenizer.from_pretrained("FacebookAI/xlm-roberta-base") static_embedding = StaticEmbedding(tokenizer, embedding_dim=512) model = SentenceTransformer(modules=[static_embedding]) embeddings = model.encode(["What are Pandas?", "The giant panda (Ailuropoda melanoleuca; Chinese: 大熊猫; pinyin: dàxióngmāo), also known as the panda bear or simply the panda, is a bear native to south central China."]) similarity = model.similarity(embeddings[0], embeddings[1]) # tensor([[0.9177]]) (If you use the distilled bge-base)
- Raises
ValueError – If the tokenizer is not a fast tokenizer.
ValueError – If neither embedding_weights nor embedding_dim is provided.
-
classmethod
from_distillation
()[source]¶ Creates a StaticEmbedding instance from a distillation process using the model2vec package.
- Parameters
model_name (str) – The name of the model to distill.
vocabulary (list[str] | None, optional) – A list of vocabulary words to use. Defaults to None.
device (str) – The device to run the distillation on (e.g., ‘cpu’, ‘cuda’). If not specified, the strongest device is automatically detected. Defaults to None.
pca_dims (int | None, optional) – The number of dimensions for PCA reduction. Defaults to 256.
apply_zipf (bool) – Whether to apply Zipf’s law during distillation. Defaults to True.
use_subword (bool) – Whether to use subword tokenization. Defaults to True.
- Returns
- An instance of StaticEmbedding initialized with the distilled model’s
tokenizer and embedding weights.
- Return type
- Raises
ImportError – If the model2vec package is not installed.
-
classmethod
from_model2vec
(model_id_or_path: str) → sentence_transformers.models.StaticEmbedding.StaticEmbedding[source]¶ Create a StaticEmbedding instance from a model2vec model. This method loads a pre-trained model2vec model and extracts the embedding weights and tokenizer to create a StaticEmbedding instance.
- Parameters
model_id_or_path (str) – The identifier or path to the pre-trained model2vec model.
- Returns
- An instance of StaticEmbedding initialized with the tokenizer and embedding weights
the model2vec model.
- Return type
- Raises
ImportError – If the model2vec package is not installed.
-
class
sentence_transformers.models.
WeightedLayerPooling
(word_embedding_dimension, num_hidden_layers: int = 12, layer_start: int = 4, layer_weights=None)[source]¶ Token embeddings are weighted mean of their different hidden layer representations
Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
class
sentence_transformers.models.
WordEmbeddings
(tokenizer: sentence_transformers.models.tokenizer.WordTokenizer.WordTokenizer, embedding_weights, update_embeddings: bool = False, max_seq_length: int = 1000000)[source]¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
class
sentence_transformers.models.
WordWeights
(vocab: list, word_weights: dict, unknown_word_weight: float = 1)[source]¶ This model can weight word embeddings, for example, with idf-values.
Initializes the WordWeights class.
- Parameters
vocab (List[str]) – Vocabulary of the tokenizer.
word_weights (Dict[str, float]) – Mapping of tokens to a float weight value. Word embeddings are multiplied by this float value. Tokens in word_weights must not be equal to the vocab (can contain more or less values).
unknown_word_weight (float, optional) – Weight for words in vocab that do not appear in the word_weights lookup. These can be, for example, rare words in the vocab where no weight exists. Defaults to 1.