Trainer

SentenceTransformerTrainer

class sentence_transformers.trainer.SentenceTransformerTrainer(model: SentenceTransformer | None = None, args: SentenceTransformerTrainingArguments = None, train_dataset: Dataset | DatasetDict | IterableDataset | dict[str, Dataset] | None = None, eval_dataset: Dataset | DatasetDict | IterableDataset | dict[str, Dataset] | None = None, loss: nn.Module | dict[str, nn.Module] | Callable[[SentenceTransformer], torch.nn.Module] | dict[str, Callable[[SentenceTransformer], torch.nn.Module]] | None = None, evaluator: SentenceEvaluator | list[SentenceEvaluator] | None = None, data_collator: DataCollator | None = None, tokenizer: PreTrainedTokenizerBase | Callable | None = None, model_init: Callable[], SentenceTransformer] | None = None, compute_metrics: Callable[[EvalPrediction], dict] | None = None, callbacks: list[TrainerCallback] | None = None, optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None = None)[source]

SentenceTransformerTrainer is a simple but feature-complete training and eval loop for PyTorch based on the 🤗 Transformers Trainer.

This trainer integrates support for various transformers.TrainerCallback subclasses, such as:

  • WandbCallback to automatically log training metrics to W&B if wandb is installed

  • TensorBoardCallback to log training metrics to TensorBoard if tensorboard is accessible.

  • CodeCarbonCallback to track the carbon emissions of your model during training if codecarbon is installed.

    • Note: These carbon emissions will be included in your automatically generated model card.

See the Transformers Callbacks documentation for more information on the integrated callbacks and how to write your own callbacks.

Parameters

Important attributes:

  • model – Always points to the core model. If using a transformers model, it will be a [PreTrainedModel] subclass.

  • model_wrapped – Always points to the most external model in case one or more other modules wrap the original model. This is the model that should be used for the forward pass. For example, under DeepSpeed, the inner model is wrapped in DeepSpeed and then again in torch.nn.DistributedDataParallel. If the inner model hasn’t been wrapped, then self.model_wrapped is the same as self.model.

  • is_model_parallel – Whether or not a model has been switched to a model parallel mode (different from data parallelism, this means some of the model layers are split on different GPUs).

  • place_model_on_device – Whether or not to automatically place the model on the device - it will be set to False if model parallel or deepspeed is used, or if the default TrainingArguments.place_model_on_device is overridden to return False .

  • is_in_train – Whether or not a model is currently running train (e.g. when evaluate is called while in train)

add_callback(callback)

Add a callback to the current list of [~transformers.TrainerCallback].

Parameters

callback (type or [~transformers.TrainerCallback]) – A [~transformers.TrainerCallback] class or an instance of a [~transformers.TrainerCallback]. In the first case, will instantiate a member of that class.

compute_loss(model: SentenceTransformer, inputs: dict[str, torch.Tensor | Any], return_outputs: bool = False)torch.Tensor | tuple[torch.Tensor, dict[str, Any]][source]

Computes the loss for the SentenceTransformer model.

It uses self.loss to compute the loss, which can be a single loss function or a dictionary of loss functions for different datasets. If the loss is a dictionary, the dataset name is expected to be passed in the inputs under the key “dataset_name”. This is done automatically in the add_dataset_name_column method. Note that even if return_outputs = True, the outputs will be empty, as the SentenceTransformers losses do not return outputs.

Parameters
  • model (SentenceTransformer) – The SentenceTransformer model.

  • inputs (Dict[str, Union[torch.Tensor, Any]]) – The input data for the model.

  • return_outputs (bool, optional) – Whether to return the outputs along with the loss. Defaults to False.

Returns

The computed loss. If return_outputs is True, returns a tuple of loss and outputs. Otherwise, returns only the loss.

Return type

Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, Any]]]

