CrossEncoder

CrossEncoder

For an introduction to Cross-Encoders, see Cross-Encoders.

class sentence_transformers.cross_encoder.CrossEncoder(model_name_or_path: str, num_labels: int = None, max_length: int = None, activation_fn: Callable | None = None, device: str | None = None, cache_folder: str = None, trust_remote_code: bool = False, revision: str | None = None, local_files_only: bool = False, token: bool | str | None = None, model_kwargs: dict = None, tokenizer_kwargs: dict = None, config_kwargs: dict = None, model_card_data: CrossEncoderModelCardData | None = None)[source]

A CrossEncoder takes exactly two sentences / texts as input and either predicts a score or label for this sentence pair. It can for example predict the similarity of the sentence pair on a scale of 0 … 1.

It does not yield a sentence embedding and does not work for individual sentences.

Parameters:
  • model_name_or_path (str) – A model name from Hugging Face Hub that can be loaded with AutoModel, or a path to a local model. We provide several pre-trained CrossEncoder models that can be used for common tasks.

  • num_labels (int, optional) – Number of labels of the classifier. If 1, the CrossEncoder is a regression model that outputs a continuous score 0…1. If > 1, it output several scores that can be soft-maxed to get probability scores for the different classes. Defaults to None.

  • max_length (int, optional) – Max length for input sequences. Longer sequences will be truncated. If None, max length of the model will be used. Defaults to None.

  • activation_fn (Callable, optional) – Callable (like nn.Sigmoid) about the default activation function that should be used on-top of model.predict(). If None. nn.Sigmoid() will be used if num_labels=1, else nn.Identity(). Defaults to None.

  • device (str, optional) – Device (like “cuda”, “cpu”, “mps”, “npu”) that should be used for computation. If None, checks if a GPU can be used.

  • cache_folder (str, Path, optional) – Path to the folder where cached files are stored.

  • trust_remote_code (bool, optional) – Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set to True for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. Defaults to False.

  • revision (str, optional) – The specific model version to use. It can be a branch name, a tag name, or a commit id, for a stored model on Hugging Face. Defaults to None.

  • local_files_only (bool, optional) – Whether or not to only look at local files (i.e., do not try to download the model).

  • token (bool or str, optional) – Hugging Face authentication token to download private models.

  • model_kwargs (Dict[str, Any], optional) –

    Additional model configuration parameters to be passed to the Hugging Face Transformers model. Particularly useful options are:

    • torch_dtype: Override the default torch.dtype and load the model under a specific dtype. The different options are:

      1. torch.float16, torch.bfloat16 or torch.float: load in a specified dtype, ignoring the model’s config.torch_dtype if one exists. If not specified - the model will get loaded in torch.float (fp32).

      2. "auto" - A torch_dtype entry in the config.json file of the model will be attempted to be used. If this entry isn’t found then next check the dtype of the first weight in the checkpoint that’s of a floating point type and use that as dtype. This will load the model using the dtype it was saved in at the end of the training. It can’t be used as an indicator of how the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32.

    • attn_implementation: The attention implementation to use in the model (if relevant). Can be any of “eager” (manual implementation of the attention), “sdpa” (using F.scaled_dot_product_attention), or “flash_attention_2” (using Dao-AILab/flash-attention). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual “eager” implementation.

    See the AutoModelForSequenceClassification.from_pretrained documentation for more details.

  • tokenizer_kwargs (Dict[str, Any], optional) – Additional tokenizer configuration parameters to be passed to the Hugging Face Transformers tokenizer. See the AutoTokenizer.from_pretrained documentation for more details.

  • config_kwargs (Dict[str, Any], optional) – Additional model configuration parameters to be passed to the Hugging Face Transformers config. See the AutoConfig.from_pretrained documentation for more details. For example, you can set classifier_dropout via this parameter.

  • model_card_data (SentenceTransformerModelCardData, optional) – A model card data object that contains information about the model. This is used to generate a model card when saving the model. If not set, a default model card data object is created.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

bfloat16() T

Casts all floating point parameters and buffers to bfloat16 datatype.

Note

This method modifies the module in-place.

Returns:

