Speeding up Inference

Sentence Transformers supports 3 backends for computing embeddings, each with its own optimizations for speeding up inference:


PyTorch

The PyTorch backend is the default backend for Sentence Transformers. If you don’t specify a device, it will use the strongest available option across “cuda”, “mps”, and “cpu”. Its default usage looks like this:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("all-MiniLM-L6-v2")

sentences = ["This is an example sentence", "Each sentence is converted"]
embeddings = model.encode(sentences)

If you’re using a GPU, then you can use the following options to speed up your inference:

Float32 (fp32, full precision) is the default floating-point format in torch, whereas float16 (fp16, half precision) is a reduced-precision floating-point format that can speed up inference on GPUs at a minimal loss of model accuracy. To use it, you can specify the torch_dtype during initialization or call model.half() on the initialized model:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("all-MiniLM-L6-v2", model_kwargs={"torch_dtype": "float16"})
# or: model.half()

sentences = ["This is an example sentence", "Each sentence is converted"]
embeddings = model.encode(sentences)

Bfloat16 (bf16) is similar to fp16, but preserves more of the original accuracy of fp32. To use it, you can specify the torch_dtype during initialization or call model.bfloat16() on the initialized model:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("all-MiniLM-L6-v2", model_kwargs={"torch_dtype": "bfloat16"})
# or: model.bfloat16()

sentences = ["This is an example sentence", "Each sentence is converted"]
embeddings = model.encode(sentences)

ONNX

ONNX can be used to speed up inference by converting the model to ONNX format and using ONNX Runtime to run the model. To use the ONNX backend, you must install Sentence Transformers with the onnx or onnx-gpu extra for CPU or GPU acceleration, respectively:

pip install sentence-transformers[onnx-gpu]
# or
pip install sentence-transformers[onnx]

To convert a model to ONNX format, you can use the following code:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("all-MiniLM-L6-v2", backend="onnx")

sentences = ["This is an example sentence", "Each sentence is converted"]
embeddings = model.encode(sentences)

If the model path or repository already contains a model in ONNX format, Sentence Transformers will automatically use it. Otherwise, it will convert the model to ONNX the format.

All keyword arguments passed via model_kwargs will be passed on to ORTModel.from_pretrained. Some notable arguments include:

  • provider: ONNX Runtime provider to use for loading the model, e.g. "CPUExecutionProvider" . See https://onnxruntime.ai/docs/execution-providers/ for possible providers. If not specified, the strongest provider (E.g. "CUDAExecutionProvider") will be used.

  • file_name: The name of the ONNX file to load. If not specified, will default to "model.onnx" or otherwise "onnx/model.onnx". This argument is useful for specifying optimized or quantized models.

  • export: A boolean flag specifying whether the model will be exported. If not provided, export will be set to True if the model repository or directory does not already contain an ONNX model.

Tip

It’s heavily recommended to save the exported model to prevent having to re-export it every time you run your code. You can do this by calling model.save_pretrained() if your model was local:

model = SentenceTransformer("path/to/my/model", backend="onnx")
model.save_pretrained("path/to/my/model")

or with model.push_to_hub() if your model was from the Hugging Face Hub:

model = SentenceTransformer("intfloat/multilingual-e5-small", backend="onnx")
model.push_to_hub("intfloat/multilingual-e5-small", create_pr=True)

Optimizing ONNX Models

ONNX models can be optimized using Optimum, allowing for speedups on CPUs and GPUs alike. To do this, you can use the export_optimized_onnx_model() function, which saves the optimized in a directory or model repository that you specify. It expects:

  • model: a Sentence Transformer model loaded with the ONNX backend.

  • optimization_config: "O1", "O2", "O3", or "O4" representing optimization levels from AutoOptimizationConfig, or an OptimizationConfig instance.

  • model_name_or_path: a path to save the optimized model file, or the repository name if you want to push it to the Hugging Face Hub.

  • push_to_hub: (Optional) a boolean to push the optimized model to the Hugging Face Hub.

  • create_pr: (Optional) a boolean to create a pull request when pushing to the Hugging Face Hub. Useful when you don’t have write access to the repository.

  • file_suffix: (Optional) a string to append to the model name when saving it. If not specified, the optimization level name string will be used, or just "optimized" if the optimization config was not just a string optimization level.

See this example for exporting a model with optimization level 3 (basic and extended general optimizations, transformers-specific fusions, fast Gelu approximation):

Only optimize once:

from sentence_transformers import SentenceTransformer, export_optimized_onnx_model

