Multilingual-Models

The issue with multilingual BERT (mBERT) as well as with XLM-RoBERTa is that those produce rather bad sentence representation out-of-the-box. Further, the vectors spaces between languages are not aligned, i.e., the sentences with the same content in different languages would be mapped to different locations in the vector space.

In my publication Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation I describe any easy approach to extend sentence embeddings to further languages.

Chien Vu also wrote a nice blog article on this technique: A complete guide to transfer learning from English to other Languages using Sentence Embeddings BERT Models

Available Pre-trained Models

For a list of available models, see Pretrained Models.

Usage

You can use the models in the following way:

from sentence_transformers import SentenceTransformer

embedder = SentenceTransformer("model-name")
embeddings = embedder.encode(["Hello World", "Hallo Welt", "Hola mundo"])
print(embeddings)

Performance

The performance was evaluated on the Semantic Textual Similarity (STS) 2017 dataset. The task is to predict the semantic similarity (on a scale 0-5) of two given sentences. STS2017 has monolingual test data for English, Arabic, and Spanish, and cross-lingual test data for English-Arabic, -Spanish and -Turkish.

We extended the STS2017 and added cross-lingual test data for English-German, French-English, Italian-English, and Dutch-English (STS2017-extended.zip). The performance is measured using Spearman correlation between the predicted similarity score and the gold score.

Model AR-AR AR-EN ES-ES ES-EN EN-EN TR-EN EN-DE FR-EN IT-EN NL-EN Average
XLM-RoBERTa mean pooling 25.7 17.4 51.8 10.9 50.7 9.2 21.3 16.6 22.9 26.0 25.2
mBERT mean pooling 50.9 16.7 56.7 21.5 54.4 16.0 33.9 33.0 34.0 35.6 35.3
LASER 68.9 66.5 79.7 57.9 77.6 72.0 64.2 69.1 70.8 68.5 69.5
Sentence Transformer Models
distiluse-base-multilingual-cased 75.9 77.6 85.3 78.7 85.4 75.5 80.3 80.2 80.5 81.7 80.1

Extend your own models

Multilingual Knowledge Distillation

The idea is based on a fixed (monolingual) teacher model, that produces sentence embeddings with our desired properties in one language. The student model is supposed to mimic the teacher model, i.e., the same English sentence should be mapped to the same vector by the teacher and by the student model. In order that the student model works for further languages, we train the student model on parallel (translated) sentences. The translation of each sentence should also be mapped to the same vector as the original sentence.

In the above figure, the student model should map Hello World and the German translation Hallo Welt to the vector of teacher_model(‘Hello World’). We achieve this by training the student model using mean squared error (MSE) loss.

In our experiments we initialized the student model with the multilingual XLM-RoBERTa model.

Training

For a fully automatic code example, see make_multilingual.py.

This scripts downloads the parallel sentences corpus, a corpus with transcripts and translations from talks. It than extends a monolingual model to several languages (en, de, es, it, fr, ar, tr). This corpus contains parallel data for more than 100 languages, hence, you can simple change the script and train a multilingual model in your favorite languages.

Data Format

As training data we require parallel sentences, i.e., sentences translated in various languages. As data format, we use a tab-separated .tsv file. In the first column, you have your source sentence, for example, an English sentence. In the following columns, you have the translations of this source sentence. If you have multiple translations per source sentence, you can put them in the same line or in different lines.

Source_sentence Target_lang1    Target_lang2    Target_lang3
Source_sentence Target_lang1    Target_lang2

An example file could look like this (EN DE ES):

Hello World Hallo Welt  Hola Mundo
Sentences are separated with a tab character.    Die Sätze sind per Tab getrennt.    Las oraciones se separan con un carácter de tabulación.

The order of the translations are not important, it is only important that the first column contains a sentence in a language that is understood by the teacher model.

Loading Training Datasets

You can load such a training file using the ParallelSentencesDataset class:

from sentence_transformers.datasets import ParallelSentencesDataset

train_data = ParallelSentencesDataset(student_model=student_model, teacher_model=teacher_model)
train_data.load_data("path/to/tab/separated/train-en-de.tsv")
train_data.load_data("path/to/tab/separated/train-en-es.tsv.gz")
train_data.load_data("path/to/tab/separated/train-en-fr.tsv.gz")

train_dataloader = DataLoader(train_data, shuffle=True, batch_size=train_batch_size)
train_loss = losses.MSELoss(model=student_model)

You load a file with the load_data() method. You can load multiple files by calling load_data multiple times. You can also regular files or .gz-compressed files.

Per default, all datasets are weighted equally. In the above example a (source, translation)-pair will be sampled equally from all three datasets. If you pass a weight parameter (integer), you can weight some datasets higher or lower.

Sources for Training Data

A great website for a vast number of parallel (translated) datasets is OPUS. There, you find parallel datasets for more than 400 languages.

The examples/training/multilingual folder contains some scripts that downloads parallel training data and brings it into the right format:

Evaluation

Training can be evaluated in different ways. For an example how to use these evaluation methods, see make_multilingual.py.

MSE Evaluation

You can measure the mean squared error (MSE) between the student embeddings and teacher embeddings. This can be achieved with the ``

# src_sentences and trg_sentences are lists of translated sentences, such that trg_sentences[i] is the translation of src_sentences[i]
dev_mse = evaluation.MSEEvaluator(src_sentences, trg_sentences, teacher_model=teacher_model)

This evaluator computes the teacher embeddings for the src_sentences, for example, for English. During training, the student model is used to compute embeddings for the trg_sentences, for example, for Spanish. The distance between teacher and student embeddings is measures. Lower scores indicate a better performance.

Translation Accuracy

You can also measure the translation accuracy. Given a list with source sentences, for example, 1000 English sentences. And a list with matching target (translated) sentences, for example, 1000 Spanish sentences.

For each sentence pair, we check if their embeddings are the closest using cosine similarity. I.e., for each src_sentences[i] we check if trg_sentences[i] has the highest similarity out of all target sentences. If this is the case, we have a hit, otherwise an error. This evaluator reports accuracy (higher = better).

# src_sentences and trg_sentences are lists of translated sentences, such that trg_sentences[i] is the translation of src_sentences[i]
dev_trans_acc = evaluation.TranslationEvaluator(
    src_sentences,
    trg_sentences,
    name=os.path.basename(dev_file),
    batch_size=inference_batch_size,
)

Multi-Lingual Semantic Textual Similarity

You can also measure the semantic textual similarity (STS) between sentence pairs in different languages:

sts_evaluator = evaluation.EmbeddingSimilarityEvaluatorFromList(sentences1, sentences2, scores)

Where sentences1 and sentences2 are lists of sentences and score is numeric value indicating the semantic similarity between sentences1[i] and sentences2[i].

Citation

If you use the code for multilingual models, feel free to cite our publication Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation:

@article{reimers-2020-multilingual-sentence-bert,
    title = "Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation",
    author = "Reimers, Nils and Gurevych, Iryna",
    journal= "arXiv preprint arXiv:2004.09813",
    month = "04",
    year = "2020",
    url = "http://arxiv.org/abs/2004.09813",
}