self

Return type:

Module

compile(*args, **kwargs)

Compile this Module’s forward using torch.compile().

This Module’s __call__ method is compiled and all arguments are passed as-is to torch.compile().

See torch.compile() for details on the arguments for this function.

cpu() T

Moves all model parameters and buffers to the CPU.

Note

This method modifies the module in-place.

Returns:

self

Return type:

Module

cuda(device: int | device | None = None) T

Moves all model parameters and buffers to the GPU.

This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized.

Note

This method modifies the module in-place.

Parameters:

device (int, optional) – if specified, all parameters will be copied to that device

Returns:

self

Return type:

Module

double() T

Casts all floating point parameters and buffers to double datatype.

Note

This method modifies the module in-place.

Returns:

self

Return type:

Module

eval() T

Sets the module in evaluation mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. Dropout, BatchNorm, etc.

This is equivalent with self.train(False).

See Locally disabling gradient computation for a comparison between .eval() and several similar mechanisms that may be confused with it.

Returns:

self

Return type:

Module

fit(train_dataloader: DataLoader, evaluator: SentenceEvaluator = None, epochs: int = 1, loss_fct=None, activation_fct=Identity(), scheduler: str = 'WarmupLinear', warmup_steps: int = 10000, optimizer_class: type[Optimizer] = <class 'torch.optim.adamw.AdamW'>, optimizer_params: dict[str, object] = {'lr': 2e-05}, weight_decay: float = 0.01, evaluation_steps: int = 0, output_path: str = None, save_best_model: bool = True, max_grad_norm: float = 1, use_amp: bool = False, callback: Callable[[float, int, int], None] = None, show_progress_bar: bool = True) None[source]

Deprecated training method from before Sentence Transformers v4.0, it is recommended to use CrossEncoderTrainer instead. This method uses CrossEncoderTrainer behind the scenes, but does not provide as much flexibility as the Trainer itself.

This training approach uses a DataLoader and Loss function to train the model.

This method should produce equivalent results in v4.0 as before v4.0, but if you encounter any issues with your existing training scripts, then you may wish to use CrossEncoder.old_fit instead. That uses the old training method from before v4.0.

Parameters:
  • train_dataloader – The DataLoader with InputExample instances

  • evaluator – An evaluator (sentence_transformers.cross_encoder.evaluation) evaluates the model performance during training on held- out dev data. It is used to determine the best model that is saved to disk.

  • epochs – Number of epochs for training

  • loss_fct – Which loss function to use for training. If None, will use BinaryCrossEntropy() if self.config.num_labels == 1 else CrossEntropyLoss(). Defaults to None.

  • activation_fct – Activation function applied on top of logits output of model.

  • scheduler – Learning rate scheduler. Available schedulers: constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts

  • warmup_steps – Behavior depends on the scheduler. For WarmupLinear (default), the learning rate is increased from o up to the maximal learning rate. After these many training steps, the learning rate is decreased linearly back to zero.

  • optimizer_class – Optimizer

  • optimizer_params – Optimizer parameters

  • weight_decay – Weight decay for model parameters

  • evaluation_steps – If > 0, evaluate the model using evaluator after each number of training steps

  • output_path – Storage path for the model and evaluation files

  • save_best_model – If true, the best model (according to evaluator) is stored at output_path

  • max_grad_norm – Used for gradient normalization.

  • use_amp – Use Automatic Mixed Precision (AMP). Only for Pytorch >= 1.6.0

  • callback – Callback function that is invoked after each evaluation. It must accept the following three parameters in this order: score, epoch, steps

  • show_progress_bar – If True, output a tqdm progress bar

float() T

Casts all floating point parameters and buffers to float datatype.

Note

This method modifies the module in-place.

Returns:

self

Return type:

Module

half() T

Casts all floating point parameters and buffers to half datatype.

Note

This method modifies the module in-place.

Returns:

self

Return type:

Module