create_model_card(language: str | None = None, license: str | None = None, tags: str | list[str] | None = None, model_name: str | None = None, finetuned_from: str | None = None, tasks: str | list[str] | None = None, dataset_tags: str | list[str] | None = None, dataset: str | list[str] | None = None, dataset_args: str | list[str] | None = None, **kwargs)None[source]

Creates a draft of a model card using the information available to the Trainer.

Parameters
  • language (str, optional) – The language of the model (if applicable)

  • license (str, optional) – The license of the model. Will default to the license of the pretrained model used, if the original model given to the Trainer comes from a repo on the Hub.

  • tags (str or List[str], optional) – Some tags to be included in the metadata of the model card.

  • model_name (str, optional) – The name of the model.

  • finetuned_from (str, optional) – The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo of the original model given to the Trainer (if it comes from the Hub).

  • tasks (str or List[str], optional) – One or several task identifiers, to be included in the metadata of the model card.

  • dataset_tags (str or List[str], optional) – One or several dataset tags, to be included in the metadata of the model card.

  • dataset (str or List[str], optional) – One or several dataset identifiers, to be included in the metadata of the model card.

  • dataset_args (str or List[str], optional) – One or several dataset arguments, to be included in the metadata of the model card.

create_optimizer()

Setup the optimizer.

We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the Trainer’s init through optimizers, or subclass and override this method in a subclass.

create_optimizer_and_scheduler(num_training_steps: int)

Setup the optimizer and the learning rate scheduler.

We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the Trainer’s init through optimizers, or subclass and override this method (or create_optimizer and/or create_scheduler) in a subclass.

create_scheduler(num_training_steps: int, optimizer: Optional[torch.optim.optimizer.Optimizer] = None)

Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or passed as an argument.

Parameters

num_training_steps (int) – The number of training steps to do.

evaluate(eval_dataset: Dataset | dict[str, Dataset] | None = None, ignore_keys: list[str] | None = None, metric_key_prefix: str = 'eval')dict[str, float][source]

Run evaluation and returns metrics.

The calling script will be responsible for providing a method to compute metrics, as they are task-dependent (pass it to the init compute_metrics argument).

You can also subclass and override this method to inject custom behavior.

Parameters
  • eval_dataset (Union[Dataset, Dict[str, Dataset]), optional) –

    Pass a dataset if you wish to override self.eval_dataset. If it is a [~datasets.Dataset], columns not accepted by the model.forward() method are automatically removed. If it is a dictionary, it will evaluate on each dataset, prepending the dictionary key to the metric name. Datasets must implement the __len__ method.

    <Tip>

    If you pass a dictionary with names of datasets as keys and datasets as values, evaluate will run separate evaluations on each dataset. This can be useful to monitor how training affects other datasets or simply to get a more fine-grained evaluation. When used with load_best_model_at_end, make sure metric_for_best_model references exactly one of the datasets. If you, for example, pass in {“data1”: data1, “data2”: data2} for two datasets data1 and data2, you could specify metric_for_best_model=”eval_data1_loss” for using the loss on data1 and metric_for_best_model=”eval_data2_loss” for the loss on data2.

    </Tip>

  • ignore_keys (List[str], optional) – A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions.

  • metric_key_prefix (str, optional, defaults to “eval”) – An optional prefix to be used as the metrics key prefix. For example the metrics “bleu” will be named “eval_bleu” if the prefix is “eval” (default)

Returns

A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The dictionary also contains the epoch number which comes from the training state.

get_batch_sampler(dataset: Dataset, batch_size: int, drop_last: bool, valid_label_columns: list[str] | None = None, generator: torch.Generator | None = None)BatchSampler | None[source]

Returns the appropriate batch sampler based on the batch_sampler argument in self.args. This batch sampler class supports __len__ and __iter__ methods, and is used as the batch_sampler to create the torch.utils.data.DataLoader.

Note

Override this method to provide a custom batch sampler.

