Training Arguments

SentenceTransformerTrainingArguments

class sentence_transformers.training_args.SentenceTransformerTrainingArguments(output_dir: str, overwrite_output_dir: bool = False, do_train: bool = False, do_eval: bool = False, do_predict: bool = False, evaluation_strategy: Union[transformers.trainer_utils.IntervalStrategy, str] = 'no', prediction_loss_only: bool = False, per_device_train_batch_size: int = 8, per_device_eval_batch_size: int = 8, per_gpu_train_batch_size: Optional[int] = None, per_gpu_eval_batch_size: Optional[int] = None, gradient_accumulation_steps: int = 1, eval_accumulation_steps: Optional[int] = None, eval_delay: Optional[float] = 0, learning_rate: float = 5e-05, weight_decay: float = 0.0, adam_beta1: float = 0.9, adam_beta2: float = 0.999, adam_epsilon: float = 1e-08, max_grad_norm: float = 1.0, num_train_epochs: float = 3.0, max_steps: int = -1, lr_scheduler_type: Union[transformers.trainer_utils.SchedulerType, str] = 'linear', lr_scheduler_kwargs: Optional[Dict] = <factory>, warmup_ratio: float = 0.0, warmup_steps: int = 0, log_level: Optional[str] = 'passive', log_level_replica: Optional[str] = 'warning', log_on_each_node: bool = True, logging_dir: Optional[str] = None, logging_strategy: Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps', logging_first_step: bool = False, logging_steps: float = 500, logging_nan_inf_filter: bool = True, save_strategy: Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps', save_steps: float = 500, save_total_limit: Optional[int] = None, save_safetensors: Optional[bool] = True, save_on_each_node: bool = False, save_only_model: bool = False, no_cuda: bool = False, use_cpu: bool = False, use_mps_device: bool = False, seed: int = 42, data_seed: Optional[int] = None, jit_mode_eval: bool = False, use_ipex: bool = False, bf16: bool = False, fp16: bool = False, fp16_opt_level: str = 'O1', half_precision_backend: str = 'auto', bf16_full_eval: bool = False, fp16_full_eval: bool = False, tf32: Optional[bool] = None, local_rank: int = -1, ddp_backend: Optional[str] = None, tpu_num_cores: Optional[int] = None, tpu_metrics_debug: bool = False, debug: Union[str, List[transformers.debug_utils.DebugOption]] = '', dataloader_drop_last: bool = False, eval_steps: Optional[float] = None, dataloader_num_workers: int = 0, past_index: int = -1, run_name: Optional[str] = None, disable_tqdm: Optional[bool] = None, remove_unused_columns: Optional[bool] = True, label_names: Optional[List[str]] = None, load_best_model_at_end: Optional[bool] = False, metric_for_best_model: Optional[str] = None, greater_is_better: Optional[bool] = None, ignore_data_skip: bool = False, fsdp: Optional[Union[List[transformers.trainer_utils.FSDPOption], str]] = '', fsdp_min_num_params: int = 0, fsdp_config: Optional[str] = None, fsdp_transformer_layer_cls_to_wrap: Optional[str] = None, deepspeed: Optional[str] = None, label_smoothing_factor: float = 0.0, optim: Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch', optim_args: Optional[str] = None, adafactor: bool = False, group_by_length: bool = False, length_column_name: Optional[str] = 'length', report_to: Optional[List[str]] = None, ddp_find_unused_parameters: Optional[bool] = None, ddp_bucket_cap_mb: Optional[int] = None, ddp_broadcast_buffers: Optional[bool] = None, dataloader_pin_memory: bool = True, dataloader_persistent_workers: bool = False, skip_memory_metrics: bool = True, use_legacy_prediction_loop: bool = False, push_to_hub: bool = False, resume_from_checkpoint: Optional[str] = None, hub_model_id: Optional[str] = None, hub_strategy: Union[transformers.trainer_utils.HubStrategy, str] = 'every_save', hub_token: Optional[str] = None, hub_private_repo: bool = False, hub_always_push: bool = False, gradient_checkpointing: bool = False, gradient_checkpointing_kwargs: Optional[dict] = None, include_inputs_for_metrics: bool = False, fp16_backend: str = 'auto', push_to_hub_model_id: Optional[str] = None, push_to_hub_organization: Optional[str] = None, push_to_hub_token: Optional[str] = None, mp_parameters: str = '', auto_find_batch_size: bool = False, full_determinism: bool = False, torchdynamo: Optional[str] = None, ray_scope: Optional[str] = 'last', ddp_timeout: Optional[int] = 1800, torch_compile: bool = False, torch_compile_backend: Optional[str] = None, torch_compile_mode: Optional[str] = None, dispatch_batches: Optional[bool] = None, split_batches: Optional[bool] = False, include_tokens_per_second: Optional[bool] = False, include_num_input_tokens_seen: Optional[bool] = False, neftune_noise_alpha: Optional[float] = None, batch_sampler: Union[sentence_transformers.training_args.BatchSamplers, str] = <BatchSamplers.BATCH_SAMPLER: 'batch_sampler'>, multi_dataset_batch_sampler: Union[sentence_transformers.training_args.MultiDatasetBatchSamplers, str] = <MultiDatasetBatchSamplers.PROPORTIONAL: 'proportional'>)[source]

SentenceTransformerTrainingArguments extends TrainingArguments with additional arguments specific to Sentence Transformers. See TrainingArguments for the complete list of available arguments.

Parameters
  • output_dir (str) – The output directory where the model checkpoints will be written.

  • batch_sampler (Union[BatchSamplers, str], optional) – The batch sampler to use. See BatchSamplers for valid options. Defaults to BatchSamplers.BATCH_SAMPLER.

  • multi_dataset_batch_sampler (Union[MultiDatasetBatchSamplers, str], optional) – The multi-dataset batch sampler to use. See MultiDatasetBatchSamplers for valid options. Defaults to MultiDatasetBatchSamplers.PROPORTIONAL.

property ddp_timeout_delta

The actual timeout for torch.distributed.init_process_group since it expects a timedelta variable.

property device

The device used by this process.

property eval_batch_size

The actual batch size for evaluation (may differ from per_gpu_eval_batch_size in distributed training).

get_process_log_level()

Returns the log level to be used depending on whether this process is the main process of node 0, main process of node non-0, or a non-main process.

For the main process the log level defaults to the logging level set (logging.WARNING if you didn’t do anything) unless overridden by log_level argument.

For the replica processes the log level defaults to logging.WARNING unless overridden by log_level_replica argument.

The choice between the main and replica process settings is made according to the return value of should_log.

get_warmup_steps(num_training_steps: int)

Get number of steps used for a linear warmup.

property local_process_index

The index of the local process used.

main_process_first(local=True, desc='work')

A context manager for torch distributed environment where on needs to do something on the main process, while blocking replicas, and when it’s finished releasing the replicas.

One such use is for datasets’s map feature which to be efficient should be run once on the main process, which upon completion saves a cached version of results and which then automatically gets loaded by the replicas.

Parameters
  • local (bool, optional, defaults to True) – if True first means process of rank 0 of each node if False first means process of rank 0 of node rank 0 In multi-node environment with a shared filesystem you most likely will want to use local=False so that only the main process of the first node will do the processing. If however, the filesystem is not shared, then the main process of each node will need to do the processing, which is the default behavior.

  • desc (str, optional, defaults to “work”) – a work description to be used in debug logs

property n_gpu

The number of GPUs used by this process.

Note

This will only be greater than one when you have multiple GPUs available but are not using distributed training. For distributed training, it will always be 1.

property parallel_mode

The current mode used for parallelism if multiple GPUs/TPU cores are available. One of:

  • ParallelMode.NOT_PARALLEL: no parallelism (CPU or one GPU).

  • ParallelMode.NOT_DISTRIBUTED: several GPUs in one single process (uses torch.nn.DataParallel).

  • ParallelMode.DISTRIBUTED: several GPUs, each having its own process (uses torch.nn.DistributedDataParallel).

  • ParallelMode.TPU: several TPU cores.

property place_model_on_device

Can be subclassed and overridden for some specific integrations.

property process_index

The index of the current process used.

set_dataloader(train_batch_size: int = 8, eval_batch_size: int = 8, drop_last: bool = False, num_workers: int = 0, pin_memory: bool = True, persistent_workers: bool = False, auto_find_batch_size: bool = False, ignore_data_skip: bool = False, sampler_seed: Optional[int] = None)

A method that regroups all arguments linked to the dataloaders creation.

Parameters
  • drop_last (bool, optional, defaults to False) – Whether to drop the last incomplete batch (if the length of the dataset is not divisible by the batch size) or not.

  • num_workers (int, optional, defaults to 0) – Number of subprocesses to use for data loading (PyTorch only). 0 means that the data will be loaded in the main process.

  • pin_memory (bool, optional, defaults to True) – Whether you want to pin memory in data loaders or not. Will default to True.

  • persistent_workers (bool, optional, defaults to False) – If True, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. Can potentially speed up training, but will increase RAM usage. Will default to False.

  • auto_find_batch_size (bool, optional, defaults to False) – Whether to find a batch size that will fit into memory automatically through exponential decay, avoiding CUDA Out-of-Memory errors. Requires accelerate to be installed (pip install accelerate)

  • ignore_data_skip (bool, optional, defaults to False) – When resuming training, whether or not to skip the epochs and batches to get the data loading at the same stage as in the previous training. If set to True, the training will begin faster (as that skipping step can take a long time) but will not yield the same results as the interrupted training would have.

  • sampler_seed (int, optional) – Random seed to be used with data samplers. If not set, random generators for data sampling will use the same seed as self.seed. This can be used to ensure reproducibility of data sampling, independent of the model seed.

Example:

```py >>> from transformers import TrainingArguments

>>> args = TrainingArguments("working_dir")
>>> args = args.set_dataloader(train_batch_size=16, eval_batch_size=64)
>>> args.per_device_train_batch_size
16
```
set_evaluate(strategy: Union[str, transformers.trainer_utils.IntervalStrategy] = 'no', steps: int = 500, batch_size: int = 8, accumulation_steps: Optional[int] = None, delay: Optional[float] = None, loss_only: bool = False, jit_mode: bool = False)

A method that regroups all arguments linked to evaluation.

Parameters
  • strategy (str or [~trainer_utils.IntervalStrategy], optional, defaults to “no”) –

    The evaluation strategy to adopt during training. Possible values are:

    • ”no”: No evaluation is done during training.

    • ”steps”: Evaluation is done (and logged) every steps.

    • ”epoch”: Evaluation is done at the end of each epoch.

    Setting a strategy different from “no” will set self.do_eval to True.

  • steps (int, optional, defaults to 500) – Number of update steps between two evaluations if strategy=”steps”.

  • batch_size (int optional, defaults to 8) – The batch size per device (GPU/TPU core/CPU…) used for evaluation.

  • accumulation_steps (int, optional) – Number of predictions steps to accumulate the output tensors for, before moving the results to the CPU. If left unset, the whole predictions are accumulated on GPU/TPU before being moved to the CPU (faster but requires more memory).

  • delay (float, optional) – Number of epochs or steps to wait for before the first evaluation can be performed, depending on the evaluation_strategy.

  • loss_only (bool, optional, defaults to False) – Ignores all outputs except the loss.

  • jit_mode (bool, optional) – Whether or not to use PyTorch jit trace for inference.

Example:

```py >>> from transformers import TrainingArguments

>>> args = TrainingArguments("working_dir")
>>> args = args.set_evaluate(strategy="steps", steps=100)
>>> args.eval_steps
100
```
set_logging(strategy: Union[str, transformers.trainer_utils.IntervalStrategy] = 'steps', steps: int = 500, report_to: Union[str, List[str]] = 'none', level: str = 'passive', first_step: bool = False, nan_inf_filter: bool = False, on_each_node: bool = False, replica_level: str = 'passive')

A method that regroups all arguments linked to logging.

Parameters
  • strategy (str or [~trainer_utils.IntervalStrategy], optional, defaults to “steps”) –

    The logging strategy to adopt during training. Possible values are:

    • ”no”: No save is done during training.

    • ”epoch”: Save is done at the end of each epoch.

    • ”steps”: Save is done every save_steps.

  • steps (int, optional, defaults to 500) – Number of update steps between two logs if strategy=”steps”.

  • level (str, optional, defaults to “passive”) – Logger log level to use on the main process. Possible choices are the log levels as strings: “debug”, “info”, “warning”, “error” and “critical”, plus a “passive” level which doesn’t set anything and lets the application set the level.

  • report_to (str or List[str], optional, defaults to “all”) – The list of integrations to report the results and logs to. Supported platforms are “azure_ml”, “clearml”, “codecarbon”, “comet_ml”, “dagshub”, “dvclive”, “flyte”, “mlflow”, “neptune”, “tensorboard”, and “wandb”. Use “all” to report to all integrations installed, “none” for no integrations.

  • first_step (bool, optional, defaults to False) – Whether to log and evaluate the first global_step or not.

  • nan_inf_filter (bool, optional, defaults to True) –

    Whether to filter nan and inf losses for logging. If set to True the loss of every step that is nan or inf is filtered and the average loss of the current logging window is taken instead.

    <Tip>

    nan_inf_filter only influences the logging of loss values, it does not change the behavior the gradient is computed or applied to the model.

    </Tip>

  • on_each_node (bool, optional, defaults to True) – In multinode distributed training, whether to log using log_level once per node, or only on the main node.

  • replica_level (str, optional, defaults to “passive”) – Logger log level to use on replicas. Same choices as log_level

Example:

```py >>> from transformers import TrainingArguments

>>> args = TrainingArguments("working_dir")
>>> args = args.set_logging(strategy="steps", steps=100)
>>> args.logging_steps
100
```
set_lr_scheduler(name: Union[str, transformers.trainer_utils.SchedulerType] = 'linear', num_epochs: float = 3.0, max_steps: int = - 1, warmup_ratio: float = 0, warmup_steps: int = 0)

A method that regroups all arguments linked to the learning rate scheduler and its hyperparameters.

Parameters
  • name (str or [SchedulerType], optional, defaults to “linear”) – The scheduler type to use. See the documentation of [SchedulerType] for all possible values.

  • num_epochs (float, optional, defaults to 3.0) – Total number of training epochs to perform (if not an integer, will perform the decimal part percents of the last epoch before stopping training).

  • max_steps (int, optional, defaults to -1) – If set to a positive number, the total number of training steps to perform. Overrides num_train_epochs. For a finite dataset, training is reiterated through the dataset (if all data is exhausted) until max_steps is reached.

  • warmup_ratio (float, optional, defaults to 0.0) – Ratio of total training steps used for a linear warmup from 0 to learning_rate.

  • warmup_steps (int, optional, defaults to 0) – Number of steps used for a linear warmup from 0 to learning_rate. Overrides any effect of warmup_ratio.

Example:

```py >>> from transformers import TrainingArguments

>>> args = TrainingArguments("working_dir")
>>> args = args.set_lr_scheduler(name="cosine", warmup_ratio=0.05)
>>> args.warmup_ratio
0.05
```
set_optimizer(name: Union[str, transformers.training_args.OptimizerNames] = 'adamw_torch', learning_rate: float = 5e-05, weight_decay: float = 0, beta1: float = 0.9, beta2: float = 0.999, epsilon: float = 1e-08, args: Optional[str] = None)

A method that regroups all arguments linked to the optimizer and its hyperparameters.

Parameters
  • name (str or [training_args.OptimizerNames], optional, defaults to “adamw_torch”) – The optimizer to use: “adamw_hf”, “adamw_torch”, “adamw_torch_fused”, “adamw_apex_fused”, “adamw_anyprecision” or “adafactor”.

  • learning_rate (float, optional, defaults to 5e-5) – The initial learning rate.

  • weight_decay (float, optional, defaults to 0) – The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights.

  • beta1 (float, optional, defaults to 0.9) – The beta1 hyperparameter for the adam optimizer or its variants.

  • beta2 (float, optional, defaults to 0.999) – The beta2 hyperparameter for the adam optimizer or its variants.

  • epsilon (float, optional, defaults to 1e-8) – The epsilon hyperparameter for the adam optimizer or its variants.

  • args (str, optional) – Optional arguments that are supplied to AnyPrecisionAdamW (only useful when optim=”adamw_anyprecision”).

Example:

```py >>> from transformers import TrainingArguments

>>> args = TrainingArguments("working_dir")
>>> args = args.set_optimizer(name="adamw_torch", beta1=0.8)
>>> args.optim
'adamw_torch'
```
set_push_to_hub(model_id: str, strategy: Union[str, transformers.trainer_utils.HubStrategy] = 'every_save', token: Optional[str] = None, private_repo: bool = False, always_push: bool = False)

A method that regroups all arguments linked to synchronizing checkpoints with the Hub.

<Tip>

Calling this method will set self.push_to_hub to True, which means the output_dir will begin a git directory synced with the repo (determined by model_id) and the content will be pushed each time a save is triggered (depending on`self.save_strategy`). Calling [~Trainer.save_model] will also trigger a push.

</Tip>

Parameters
  • model_id (str) – The name of the repository to keep in sync with the local output_dir. It can be a simple model ID in which case the model will be pushed in your namespace. Otherwise it should be the whole repository name, for instance “user_name/model”, which allows you to push to an organization you are a member of with “organization_name/model”.

  • strategy (str or [~trainer_utils.HubStrategy], optional, defaults to “every_save”) –

    Defines the scope of what is pushed to the Hub and when. Possible values are:

    • ”end”: push the model, its configuration, the tokenizer (if passed along to the [Trainer]) and a

    draft of a model card when the [~Trainer.save_model] method is called. - “every_save”: push the model, its configuration, the tokenizer (if passed along to the [Trainer])

    and

    a draft of a model card each time there is a model save. The pushes are asynchronous to not block training, and in case the save are very frequent, a new push is only attempted if the previous one is finished. A last push is made with the final model at the end of training. - “checkpoint”: like “every_save” but the latest checkpoint is also pushed in a subfolder named last-checkpoint, allowing you to resume training easily with trainer.train(resume_from_checkpoint=”last-checkpoint”). - “all_checkpoints”: like “checkpoint” but all checkpoints are pushed like they appear in the

    output

    folder (so you will get one checkpoint folder per folder in your final repository)

  • token (str, optional) – The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with huggingface-cli login.

  • private_repo (bool, optional, defaults to False) – If True, the Hub repo will be set to private.

  • always_push (bool, optional, defaults to False) – Unless this is True, the Trainer will skip pushing a checkpoint when the previous push is not finished.

Example:

```py >>> from transformers import TrainingArguments

>>> args = TrainingArguments("working_dir")
>>> args = args.set_push_to_hub("me/awesome-model")
>>> args.hub_model_id
'me/awesome-model'
```
set_save(strategy: Union[str, transformers.trainer_utils.IntervalStrategy] = 'steps', steps: int = 500, total_limit: Optional[int] = None, on_each_node: bool = False)

A method that regroups all arguments linked to checkpoint saving.

Parameters
  • strategy (str or [~trainer_utils.IntervalStrategy], optional, defaults to “steps”) –

    The checkpoint save strategy to adopt during training. Possible values are:

    • ”no”: No save is done during training.

    • ”epoch”: Save is done at the end of each epoch.

    • ”steps”: Save is done every save_steps.

  • steps (int, optional, defaults to 500) – Number of updates steps before two checkpoint saves if strategy=”steps”.

  • total_limit (int, optional) – If a value is passed, will limit the total amount of checkpoints. Deletes the older checkpoints in output_dir.

  • on_each_node (bool, optional, defaults to False) –

    When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on the main one.

    This should not be activated when the different nodes use the same storage as the files will be saved with the same names for each node.

Example:

```py >>> from transformers import TrainingArguments

>>> args = TrainingArguments("working_dir")
>>> args = args.set_save(strategy="steps", steps=100)
>>> args.save_steps
100
```
set_testing(batch_size: int = 8, loss_only: bool = False, jit_mode: bool = False)

A method that regroups all basic arguments linked to testing on a held-out dataset.

<Tip>

Calling this method will automatically set self.do_predict to True.

</Tip>

Parameters
  • batch_size (int optional, defaults to 8) – The batch size per device (GPU/TPU core/CPU…) used for testing.

  • loss_only (bool, optional, defaults to False) – Ignores all outputs except the loss.

  • jit_mode (bool, optional) – Whether or not to use PyTorch jit trace for inference.

Example:

```py >>> from transformers import TrainingArguments

>>> args = TrainingArguments("working_dir")
>>> args = args.set_testing(batch_size=32)
>>> args.per_device_eval_batch_size
32
```
set_training(learning_rate: float = 5e-05, batch_size: int = 8, weight_decay: float = 0, num_epochs: float = 3, max_steps: int = - 1, gradient_accumulation_steps: int = 1, seed: int = 42, gradient_checkpointing: bool = False)

A method that regroups all basic arguments linked to the training.

<Tip>

Calling this method will automatically set self.do_train to True.

</Tip>

Parameters
  • learning_rate (float, optional, defaults to 5e-5) – The initial learning rate for the optimizer.

  • batch_size (int optional, defaults to 8) – The batch size per device (GPU/TPU core/CPU…) used for training.

  • weight_decay (float, optional, defaults to 0) – The weight decay to apply (if not zero) to all layers except all bias and LayerNorm weights in the optimizer.

  • num_train_epochs (float, optional, defaults to 3.0) – Total number of training epochs to perform (if not an integer, will perform the decimal part percents of the last epoch before stopping training).

  • max_steps (int, optional, defaults to -1) – If set to a positive number, the total number of training steps to perform. Overrides num_train_epochs. For a finite dataset, training is reiterated through the dataset (if all data is exhausted) until max_steps is reached.

  • gradient_accumulation_steps (int, optional, defaults to 1) –

    Number of updates steps to accumulate the gradients for, before performing a backward/update pass.

    <Tip warning={true}>

    When using gradient accumulation, one step is counted as one step with backward pass. Therefore, logging, evaluation, save will be conducted every gradient_accumulation_steps * xxx_step training examples.

    </Tip>

  • seed (int, optional, defaults to 42) – Random seed that will be set at the beginning of training. To ensure reproducibility across runs, use the [~Trainer.model_init] function to instantiate the model if it has some randomly initialized parameters.

  • gradient_checkpointing (bool, optional, defaults to False) – If True, use gradient checkpointing to save memory at the expense of slower backward pass.

Example:

```py >>> from transformers import TrainingArguments

>>> args = TrainingArguments("working_dir")
>>> args = args.set_training(learning_rate=1e-4, batch_size=32)
>>> args.learning_rate
1e-4
```
property should_log

Whether or not the current process should produce log.

property should_save

Whether or not the current process should write to disk, e.g., to save models and checkpoints.

to_dict()

Serializes this instance while replace Enum by their values (for JSON serialization support). It obfuscates the token values by removing their value.

to_json_string()

Serializes this instance to a JSON string.

to_sanitized_dict()Dict[str, Any]

Sanitized serialization to use with TensorBoard’s hparams

property train_batch_size

The actual batch size for training (may differ from per_gpu_train_batch_size in distributed training).

property world_size

The number of processes used in parallel.

BatchSamplers

class sentence_transformers.training_args.BatchSamplers(value)[source]

Stores the acceptable string identifiers for batch samplers.

The batch sampler is responsible for determining how samples are grouped into batches during training. Valid options are:

  • BatchSamplers.BATCH_SAMPLER: The default PyTorch batch sampler.

  • BatchSamplers.NO_DUPLICATES: Ensures no duplicate samples in a batch.

  • BatchSamplers.GROUP_BY_LABEL: Ensures each batch has 2+ samples from the same label.

MultiDatasetBatchSamplers

class sentence_transformers.training_args.MultiDatasetBatchSamplers(value)[source]

Stores the acceptable string identifiers for multi-dataset batch samplers.

The multi-dataset batch sampler is responsible for determining in what order batches are sampled from multiple datasets during training. Valid options are:

  • MultiDatasetBatchSamplers.ROUND_ROBIN: Round-robin sampling from each dataset until one is exhausted. With this strategy, it’s likely that not all samples from each dataset are used, but each dataset is sampled from equally.

  • MultiDatasetBatchSamplers.PROPORTIONAL: Sample from each dataset in proportion to its size [default]. With this strategy, all samples from each dataset are used and larger datasets are sampled from more frequently.