old_fit(train_dataloader: ~torch.utils.data.dataloader.DataLoader, evaluator: ~sentence_transformers.evaluation.SentenceEvaluator.SentenceEvaluator | None = None, epochs: int = 1, loss_fct=None, activation_fct=Identity(), scheduler: str = 'WarmupLinear', warmup_steps: int = 10000, optimizer_class: type[~torch.optim.optimizer.Optimizer] = <class 'torch.optim.adamw.AdamW'>, optimizer_params: dict[str, object] = {'lr': 2e-05}, weight_decay: float = 0.01, evaluation_steps: int = 0, output_path: str | None = None, save_best_model: bool = True, max_grad_norm: float = 1, use_amp: bool = False, callback: ~typing.Callable[[float, int, int], None] | None = None, show_progress_bar: bool = True) None[source]

Deprecated training method from before Sentence Transformers v4.0, it is recommended to use CrossEncoderTrainer instead. This method should only be used if you encounter issues with your existing training scripts after upgrading to v4.0.

This training approach uses a DataLoader and Loss function to train the model.

Parameters:
  • train_dataloader – The DataLoader with InputExample instances

  • evaluator – An evaluator (sentence_transformers.cross_encoder.evaluation) evaluates the model performance during training on held- out dev data. It is used to determine the best model that is saved to disk.

  • epochs – Number of epochs for training

  • loss_fct – Which loss function to use for training. If None, will use BinaryCrossEntropy() if self.config.num_labels == 1 else CrossEntropyLoss(). Defaults to None.

  • activation_fct – Activation function applied on top of logits output of model.

  • scheduler – Learning rate scheduler. Available schedulers: constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts

  • warmup_steps – Behavior depends on the scheduler. For WarmupLinear (default), the learning rate is increased from o up to the maximal learning rate. After these many training steps, the learning rate is decreased linearly back to zero.

  • optimizer_class – Optimizer

  • optimizer_params – Optimizer parameters

  • weight_decay – Weight decay for model parameters

  • evaluation_steps – If > 0, evaluate the model using evaluator after each number of training steps

  • output_path – Storage path for the model and evaluation files

  • save_best_model – If true, the best model (according to evaluator) is stored at output_path

  • max_grad_norm – Used for gradient normalization.

  • use_amp – Use Automatic Mixed Precision (AMP). Only for Pytorch >= 1.6.0

  • callback – Callback function that is invoked after each evaluation. It must accept the following three parameters in this order: score, epoch, steps

  • show_progress_bar – If True, output a tqdm progress bar

predict(sentences: tuple[str, str] | list[str], batch_size: int = 32, show_progress_bar: bool | None = None, activation_fn: Callable | None = None, apply_softmax: bool | None = False, convert_to_numpy: Literal[False] = True, convert_to_tensor: Literal[False] = False) torch.Tensor[source]
predict(sentences: list[tuple[str, str]] | list[list[str]] | tuple[str, str] | list[str], batch_size: int = 32, show_progress_bar: bool | None = None, activation_fn: Callable | None = None, apply_softmax: bool | None = False, convert_to_numpy: Literal[True] = True, convert_to_tensor: Literal[False] = False) np.ndarray
predict(sentences: list[tuple[str, str]] | list[list[str]] | tuple[str, str] | list[str], batch_size: int = 32, show_progress_bar: bool | None = None, activation_fn: Callable | None = None, apply_softmax: bool | None = False, convert_to_numpy: bool = True, convert_to_tensor: Literal[True] = False) torch.Tensor
predict(sentences: list[tuple[str, str]] | list[list[str]], batch_size: int = 32, show_progress_bar: bool | None = None, activation_fn: Callable | None = None, apply_softmax: bool | None = False, convert_to_numpy: Literal[False] = True, convert_to_tensor: Literal[False] = False) list[torch.Tensor]

Performs predictions with the CrossEncoder on the given sentence pairs.