Parameters
  • dataset (Dataset) – The dataset to sample from.

  • batch_size (int) – Number of samples per batch.

  • drop_last (bool) – If True, drop the last incomplete batch if the dataset size is not divisible by the batch size.

  • valid_label_columns (List[str]) – List of column names to check for labels. The first column name from valid_label_columns found in the dataset will be used as the label column.

  • generator (torch.Generator, optional) – Optional random number generator for shuffling the indices.

get_eval_dataloader(eval_dataset: Dataset | None = None)DataLoader[source]

Returns the evaluation [~torch.utils.data.DataLoader].

Subclass and override this method if you want to inject some custom behavior.

Parameters

eval_dataset (torch.utils.data.Dataset, optional) – If provided, will override self.eval_dataset. If it is a [~datasets.Dataset], columns not accepted by the model.forward() method are automatically removed. It must implement __len__.

get_learning_rates()

Returns the learning rate of each parameter from self.optimizer.

get_multi_dataset_batch_sampler(dataset: ConcatDataset, batch_samplers: list[BatchSampler], generator: torch.Generator | None = None, seed: int | None = 0)BatchSampler[source]

Returns the appropriate multi-dataset batch sampler based on the multi_dataset_batch_sampler argument in self.args. This batch sampler class supports __len__ and __iter__ methods, and is used as the batch_sampler to create the torch.utils.data.DataLoader.

Note

Override this method to provide a custom multi-dataset batch sampler.

Parameters
  • dataset (ConcatDataset) – The concatenation of all datasets.

  • batch_samplers (List[BatchSampler]) – List of batch samplers for each dataset in the concatenated dataset.

  • generator (torch.Generator, optional) – Optional random number generator for shuffling the indices.

  • seed (int, optional) – Optional seed for the random number generator

get_num_trainable_parameters()

Get the number of trainable parameters.

get_optimizer_group(param: Optional[Union[str, torch.nn.parameter.Parameter]] = None)

Returns optimizer group for a parameter if given, else returns all optimizer groups for params.

Parameters

param (str or torch.nn.parameter.Parameter, optional) – The parameter for which optimizer group needs to be returned.

get_test_dataloader(test_dataset: datasets.arrow_dataset.Dataset)torch.utils.data.dataloader.DataLoader[source]

Returns the training [~torch.utils.data.DataLoader].

Subclass and override this method if you want to inject some custom behavior.

Parameters

test_dataset (torch.utils.data.Dataset, optional) – The test dataset to use. If it is a [~datasets.Dataset], columns not accepted by the model.forward() method are automatically removed. It must implement __len__.

get_train_dataloader()torch.utils.data.dataloader.DataLoader[source]

Returns the training [~torch.utils.data.DataLoader].

Will use no sampler if train_dataset does not implement __len__, a random sampler (adapted to distributed training if necessary) otherwise.

Subclass and override this method if you want to inject some custom behavior.

Launch an hyperparameter search using optuna or Ray Tune or SigOpt. The optimized quantity is determined by compute_objective, which defaults to a function returning the evaluation loss when no metric is provided, the sum of all metrics otherwise.

<Tip warning={true}>

To use this method, you need to have provided a model_init when initializing your [Trainer]: we need to reinitialize the model at each new run. This is incompatible with the optimizers argument, so you need to subclass [Trainer] and override the method [~Trainer.create_optimizer_and_scheduler] for custom optimizer/scheduler.

</Tip>