model = SentenceTransformer("all-MiniLM-L6-v2", backend="onnx")
export_optimized_onnx_model(model, "O3", "all-MiniLM-L6-v2", push_to_hub=True, create_pr=True)

Before the pull request gets merged:

from sentence_transformers import SentenceTransformer

pull_request_nr = 2 # TODO: Update this to the number of your pull request
model = SentenceTransformer(
   "all-MiniLM-L6-v2",
   backend="onnx",
   model_kwargs={"file_name": "onnx/model_O3.onnx"},
   revision=f"refs/pr/{pull_request_nr}"
)

Once the pull request gets merged:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer(
   "all-MiniLM-L6-v2",
   backend="onnx",
   model_kwargs={"file_name": "onnx/model_O3.onnx"},
)

Only optimize once:

from sentence_transformers import SentenceTransformer, export_optimized_onnx_model

model = SentenceTransformer("path/to/my/mpnet-legal-finetuned", backend="onnx")
export_optimized_onnx_model(model, "O3", "path/to/my/mpnet-legal-finetuned")

After optimizing:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer(
   "path/to/my/mpnet-legal-finetuned",
   backend="onnx",
   model_kwargs={"file_name": "onnx/model_O3.onnx"},
)

Quantizing ONNX Models

ONNX models can be quantized to int8 precision using Optimum, allowing for faster inference on CPUs. To do this, you can use the export_dynamic_quantized_onnx_model() function, which saves the quantized in a directory or model repository that you specify. Dynamic quantization, unlike static quantization, does not require a calibration dataset. It expects:

  • model: a Sentence Transformer model loaded with the ONNX backend.

  • quantization_config: "arm64", "avx2", "avx512", or "avx512_vnni" representing quantization configurations from AutoQuantizationConfig, or an QuantizationConfig instance.

  • model_name_or_path: a path to save the quantized model file, or the repository name if you want to push it to the Hugging Face Hub.

  • push_to_hub: (Optional) a boolean to push the quantized model to the Hugging Face Hub.

  • create_pr: (Optional) a boolean to create a pull request when pushing to the Hugging Face Hub. Useful when you don’t have write access to the repository.

  • file_suffix: (Optional) a string to append to the model name when saving it. If not specified, "qint8_quantized" will be used.

On my CPU, each of the default quantization configurations ("arm64", "avx2", "avx512", "avx512_vnni") resulted in roughly equivalent speedups.

See this example for quantizing a model to int8 with avx512_vnni:

Only quantize once:

from sentence_transformers import SentenceTransformer, export_dynamic_quantized_onnx_model

model = SentenceTransformer("all-MiniLM-L6-v2", backend="onnx")
export_dynamic_quantized_onnx_model(model, "avx512_vnni", "all-MiniLM-L6-v2", push_to_hub=True, create_pr=True)

Before the pull request gets merged:

from sentence_transformers import SentenceTransformer

pull_request_nr = 2 # TODO: Update this to the number of your pull request
model = SentenceTransformer(
   "all-MiniLM-L6-v2",
   backend="onnx",
   model_kwargs={"file_name": "onnx/model_qint8_avx512_vnni.onnx"},
   revision=f"refs/pr/{pull_request_nr}"
)

Once the pull request gets merged:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer(
   "all-MiniLM-L6-v2",
   backend="onnx",
   model_kwargs={"file_name": "onnx/model_qint8_avx512_vnni.onnx"},
)

Only quantize once:

from sentence_transformers import SentenceTransformer, export_dynamic_quantized_onnx_model

model = SentenceTransformer("path/to/my/mpnet-legal-finetuned", backend="onnx")
export_dynamic_quantized_onnx_model(model, "O3", "path/to/my/mpnet-legal-finetuned")

After quantizing:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer(
   "path/to/my/mpnet-legal-finetuned",
   backend="onnx",
   model_kwargs={"file_name": "onnx/model_qint8_avx512_vnni.onnx"},
)

OpenVINO

OpenVINO allows for accelerated inference on CPUs by exporting the model to the OpenVINO format. To use the OpenVINO backend, you must install Sentence Transformers with the openvino extra:

pip install sentence-transformers[openvino]

To convert a model to OpenVINO format, you can use the following code:

from sentence_transformers import SentenceTransformer

model = SentenceTransformer("all-MiniLM-L6-v2", backend="openvino")

