Samplers

BatchSamplers

class sentence_transformers.training_args.BatchSamplers(value)[source]

Stores the acceptable string identifiers for batch samplers.

The batch sampler is responsible for determining how samples are grouped into batches during training. Valid options are:

If you want to use a custom batch sampler, you can create a new Trainer class that inherits from SentenceTransformerTrainer and overrides the get_batch_sampler() method. The method must return a class instance that supports __iter__ and __len__ methods. The former should yield a list of indices for each batch, and the latter should return the number of batches.

Usage:
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.losses import MultipleNegativesRankingLoss
from datasets import Dataset

model = SentenceTransformer("microsoft/mpnet-base")
train_dataset = Dataset.from_dict({
    "anchor": ["It's nice weather outside today.", "He drove to work."],
    "positive": ["It's so sunny.", "He took the car to the office."],
})
loss = MultipleNegativesRankingLoss(model)
args = SentenceTransformerTrainingArguments(
    output_dir="checkpoints",
    batch_sampler=BatchSamplers.NO_DUPLICATES,
)
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    loss=loss,
)
trainer.train()
class sentence_transformers.sampler.DefaultBatchSampler(*args, **kwargs)[source]

This sampler is the default batch sampler used in the SentenceTransformer library. It is equivalent to the PyTorch BatchSampler.

Parameters
  • sampler (Sampler or Iterable) – The sampler used for sampling elements from the dataset, such as SubsetRandomSampler.

  • batch_size (int) – Number of samples per batch.

  • drop_last (bool) – If True, drop the last incomplete batch if the dataset size is not divisible by the batch size.

class sentence_transformers.sampler.NoDuplicatesBatchSampler(dataset: Dataset, batch_size: int, drop_last: bool, valid_label_columns: list[str] = [], generator: torch.Generator = None, seed: int = 0)[source]

This sampler creates batches such that each batch contains samples where the values are unique, even across columns. This is useful when losses consider other samples in a batch to be in-batch negatives, and you want to ensure that the negatives are not duplicates of the anchor/positive sample.

Recommended for:
Parameters
  • dataset (Dataset) – The dataset to sample from.

  • batch_size (int) – Number of samples per batch.

  • drop_last (bool) – If True, drop the last incomplete batch if the dataset size is not divisible by the batch size.

  • valid_label_columns (List[str]) – List of column names to check for labels. The first column name from valid_label_columns found in the dataset will be used as the label column.

  • generator (torch.Generator, optional) – Optional random number generator for shuffling the indices.

  • seed (int, optional) – Seed for the random number generator to ensure reproducibility.

class sentence_transformers.sampler.GroupByLabelBatchSampler(dataset: Dataset, batch_size: int, drop_last: bool, valid_label_columns: list[str] = None, generator: torch.Generator = None, seed: int = 0)[source]

This sampler groups samples by their labels and aims to create batches such that each batch contains samples where the labels are as homogeneous as possible. This sampler is meant to be used alongside the Batch...TripletLoss classes, which require that each batch contains at least 2 examples per label class.

Recommended for:
Parameters
  • dataset (Dataset) – The dataset to sample from.

  • batch_size (int) – Number of samples per batch. Must be divisible by 2.

  • drop_last (bool) – If True, drop the last incomplete batch if the dataset size is not divisible by the batch size.

  • valid_label_columns (List[str]) – List of column names to check for labels. The first column name from valid_label_columns found in the dataset will be used as the label column.

  • generator (torch.Generator, optional) – Optional random number generator for shuffling the indices.

  • seed (int, optional) – Seed for the random number generator to ensure reproducibility.

MultiDatasetBatchSamplers

class sentence_transformers.training_args.MultiDatasetBatchSamplers(value)[source]

Stores the acceptable string identifiers for multi-dataset batch samplers.

The multi-dataset batch sampler is responsible for determining in what order batches are sampled from multiple datasets during training. Valid options are:

  • MultiDatasetBatchSamplers.ROUND_ROBIN: Uses RoundRobinBatchSampler, which uses round-robin sampling from each dataset until one is exhausted. With this strategy, it’s likely that not all samples from each dataset are used, but each dataset is sampled from equally.

  • MultiDatasetBatchSamplers.PROPORTIONAL: [default] Uses ProportionalBatchSampler, which samples from each dataset in proportion to its size. With this strategy, all samples from each dataset are used and larger datasets are sampled from more frequently.

Usage:
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.training_args import MultiDatasetBatchSamplers
from sentence_transformers.losses import CoSENTLoss
from datasets import Dataset, DatasetDict

model = SentenceTransformer("microsoft/mpnet-base")
train_general = Dataset.from_dict({
    "sentence_A": ["It's nice weather outside today.", "He drove to work."],
    "sentence_B": ["It's so sunny.", "He took the car to the bank."],
    "score": [0.9, 0.4],
})
train_medical = Dataset.from_dict({
    "sentence_A": ["The patient has a fever.", "The doctor prescribed medication.", "The patient is sweating."],
    "sentence_B": ["The patient feels hot.", "The medication was given to the patient.", "The patient is perspiring."],
    "score": [0.8, 0.6, 0.7],
})
train_legal = Dataset.from_dict({
    "sentence_A": ["This contract is legally binding.", "The parties agree to the terms and conditions."],
    "sentence_B": ["Both parties acknowledge their obligations.", "By signing this agreement, the parties enter into a legal relationship."],
    "score": [0.7, 0.8],
})
train_dataset = DatasetDict({
    "general": train_general,
    "medical": train_medical,
    "legal": train_legal,
})

loss = CoSENTLoss(model)
args = SentenceTransformerTrainingArguments(
    output_dir="checkpoints",
    multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
)
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    loss=loss,
)
trainer.train()
class sentence_transformers.sampler.RoundRobinBatchSampler(dataset: ConcatDataset, batch_samplers: list[BatchSampler], generator: torch.Generator = None, seed: int = None)[source]

Batch sampler that yields batches in a round-robin fashion from multiple batch samplers, until one is exhausted. With this sampler, it’s unlikely that all samples from each dataset are used, but we do ensure that each dataset is sampled from equally.

Parameters
  • dataset (ConcatDataset) – A concatenation of multiple datasets.

  • batch_samplers (List[BatchSampler]) – A list of batch samplers, one for each dataset in the ConcatDataset.

  • generator (torch.Generator, optional) – A generator for reproducible sampling. Defaults to None.

  • seed (int, optional) – A seed for the generator. Defaults to None.

class sentence_transformers.sampler.ProportionalBatchSampler(dataset: ConcatDataset, batch_samplers: list[BatchSampler], generator: torch.Generator, seed: int)[source]

Batch sampler that samples from each dataset in proportion to its size, until all are exhausted simultaneously. With this sampler, all samples from each dataset are used and larger datasets are sampled from more frequently.

Parameters
  • dataset (ConcatDataset) – A concatenation of multiple datasets.

  • batch_samplers (List[BatchSampler]) – A list of batch samplers, one for each dataset in the ConcatDataset.

  • generator (torch.Generator, optional) – A generator for reproducible sampling. Defaults to None.

  • seed (int, optional) – A seed for the generator. Defaults to None.