Parameters
  • hp_space (Callable[[“optuna.Trial”], Dict[str, float]], optional) – A function that defines the hyperparameter search space. Will default to [~trainer_utils.default_hp_space_optuna] or [~trainer_utils.default_hp_space_ray] or [~trainer_utils.default_hp_space_sigopt] depending on your backend.

  • compute_objective (Callable[[Dict[str, float]], float], optional) – A function computing the objective to minimize or maximize from the metrics returned by the evaluate method. Will default to [~trainer_utils.default_compute_objective].

  • n_trials (int, optional, defaults to 100) – The number of trial runs to test.

  • direction (str or List[str], optional, defaults to “minimize”) – If it’s single objective optimization, direction is str, can be “minimize” or “maximize”, you should pick “minimize” when optimizing the validation loss, “maximize” when optimizing one or several metrics. If it’s multi objectives optimization, direction is List[str], can be List of “minimize” and “maximize”, you should pick “minimize” when optimizing the validation loss, “maximize” when optimizing one or several metrics.

  • backend (str or [~training_utils.HPSearchBackend], optional) – The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending on which one is installed. If all are installed, will default to optuna.

  • hp_name (Callable[[“optuna.Trial”], str]], optional) – A function that defines the trial/run name. Will default to None.

  • kwargs (Dict[str, Any], optional) –

    Additional keyword arguments passed along to optuna.create_study or ray.tune.run. For more information see:

Returns

All the information about the best run or best runs for multi-objective optimization. Experiment summary can be found in run_summary attribute for Ray backend.

Return type

[trainer_utils.BestRun or List[trainer_utils.BestRun]]

is_local_process_zero()bool

Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several machines) main process.

is_world_process_zero()bool

Whether or not this process is the global main process (when training in a distributed fashion on several machines, this is only going to be True for one process).

log(logs: Dict[str, float])None

Log logs on the various objects watching training.

Subclass and override this method to inject custom behavior.

Parameters

logs (Dict[str, float]) – The values to log.

pop_callback(callback)

Remove a callback from the current list of [~transformers.TrainerCallback] and returns it.

If the callback is not found, returns None (and no error is raised).

Parameters

callback (type or [~transformers.TrainerCallback]) – A [~transformers.TrainerCallback] class or an instance of a [~transformers.TrainerCallback]. In the first case, will pop the first member of that class found in the list of callbacks.

Returns

The callback removed, if found.

Return type

[~transformers.TrainerCallback]

propagate_args_to_deepspeed(auto_find_batch_size=False)

Sets values in the deepspeed plugin based on the Trainer args

push_to_hub(commit_message: Optional[str] = 'End of training', blocking: bool = True, token: Optional[str] = None, **kwargs)str

Upload self.model and self.tokenizer to the 🤗 model hub on the repo self.args.hub_model_id.

Parameters
  • commit_message (str, optional, defaults to “End of training”) – Message to commit while pushing.

  • blocking (bool, optional, defaults to True) – Whether the function should return only when the git push has finished.

  • token (str, optional, defaults to None) – Token with write permission to overwrite Trainer’s original args.

  • kwargs (Dict[str, Any], optional) – Additional keyword arguments passed along to [~Trainer.create_model_card].

Returns

The URL of the repository where the model was pushed if blocking=False, or a Future object tracking the progress of the commit if blocking=True.

remove_callback(callback)

Remove a callback from the current list of [~transformers.TrainerCallback].

Parameters

callback (type or [~transformers.TrainerCallback]) – A [~transformers.TrainerCallback] class or an instance of a [~transformers.TrainerCallback]. In the first case, will remove the first member of that class found in the list of callbacks.

save_model(output_dir: Optional[str] = None, _internal_call: bool = False)

Will save the model, so you can reload it using from_pretrained().

Will only save from the main process.

train(resume_from_checkpoint: Optional[Union[bool, str]] = None, trial: Union[optuna.Trial, Dict[str, Any]] = None, ignore_keys_for_eval: Optional[List[str]] = None, **kwargs)

Main training entry point.

Parameters
  • resume_from_checkpoint (str or bool, optional) – If a str, local path to a saved checkpoint as saved by a previous instance of [Trainer]. If a bool and equals True, load the last checkpoint in args.output_dir as saved by a previous instance of [Trainer]. If present, training will resume from the model/optimizer/scheduler states loaded here.

  • trial (optuna.Trial or Dict[str, Any], optional) – The trial run or the hyperparameter dictionary for hyperparameter search.

  • ignore_keys_for_eval (List[str], optional) – A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training.

  • kwargs (Dict[str, Any], optional) – Additional keyword arguments used to hide deprecated arguments