Parameters:
  • sentences (Union[List[Tuple[str, str]], Tuple[str, str]]) – A list of sentence pairs [(Sent1, Sent2), (Sent3, Sent4)] or one sentence pair (Sent1, Sent2).

  • batch_size (int, optional) – Batch size for encoding. Defaults to 32.

  • show_progress_bar (bool, optional) – Output progress bar. Defaults to None.

  • activation_fn (callable, optional) – Activation function applied on the logits output of the CrossEncoder. If None, the model.activation_fn will be used, which defaults to torch.nn.Sigmoid if num_labels=1, else torch.nn.Identity. Defaults to None.

  • convert_to_numpy (bool, optional) – Convert the output to a numpy matrix. Defaults to True.

  • apply_softmax (bool, optional) – If set to True and model.num_labels > 1, applies softmax on the logits output such that for each sample, the scores of each class sum to 1. Defaults to False.

  • convert_to_numpy – Whether the output should be a list of numpy vectors. If False, output a list of PyTorch tensors. Defaults to True.

  • convert_to_tensor (bool, optional) – Whether the output should be one large tensor. Overwrites convert_to_numpy. Defaults to False.

Returns:

Predictions for the passed sentence pairs. The return type depends on the convert_to_numpy and convert_to_tensor parameters. If convert_to_tensor is True, the output will be a torch.Tensor. If convert_to_numpy is True, the output will be a numpy.ndarray. Otherwise, the output will be a list of torch.Tensor values.

Return type:

Union[List[torch.Tensor], np.ndarray, torch.Tensor]

Examples

from sentence_transformers import CrossEncoder

model = CrossEncoder("cross-encoder/stsb-roberta-base")
sentences = [["I love cats", "Cats are amazing"], ["I prefer dogs", "Dogs are loyal"]]
model.predict(sentences)
# => array([0.6912767, 0.4303499], dtype=float32)
push_to_hub(repo_id: str, *, token: str | None = None, private: bool | None = None, safe_serialization: bool = True, commit_message: str | None = None, exist_ok: bool = False, revision: str | None = None, create_pr: bool = False, tags: list[str] | None = None) str[source]

Upload the CrossEncoder model to the Hugging Face Hub.

Example

from sentence_transformers import CrossEncoder

model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2")
model.push_to_hub("username/my-crossencoder-model")
# => "https://huggingface.co/username/my-crossencoder-model"
Parameters:
  • repo_id (str) – The name of the repository on the Hugging Face Hub, e.g. “username/repo_name”, “organization/repo_name” or just “repo_name”.

  • token (str, optional) – The authentication token to use for the Hugging Face Hub API. If not provided, will use the token stored via the Hugging Face CLI.

  • private (bool, optional) – Whether to create a private repository. If not specified, the repository will be public.

  • safe_serialization (bool, optional) – Whether or not to convert the model weights in safetensors format for safer serialization. Defaults to True.

  • commit_message (str, optional) – The commit message to use for the push. Defaults to “Add new CrossEncoder model”.

  • exist_ok (bool, optional) – If True, do not raise an error if the repository already exists. Ignored if create_pr=True. Defaults to False.

  • revision (str, optional) – The git branch to commit to. Defaults to the head of the ‘main’ branch.

  • create_pr (bool, optional) – Whether to create a Pull Request with the upload or directly commit. Defaults to False.

  • tags (list[str], optional) – A list of tags to add to the model card. Defaults to None.

Returns:

URL of the commit or pull request (if create_pr=True)

Return type:

str

rank(query: str, documents: list[str], top_k: int | None = None, return_documents: bool = False, batch_size: int = 32, show_progress_bar: bool = None, activation_fn: Callable | None = None, apply_softmax=False, convert_to_numpy: bool = True, convert_to_tensor: bool = False) list[dict[Literal['corpus_id', 'score', 'text'], int | float | str]][source]

Performs ranking with the CrossEncoder on the given query and documents. Returns a sorted list with the document indices and scores.

Parameters:
  • query (str) – A single query.

  • documents (List[str]) – A list of documents.

  • top_k (Optional[int], optional) – Return the top-k documents. If None, all documents are returned. Defaults to None.

  • return_documents (bool, optional) – If True, also returns the documents. If False, only returns the indices and scores. Defaults to False.

  • batch_size (int, optional) – Batch size for encoding. Defaults to 32.

  • show_progress_bar (bool, optional) – Output progress bar. Defaults to None.

  • activation_fn ([type], optional) – Activation function applied on the logits output of the CrossEncoder. If None, nn.Sigmoid() will be used if num_labels=1, else nn.Identity. Defaults to None.

  • convert_to_numpy (bool, optional) – Convert the output to a numpy matrix. Defaults to True.

  • apply_softmax (bool, optional) – If there are more than 2 dimensions and apply_softmax=True, applies softmax on the logits output. Defaults to False.

  • convert_to_tensor (bool, optional) – Convert the output to a tensor. Defaults to False.

