Multimodal Training

Sentence Transformer models can handle multimodal inputs (text, images, audio, and video), enabling cross-modal retrieval tasks such as text-to-image search or audio-to-text matching. The key enabler is the Transformer module’s automatic modality detection: it inspects the underlying model’s processor to determine which modalities are supported, then handles preprocessing for each modality transparently.

This means multimodal training uses the exact same pipeline as text-only training: the same losses, the same trainer, and the same evaluation tools. The data collator handles multimodal preprocessing automatically.

Supported Input Types

You can check which modalities a model supports using the modalities property and the supports() method:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("Qwen/Qwen3-VL-Embedding-2B", revision="refs/pr/23")
print(model.modalities)
# => ['text', 'image', 'message']
print(model.supports("image"))
# => True
print(model.supports("audio"))
# => False

The encode() method accepts a variety of input types depending on the model’s supported modalities:

  • Text: strings.

  • Image: PIL images, file paths, URLs, or numpy/torch arrays.

  • Audio: file paths, numpy/torch arrays, dicts with "array" and "sampling_rate" keys, or torchcodec.AudioDecoder instances.

  • Video: file paths, numpy/torch arrays, dicts with "array" and "video_metadata" keys, or torchcodec.VideoDecoder instances.

  • Multimodal dicts: a dict mapping modality names to values, e.g. {"text": ..., "audio": ...}. The keys must be "text", "image", "audio", or "video".

  • Chat messages: a list of dicts with "role" and "content" keys for multimodal models that use an uncommon chat template to combine text and non-text inputs.

Training

Training a multimodal model follows the same steps as training a text-only model. You can use any compatible loss function, and the trainer and data collator handle multimodal inputs without any special configuration. Datasets can mix modalities across columns, for example a “query” column containing text strings and a “document” column containing PIL images.

Training Example: Document Screenshot Embedding

The training_visual_document_retrieval.py script finetunes Qwen/Qwen3-VL-Embedding-2B on query-document screenshot pairs for visual document retrieval. Here is how it works:

1. Load the model with efficient training settings:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer(
    "Qwen/Qwen3-VL-Embedding-2B",
    model_kwargs={"attn_implementation": "flash_attention_2", "torch_dtype": "bfloat16"},
    processor_kwargs={"min_pixels": 28 * 28, "max_pixels": 600 * 600},
    revision="refs/pr/23",
)

The model_kwargs enable Flash Attention 2 and bfloat16 precision for faster training. The processor_kwargs control image resolution bounds; smaller max_pixels reduces memory usage at the cost of image detail.

2. Load the dataset from the tomaarsen/llamaindex-vdr-en-train-preprocessed dataset, which contains text queries paired with document screenshot images:

from datasets import load_dataset

train_dataset = load_dataset("tomaarsen/llamaindex-vdr-en-train-preprocessed", "train", split="train")
eval_dataset = load_dataset("tomaarsen/llamaindex-vdr-en-train-preprocessed", "eval", split="train")

3. Define the loss function using CachedMultipleNegativesRankingLoss wrapped in MatryoshkaLoss. This combination trains the model for retrieval with in-batch negatives while producing embeddings that remain effective after truncation to smaller dimensions:

from sentence_transformers.sentence_transformer.losses import CachedMultipleNegativesRankingLoss, MatryoshkaLoss

loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=1)
loss = MatryoshkaLoss(model, loss, matryoshka_dims=[2048, 1536, 1024, 512, 256, 128, 64])

4. Evaluate using InformationRetrievalEvaluator with text queries against an image corpus, measuring cross-modal retrieval performance:

from sentence_transformers.sentence_transformer.evaluation import InformationRetrievalEvaluator

eval_evaluator = InformationRetrievalEvaluator(
    queries=eval_queries,       # dict of text queries
    corpus=eval_corpus,         # dict of PIL images
    relevant_docs=eval_relevant_docs,
    name="vdr-eval-hard",
)

5. Train using the standard SentenceTransformerTrainer:

from sentence_transformers.sentence_transformer.trainer import SentenceTransformerTrainer

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    evaluator=eval_evaluator,
)
trainer.train()

After training, the model can be evaluated at each Matryoshka dimension separately to measure the performance-efficiency tradeoff.

References