Models

sentence_transformers.models defines different building blocks, that can be used to create SentenceTransformer networks from scratch. For more details, see Training Overview.

Main Classes

class sentence_transformers.models.Transformer(model_name_or_path: str, max_seq_length: Optional[int] = None, model_args: Dict = {}, cache_dir: Optional[str] = None, tokenizer_args: Dict = {}, do_lower_case: bool = False, tokenizer_name_or_path: Optional[str] = None)

Huggingface AutoModel to generate token embeddings. Loads the correct class, e.g. BERT / RoBERTa etc.

Parameters
  • model_name_or_path – Huggingface models name (https://huggingface.co/models)

  • max_seq_length – Truncate any inputs longer than max_seq_length

  • model_args – Arguments (key, value pairs) passed to the Huggingface Transformers model

  • cache_dir – Cache dir for Huggingface Transformers to store/load models

  • tokenizer_args – Arguments (key, value pairs) passed to the Huggingface Tokenizer model

  • do_lower_case – If true, lowercases the input (independent if the model is cased or not)

  • tokenizer_name_or_path – Name or path of the tokenizer. When None, then model_name_or_path is used

Initializes internal Module state, shared by both nn.Module and ScriptModule.

class sentence_transformers.models.Pooling(word_embedding_dimension: int, pooling_mode: Optional[str] = None, pooling_mode_cls_token: bool = False, pooling_mode_max_tokens: bool = False, pooling_mode_mean_tokens: bool = True, pooling_mode_mean_sqrt_len_tokens: bool = False, pooling_mode_weightedmean_tokens: bool = False, pooling_mode_lasttoken: bool = False, include_prompt=True)

Performs pooling (max or mean) on the token embeddings.

Using pooling, it generates from a variable sized sentence a fixed sized sentence embedding. This layer also allows to use the CLS token if it is returned by the underlying word embedding model. You can concatenate multiple poolings together.

Parameters
  • word_embedding_dimension – Dimensions for the word embeddings

  • pooling_mode – Either “cls”, “lasttoken”, “max”, “mean”, “mean_sqrt_len_tokens”, or “weightedmean”. If set, overwrites the other pooling_mode_* settings

  • pooling_mode_cls_token – Use the first token (CLS token) as text representations

  • pooling_mode_max_tokens – Use max in each dimension over all tokens.

  • pooling_mode_mean_tokens – Perform mean-pooling

  • pooling_mode_mean_sqrt_len_tokens – Perform mean-pooling, but divide by sqrt(input_length).

  • pooling_mode_weightedmean_tokens – Perform (position) weighted mean pooling. See SGPT: GPT Sentence Embeddings for Semantic Search.

  • pooling_mode_lasttoken

    Perform last token pooling. See SGPT: GPT Sentence Embeddings for Semantic Search and Text and Code Embeddings by Contrastive Pre-Training.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

class sentence_transformers.models.Dense(in_features: int, out_features: int, bias: bool = True, activation_function=Tanh(), init_weight: Optional[torch.Tensor] = None, init_bias: Optional[torch.Tensor] = None)

Feed-forward function with activiation function.

This layer takes a fixed-sized sentence embedding and passes it through a feed-forward layer. Can be used to generate deep averaging networks (DAN).

Parameters
  • in_features – Size of the input dimension

  • out_features – Output size

  • bias – Add a bias vector

  • activation_function – Pytorch activation function applied on output

  • init_weight – Initial value for the matrix of the linear layer

  • init_bias – Initial value for the bias of the linear layer

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Further Classes

class sentence_transformers.models.Asym(sub_modules: Dict[str, List[torch.nn.modules.module.Module]], allow_empty_key: bool = True)

This model allows to create asymmetric SentenceTransformer models, that apply different models depending on the specified input key.

In the below example, we create two different Dense models for ‘query’ and ‘doc’. Text that is passed as {‘query’: ‘My query’} will be passed along along the first Dense model, and text that will be passed as {‘doc’: ‘My document’} will use the other Dense model.

Note, that when you call encode(), that only inputs of the same type can be encoded. Mixed-Types cannot be encoded.

Example::

word_embedding_model = models.Transformer(model_name) pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension()) asym_model = models.Asym({‘query’: [models.Dense(word_embedding_model.get_word_embedding_dimension(), 128)], ‘doc’: [models.Dense(word_embedding_model.get_word_embedding_dimension(), 128)]}) model = SentenceTransformer(modules=[word_embedding_model, pooling_model, asym_model])

model.encode([{‘query’: ‘Q1’}, {‘query’: ‘Q2’}] model.encode([{‘doc’: ‘Doc1’}, {‘doc’: ‘Doc2’}]

#You can train it with InputExample like this. Note, that the order must always be the same: train_example = InputExample(texts=[{‘query’: ‘Train query’}, {‘doc’: ‘Document’}], label=1)

Parameters
  • sub_modules – Dict in the format str -> List[models]. The models in the specified list will be applied for input marked with the respective key.

  • allow_empty_key – If true, inputs without a key can be processed. If false, an exception will be thrown if no key is specified.

class sentence_transformers.models.BoW(vocab: List[str], word_weights: Dict[str, float] = {}, unknown_word_weight: float = 1, cumulative_term_frequency: bool = True)

Implements a Bag-of-Words (BoW) model to derive sentence embeddings.

A weighting can be added to allow the generation of tf-idf vectors. The output vector has the size of the vocab.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

class sentence_transformers.models.CNN(in_word_embedding_dimension: int, out_channels: int = 256, kernel_sizes: List[int] = [1, 3, 5], stride_sizes: Optional[List[int]] = None)

CNN-layer with multiple kernel-sizes over the word embeddings

Initializes internal Module state, shared by both nn.Module and ScriptModule.

class sentence_transformers.models.LSTM(word_embedding_dimension: int, hidden_dim: int, num_layers: int = 1, dropout: float = 0, bidirectional: bool = True)

Bidirectional LSTM running over word embeddings.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

class sentence_transformers.models.Normalize

This layer normalizes embeddings to unit length

Initializes internal Module state, shared by both nn.Module and ScriptModule.

class sentence_transformers.models.WeightedLayerPooling(word_embedding_dimension, num_hidden_layers: int = 12, layer_start: int = 4, layer_weights=None)

Token embeddings are weighted mean of their different hidden layer representations

Initializes internal Module state, shared by both nn.Module and ScriptModule.

class sentence_transformers.models.WordEmbeddings(tokenizer: sentence_transformers.models.tokenizer.WordTokenizer.WordTokenizer, embedding_weights, update_embeddings: bool = False, max_seq_length: int = 1000000)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

class sentence_transformers.models.WordWeights(vocab: List[str], word_weights: Dict[str, float], unknown_word_weight: float = 1)

This model can weight word embeddings, for example, with idf-values.

Parameters
  • vocab – Vocabulary of the tokenizer

  • word_weights – Mapping of tokens to a float weight value. Words embeddings are multiplied by this float value. Tokens in word_weights must not be equal to the vocab (can contain more or less values)

  • unknown_word_weight – Weight for words in vocab, that do not appear in the word_weights lookup. These can be for example rare words in the vocab, where no weight exists.