Natural Language Inference

Given two sentence (premise and hypothesis), Natural Language Inference (NLI) is the task of deciding if the premise entails the hypothesis, if they are contradiction, or if they are neutral. Commonly used NLI dataset are SNLI and MultiNLI.

To train a CrossEncoder on NLI, see the following example file:

  • training_nli.py:

    This example uses CrossEntropyLoss to train the CrossEncoder model to predict the highest logit for the correct class out of “contradiction”, “entailment”, and “neutral”.

You can also train and use SentenceTransformer models for this task. See Sentence Transformer > Training Examples > Natural Language Inference for more details.

Data

We combine SNLI and MultiNLI into a dataset we call AllNLI. These two datasets contain sentence pairs and one of three labels: entailment, neutral, contradiction:

Sentence A (Premise) Sentence B (Hypothesis) Label
A soccer game with multiple males playing. Some men are playing a sport. entailment
An older and younger man smiling. Two men are smiling and laughing at the cats playing on the floor. neutral
A man inspects the uniform of a figure in some East Asian country. The man is sleeping. contradiction

We format AllNLI in a few different subsets, compatible with different loss functions. See for example the pair-class subset of AllNLI.

CrossEntropyLoss

The CrossEntropyLoss is a rather elementary loss that applies the common torch.nn.CrossEntropyLoss on the logits (a.k.a. outputs, raw predictions) produced after 1) passing the tokenized text pairs through the model and 2) applying the optional activation function over the logits. It’s very commonly used if the CrossEncoder model has to predict more than just 1 class.

Inference

You can perform inference using any of the pre-trained CrossEncoder models for NLI like so:

from sentence_transformers import CrossEncoder

model = CrossEncoder("cross-encoder/nli-deberta-v3-base")
scores = model.predict([
    ("A man is eating pizza", "A man eats something"),
    ("A black race car starts up in front of a crowd of people.", "A man is driving down a lonely road."),
])

# Convert scores to labels
label_mapping = ["contradiction", "entailment", "neutral"]
labels = [label_mapping[score_max] for score_max in scores.argmax(axis=1)]
# => ['entailment', 'contradiction']