sentences = ["This is an example sentence", "Each sentence is converted"]
embeddings = model.encode(sentences)
All keyword arguments passed via model_kwargs will be passed on to OVBaseModel.from_pretrained(). Some notable arguments include:
  • file_name: The name of the ONNX file to load. If not specified, will default to "openvino_model.xml" or otherwise "openvino/openvino_model.xml". This argument is useful for specifying optimized or quantized models.

  • export: A boolean flag specifying whether the model will be exported. If not provided, export will be set to True if the model repository or directory does not already contain an OpenVINO model.

Tip

It’s heavily recommended to save the exported model to prevent having to re-export it every time you run your code. You can do this by calling model.save_pretrained() if your model was local:

model = SentenceTransformer("path/to/my/model", backend="openvino")
model.save_pretrained("path/to/my/model")

or with model.push_to_hub() if your model was from the Hugging Face Hub:

model = SentenceTransformer("intfloat/multilingual-e5-small", backend="openvino")
model.push_to_hub("intfloat/multilingual-e5-small", create_pr=True)

Benchmarks

The following images show the benchmark results for the different backends on GPUs and CPUs. The results are averaged across 4 models of various sizes, 3 datasets, and numerous batch sizes.

Expand the benchmark details
Speedup ratio: Performance ratio: The same models and hardware was used. We compare the performance against the performance of PyTorch with fp32, i.e. the default backend and precision.
  • Evaluation:
    • Semantic Textual Similarity: Spearman rank correlation based on cosine similarity on the sentence-transformers/stsb test set, computed via the EmbeddingSimilarityEvaluator.
    • Information Retrieval: NDCG@10 based on cosine similarity on the entire NanoBEIR collection of datasets, computed via the InformationRetrievalEvaluator.
  • Backends:
    • torch-fp32: PyTorch with float32 precision (default).
    • torch-fp16: PyTorch with float16 precision, via model_kwargs={"torch_dtype": "float16"}.
    • torch-bf16: PyTorch with bfloat16 precision, via model_kwargs={"torch_dtype": "bfloat16"}.
    • onnx: ONNX with float32 precision, via backend="onnx".
    • onnx-O1: ONNX with float32 precision and O1 optimization, via export_optimized_onnx_model(..., "O1", ...) and backend="onnx".
    • onnx-O2: ONNX with float32 precision and O2 optimization, via export_optimized_onnx_model(..., "O2", ...) and backend="onnx".
    • onnx-O3: ONNX with float32 precision and O3 optimization, via export_optimized_onnx_model(..., "O3", ...) and backend="onnx".
    • onnx-O4: ONNX with float16 precision and O4 optimization, via export_optimized_onnx_model(..., "O4", ...) and backend="onnx".
    • onnx-qint8: ONNX quantized to int8 with "avx512_vnni", via export_dynamic_quantized_onnx_model(..., "avx512_vnni", ...) and backend="onnx". The different quantization configurations resulted in roughly equivalent speedups.
    • openvino: OpenVINO, via backend="openvino".
    • openvino-igpu: OpenVINO, via backend="openvino" and model_kwargs={"device": "GPU"}) to use the iGPU from my CPU.
Note that the aggressive averaging across models, datasets, and batch sizes prevents some more intricate patterns from being visible. For example, for GPUs, if we only consider the stsb dataset with the shortest texts, ONNX becomes better: 1.46x for ONNX, and ONNX-O4 reaches 1.83x whereas fp16 and bf16 reach 1.54x and 1.53x respectively. So, for shorter texts we recommend ONNX on GPU.

For CPU, ONNX is also stronger for the stsb dataset with the shortest texts: 1.39x for ONNX, outperforming 1.29x for OpenVINO. ONNX with int8 quantization is even stronger with a 3.08x speedup. For longer texts, ONNX and OpenVINO can even perform slightly worse than PyTorch, so we recommend testing the different backends with your specific model and data to find the best one for your use case.

Benchmark for GPUs Benchmark for CPUs

Recommendations

Based on the benchmarks, this flowchart should help you decide which backend to use for your model:

%%{init: { "theme": "neutral", "flowchart": { "curve": "bumpY" } }}%% graph TD A(What is your hardware?) -->|GPU| B(Is your text usually smaller than 500 characters?) A -->|CPU| C(Is a 0.4% accuracy loss acceptable?) B -->|yes| D[onnx-O4] B -->|no| F[float16] C -->|yes| G[onnx-int8] C -->|no| H(Do you have an Intel CPU?) H -->|yes| I[openvino] H -->|no| J[onnx] click D "#optimizing-onnx-models" click F "#pytorch" click G "#quantizing-onnx-models" click I "#openvino" click J "#onnx"

Note

Your milage may vary, and you should always test the different backends with your specific model and data to find the best one for your use case.