Returns:

A sorted list with the “corpus_id”, “score”, and optionally “text” of the documents.

Return type:

List[Dict[Literal[“corpus_id”, “score”, “text”], Union[int, float, str]]]

Example

from sentence_transformers import CrossEncoder
model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2")

query = "Who wrote 'To Kill a Mockingbird'?"
documents = [
    "'To Kill a Mockingbird' is a novel by Harper Lee published in 1960. It was immediately successful, winning the Pulitzer Prize, and has become a classic of modern American literature.",
    "The novel 'Moby-Dick' was written by Herman Melville and first published in 1851. It is considered a masterpiece of American literature and deals with complex themes of obsession, revenge, and the conflict between good and evil.",
    "Harper Lee, an American novelist widely known for her novel 'To Kill a Mockingbird', was born in 1926 in Monroeville, Alabama. She received the Pulitzer Prize for Fiction in 1961.",
    "Jane Austen was an English novelist known primarily for her six major novels, which interpret, critique and comment upon the British landed gentry at the end of the 18th century.",
    "The 'Harry Potter' series, which consists of seven fantasy novels written by British author J.K. Rowling, is among the most popular and critically acclaimed books of the modern era.",
    "'The Great Gatsby', a novel written by American author F. Scott Fitzgerald, was published in 1925. The story is set in the Jazz Age and follows the life of millionaire Jay Gatsby and his pursuit of Daisy Buchanan."
]

model.rank(query, documents, return_documents=True)
[{'corpus_id': 0,
'score': 10.67858,
'text': "'To Kill a Mockingbird' is a novel by Harper Lee published in 1960. It was immediately successful, winning the Pulitzer Prize, and has become a classic of modern American literature."},
{'corpus_id': 2,
'score': 9.761677,
'text': "Harper Lee, an American novelist widely known for her novel 'To Kill a Mockingbird', was born in 1926 in Monroeville, Alabama. She received the Pulitzer Prize for Fiction in 1961."},
{'corpus_id': 1,
'score': -3.3099542,
'text': "The novel 'Moby-Dick' was written by Herman Melville and first published in 1851. It is considered a masterpiece of American literature and deals with complex themes of obsession, revenge, and the conflict between good and evil."},
{'corpus_id': 5,
'score': -4.8989105,
'text': "'The Great Gatsby', a novel written by American author F. Scott Fitzgerald, was published in 1925. The story is set in the Jazz Age and follows the life of millionaire Jay Gatsby and his pursuit of Daisy Buchanan."},
{'corpus_id': 4,
'score': -5.082967,
'text': "The 'Harry Potter' series, which consists of seven fantasy novels written by British author J.K. Rowling, is among the most popular and critically acclaimed books of the modern era."}]
save_pretrained(path: str, *, safe_serialization: bool = True, **kwargs) None[source]

Saves the model and tokenizer to path; identical to save

set_config_value(key: str, value) None[source]

Set a value in the underlying model’s config.

Parameters:
  • key (str) – The key to set.

  • value – The value to set.

to(device: int | str | device | None = None) None[source]

Moves and/or casts the parameters and buffers.

This can be called as

to(device=None, dtype=None, non_blocking=False)[source]
to(dtype, non_blocking=False)[source]
to(tensor, non_blocking=False)[source]
to(memory_format=torch.channels_last)[source]

Its signature is similar to torch.Tensor.to(), but only accepts floating point or complex dtypes. In addition, this method will only cast the floating point or complex parameters and buffers to dtype (if given). The integral parameters and buffers will be moved device, if that is given, but with dtypes unchanged. When non_blocking is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.

See below for examples.

Note

This method modifies the module in-place.

Parameters:
  • device (torch.device) – the desired device of the parameters and buffers in this module

  • dtype (torch.dtype) – the desired floating point or complex dtype of the parameters and buffers in this module

  • tensor (torch.Tensor) – Tensor whose dtype and device are the desired dtype and device for all parameters and buffers in this module

  • memory_format (torch.memory_format) – the desired memory format for 4D parameters and buffers in this module (keyword only argument)

Returns:

self

Return type:

Module

Examples:

>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> linear = nn.Linear(2, 2)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]])
>>> linear.to(torch.double)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1913, -0.3420],
        [-0.5113, -0.2325]], dtype=torch.float64)
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)
>>> gpu1 = torch.device("cuda:1")
>>> linear.to(gpu1, dtype=torch.half, non_blocking=True)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')
>>> cpu = torch.device("cpu")
>>> linear.to(cpu)
Linear(in_features=2, out_features=2, bias=True)
>>> linear.weight
Parameter containing:
tensor([[ 0.1914, -0.3420],
        [-0.5112, -0.2324]], dtype=torch.float16)

>>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)
>>> linear.weight
Parameter containing:
tensor([[ 0.3741+0.j,  0.2382+0.j],
        [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)
>>> linear(torch.ones(3, 2, dtype=torch.cdouble))
tensor([[0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j],
        [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
train(mode: bool = True) T

Sets the module in training mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. Dropout, BatchNorm, etc.

Parameters:

mode (bool) – whether to set training mode (True) or evaluation mode (False). Default: True.

Returns:

self

Return type:

Module

CrossEncoderModelCardData

class sentence_transformers.cross_encoder.model_card.CrossEncoderModelCardData(language: str | list[str] | None = <factory>, license: str | None = None, model_name: str | None = None, model_id: str | None = None, train_datasets: list[dict[str, str]] = <factory>, eval_datasets: list[dict[str, str]] = <factory>, task_name: str = None, tags: list[str] | None = <factory>, generate_widget_examples: Literal['deprecated'] = 'deprecated')[source]

A dataclass storing data used in the model card.

Parameters:
  • language (Optional[Union[str, List[str]]]) – The model language, either a string or a list, e.g. “en” or [“en”, “de”, “nl”]

  • license (Optional[str]) – The license of the model, e.g. “apache-2.0”, “mit”, or “cc-by-nc-sa-4.0”

  • model_name (Optional[str]) – The pretty name of the model, e.g. “CrossEncoder based on answerdotai/ModernBERT-base”.

  • model_id (Optional[str]) – The model ID when pushing the model to the Hub, e.g. “tomaarsen/ce-mpnet-base-ms-marco”.

  • train_datasets (List[Dict[str, str]]) – A list of the names and/or Hugging Face dataset IDs of the training datasets. e.g. [{“name”: “SNLI”, “id”: “stanfordnlp/snli”}, {“name”: “MultiNLI”, “id”: “nyu-mll/multi_nli”}, {“name”: “STSB”}]

  • eval_datasets (List[Dict[str, str]]) – A list of the names and/or Hugging Face dataset IDs of the evaluation datasets. e.g. [{“name”: “SNLI”, “id”: “stanfordnlp/snli”}, {“id”: “mteb/stsbenchmark-sts”}]

  • task_name (str) – The human-readable task the model is trained on, e.g. “semantic search and paraphrase mining”.

  • tags (Optional[List[str]]) – A list of tags for the model, e.g. [“sentence-transformers”, “cross-encoder”].

Tip

Install codecarbon to automatically track carbon emission usage and include it in your model cards.

Example:

>>> model = CrossEncoder(
...     "microsoft/mpnet-base",
...     model_card_data=CrossEncoderModelCardData(
...         model_id="tomaarsen/ce-mpnet-base-allnli",
...         train_datasets=[{"name": "SNLI", "id": "stanfordnlp/snli"}, {"name": "MultiNLI", "id": "nyu-mll/multi_nli"}],
...         eval_datasets=[{"name": "SNLI", "id": "stanfordnlp/snli"}, {"name": "MultiNLI", "id": "nyu-mll/multi_nli"}],
...         license="apache-2.0",
...         language="en",
...     ),
... )