Semantic Search with MAX Engine

March 21, 2024

Ehsan M. Kermani

AI DevRel

In the field of natural language processing (NLP), semantic search focuses on understanding the context and intent behind queries, going beyond mere keyword matching to provide more relevant and contextually appropriate results. This approach relies on advanced embedding models to convert text into high-dimensional vectors, capturing the complex semantics of language. In this blog post, we will use Amazon Multilingual Counterfactual Dataset (AMCD) which comprises sentences from Amazon customer reviews annotated for counterfactual detection (CFD) in a binary classification task. Counterfactual statements refer to hypothetical scenarios that have not occurred or are impossible to occur. Such statements are typically recognized as having the structure – If p were true, q would also be true, where both the antecedent (p) and the consequent (q) are understood or presumed to be untrue. For instance, a review stating "If this camera had a better lens, my photos would be perfect" suggests a desired improvement (a better lens) that is currently absent, impacting the outcome (perfect photos).

Our classifier will employ the bge-base-en-v1.5 model within the MAX Engine which has 768 embedding dimensions. The BGE model is distinguished as one of the leading text embedding models on the MTEB leaderboard characterized by its minimal disk size of 416MB and available variants with 768 and 1024 embedding dimensions. Furthermore, we will leverage a vector database to store embeddings generated from the training dataset, simulating real-world conditions for batched inference processing. During inference, we will identify the top 10 most similar reviews (using cosine similarity) and assign probabilities to test queries. Subsequently, we will evaluate the classifier's effectiveness through metrics such as accuracy, F1 score, precision, and recall, applying a 0.5 cutoff threshold. Ultimately, we will contrast the performance of MAX Engine with PyTorch and ONNX runtime across various batch sizes, illustrating that 

  • For small batch sizes on CPU, MAX Engine outperforms PyTorch and ONNX runtime by up to 1.6 and 2.8 times, respectively.
  • With large batch sizes on CPU, MAX Engine outperforms PyTorch and ONNX runtime by 2 and 1.8 times, respectively.

To install MAX, please check out Get started with MAX Engine. Also have a look at Getting Started with MAX Developer Edition in case you missed it.

The code for this blog post is available in our GitHub repository. The MAX version for this blog is 24.1.1 (0ab415f7).

Dataset and input tokenizer

Let’s first examine the data in Amazon Multilingual Counterfactual Dataset (AMCD)

import pandas as pd data = pd.read_csv("amazon-multilingual-counterfactual-dataset/data/EN_train.tsv", sep="\t") data.head()

The dataframe consists of two columns: sentence (from Amazon customer review) and is_counterfactual (the label) and a total of 4018 samples.

For example:

Next we tokenize all the input sentences in data as follows

from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-base-en-v1.5") inputs = tokenizer(list(data['sentence']), return_tensors="pt", max_length=512, padding=True, truncation=True)

With the inputs tokenized, we are now ready to proceed to inference and create sentence embeddings.

MAX Engine inference

In this blog post, we will utilize the ONNX version of the model, available on HuggingFace. We can obtain use the following command (ensure you have Git LFS installed)

git lfs install git clone

The ONNX model is located at bge-base-en-v1.5/onnx/model.onnx

Below, we create a session object and load the model into maxmodel. We also examine the input and output tensors, noting their names, shapes, and data types:

from max import engine session = engine.InferenceSession() maxmodel = session.load("bge-base-en-v1.5/onnx/model.onnx") for tensor in maxmodel.input_metadata: print(f'input name: {}, shape: {tensor.shape}, dtype: {tensor.dtype}') for tensor in maxmodel.output_metadata: print(f'output name: {}, shape: {tensor.shape}, dtype: {tensor.dtype}')

The model has three input tensors — input_ids, attention_mask, and token_type_ids — and one output tensor, last_hidden_state:

input name: input_ids, shape: [None, None], dtype: DType.int64 input name: attention_mask, shape: [None, None], dtype: DType.int64 input name: token_type_ids, shape: [None, None], dtype: DType.int64 output name: last_hidden_state, shape: [None, None, 768], dtype: DType.float32

The model's pooling configuration file is as follows which we will use later to accurately obtain our sentence embeddings.

{ "word_embedding_dimension": 768, "pooling_mode_cls_token": true, "pooling_mode_mean_tokens": false, "pooling_mode_max_tokens": false, "pooling_mode_mean_sqrt_len_tokens": false }

Optional: Convert to ONNX using optimum

Another notable option is the conversion of models to the ONNX format using the optimum package which can be done through its command-line-interface (CLI). 

Note that converting to ONNX offers benefits like framework interoperability across different platforms. For instance, to convert the BAAI/bge-base-en-v1.5 model to ONNX:

optimum-cli export onnx --model "BAAI/bge-base-en-v1.5" "./onnx/bge-base-en-v1.5"

Sentence embeddings

To enhance efficiency, especially with large datasets, we batch the input sentences before embedding. This approach not only accelerates the processing but also helps manage memory usage more effectively.

Here, in each batch we simply call maxmodel.execute and iterate on the training data until all sentences are embedded.

import numpy as np from import DataLoader, TensorDataset ds = TensorDataset(inputs["input_ids"], inputs["token_type_ids"], inputs["attention_mask"]) data_loader = DataLoader(ds, batch_size=128, shuffle=False) output_embeddings = [] for batch in data_loader: batch_input_ids, batch_token_type_ids, batch_attention_mask = batch batch_outputs = maxmodel.execute(input_ids=batch_input_ids, token_type_ids=batch_token_type_ids, attention_mask=batch_attention_mask) last_hidden_state = batch_outputs["last_hidden_state"] # Extract the CLS token embedding sentence_embeddings = last_hidden_state[:, 0, :] output_embeddings.append(sentence_embeddings) # concatenate all into one array all_embeddings = np.concatenate(output_embeddings, axis=0) print(f"All embeddings dimensions: {all_embeddings.shape}")

which outputs

All embeddings dimensions: (4018, 768)

After obtaining the embeddings, they can be utilized for various NLP tasks, such as semantic similarity, clustering, or as input features for machine learning models. In the next section, we will store them in a vector database for semantic search.

Using a Vector Database

Vector databases excel in managing and querying high-dimensional data, making them ideal for storing embeddings. We chose ChromaDB which is an embedded vector database and is known for its efficiency and straightforward usage, particularly fitting for small to medium-sized applications. ChromaDB stands out due to its fast querying capabilities and lightweight nature.

To start, we create a client and a collection to store our embeddings as follows

import chromadb chroma_client = chromadb.Client() collection = chroma_client.create_collection(name="counterfactual_collection", metadata={"hnsw:space": "cosine"}) for i, (documents, embeddings, label) in enumerate(zip(list(data['sentence']), all_embeddings.tolist(), list(data['is_counterfactual']))): collection.upsert(ids=[str(i)], documents=documents, embeddings=embeddings, metadatas=[{"is_counterfactual": label}])

Search in Vector Database collection

To demonstrate the practical application, we query the database using a test sentence. After tokenizing this sentence and generating its embedding, we search the vector database for the most similar entries, using cosine similarity to identify and return the top 10 most similar items. Cosine similarity is particularly effective for embeddings because it focuses on the orientation of vectors rather than their magnitude. Finally, we assign a probability by normalizing the count of positive is_counterfactual results from the top 10 queries, leveraging cosine similarity.

query = "I've worn my boots a couple times without proper socks and I can definitely tell the difference!" query_inputs = tokenizer(query, return_tensors="np", max_length=512, padding=True, truncation=True) query_output = maxmodel.execute(input_ids=query_inputs["input_ids"], token_type_ids=query_inputs["token_type_ids"], attention_mask=query_inputs["attention_mask"]) # Extract the CLS token embedding query_embeddings = query_output["last_hidden_state"][:, 0, :] results = collection.query(query_embeddings, n_results=10) counterfactual_prob = sum([r["is_counterfactual"] for r in results["metadatas"][0]]) / len(results["metadatas"][0]) print(f"counterfactual probability is {counterfactual_prob * 100}%")

which outputs

counterfactual probability is 10.0%

Assess test accuracy, F1-score, precision and recall

We evaluate our model on a test dataset using common metrics.

  • Accuracy provides a general sense of performance
  • F1-score balances precision and recall
  • Precision measures the model's exactness and
  • Recall assesses its completeness
import pandas as pd from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score test_data = pd.read_csv("amazon-multilingual-counterfactual-dataset/data/EN_test.tsv", sep="\t") cutoff_threshold = 0.5 def get_counterfactual_prob(sentence): query_inputs = tokenizer(sentence, return_tensors="np", max_length=512, padding=True, truncation=True) query_output = maxmodel.execute(input_ids=query_inputs["input_ids"], token_type_ids=query_inputs["token_type_ids"], attention_mask=query_inputs["attention_mask"]) # Extract the CLS token embedding query_embeddings = query_output["last_hidden_state"][:, 0, :] results = collection.query(query_embeddings, n_results=10) counterfactual_prob = sum([r["is_counterfactual"] for r in results["metadatas"][0]]) / len(results["metadatas"][0]) return counterfactual_prob predictions = [] for index, row in test_data.iterrows(): counterfactual_prob = get_counterfactual_prob(row['sentence']) prediction = 1 if counterfactual_prob > cutoff_threshold else 0 predictions.append(prediction) accuracy = accuracy_score(test_data['is_counterfactual'], predictions) f1 = f1_score(test_data['is_counterfactual'], predictions) precision = precision_score(test_data['is_counterfactual'], predictions) recall = recall_score(test_data['is_counterfactual'], predictions) print(f"Accuracy: {accuracy:.2f}") print(f"F1 Score: {f1:.2f}") print(f"Precision: {precision:.2f}") print(f"Recall: {recall:.2f}")

which outputs

Accuracy: 0.88 F1 Score: 0.64 Precision: 0.75 Recall: 0.56

Comparing MAX Engine performance against PyTorch eager and ONNX runtime

Recall that to efficiently compute all sentence embeddings, we processed them in batches, using a batch size of 128. This batching is particularly important in data intensive scenarios for optimizing resource utilization and processing speed. Consequently, we aim to compare the performance of MAX Engine against PyTorch and ONNX runtime across various batch sizes to understand their respective efficiencies in handling batched data.

To make a compelling comparison between MAX Engine, PyTorch and ONNX runtime, we meticulously selected a range of batch sizes and for better visualization, we divided them up into two categories of 

  • smaller batch sizes: 1 up to 32 and 
  • larger batch sizes: 64 up to 4096

These wide arrays allow us to observe the performance scalability and efficiency of each framework under different load conditions. The runtime for each batch size is measured, offering a clear picture of how each framework handles varying volumes of data. This evaluation is crucial for developers and engineers to make informed decisions about the tools and frameworks best suited for their specific NLP tasks, especially in resource-intensive scenarios such as working with large datasets. For completeness, runtime measurements were done on an AWS c5.12xlarge instance.

import gc import torch from import DataLoader, TensorDataset import numpy as np import time from transformers import AutoModel from optimum.onnxruntime import ORTModelForFeatureExtraction model = AutoModel.from_pretrained("BAAI/bge-base-en-v1.5") model.eval() ortmodel = ORTModelForFeatureExtraction.from_pretrained("BAAI/bge-base-en-v1.5", revision="refs/pr/6", file_name="onnx/model.onnx") def measure_runtime(inputs, model_fn, batch_sizes, is_pytorch=True): results = {} ds = TensorDataset(inputs["input_ids"], inputs["token_type_ids"], inputs["attention_mask"]) for batch_size in batch_sizes: data_loader = DataLoader(ds, batch_size=batch_size, shuffle=False) times = [] for batch in data_loader: start_time = time.time() batch_input_ids, batch_token_type_ids, batch_attention_mask = batch if is_pytorch: with torch.no_grad(): _ = model_fn(input_ids=batch_input_ids, token_type_ids=batch_token_type_ids, attention_mask=batch_attention_mask) else: _ = model_fn(input_ids=batch_input_ids, token_type_ids=batch_token_type_ids, attention_mask=batch_attention_mask) end_time = time.time() times.append(end_time - start_time) gc.collect() times = np.array(times) mean_time = np.mean(times) std_time = np.std(times) # 95% confidence interval with normality distribution assumption of time measurements confidence_interval = 1.96 * (std_time / np.sqrt(len(times))) results[batch_size] = {'mean_time': mean_time, 'std_time': std_time, 'confidence_interval': confidence_interval} return results

Now we use the PyTorch, ONNX model and MAX Engine model individually and plot their performance.

import matplotlib.pyplot as plt small_batch_sizes = [2 ** i for i in range(6)] small_results = measure_runtime(inputs, model, small_batch_sizes) small_maxresults = measure_runtime(inputs, lambda **kwargs: maxmodel.execute(**kwargs), small_batch_sizes, is_pytorch=False) small_ortresults = measure_runtime(inputs, ortmodel, small_batch_sizes, is_pytorch=False) def plot_performance_comparison(batch_sizes, results_with_labels, title): plt.figure(figsize=(10, 6)) for label, res in results_with_labels: mean_times = [res[bs]['mean_time'] for bs in batch_sizes] conf_intervals = [res[bs]['confidence_interval'] for bs in batch_sizes] plt.errorbar(batch_sizes, mean_times, yerr=conf_intervals, fmt='-o', capsize=5, label=label) plt.xlabel('Batch Size') plt.ylabel('Mean Time (seconds)') plt.title(title) plt.legend() plt.grid(True) plot_performance_comparison(small_batch_sizes, [("PyTorch", small_results), ("MAX Engine", small_maxresults), ("ONNX runtime", small_ortresults)], title="Batch Size (1 up to 32) vs Mean Processing Time with 95% Confidence Intervals")

This analysis revealed that for smaller batch sizes (1 up to 32), MAX Engine can be up to 1.6 times faster than PyTorch and is up to 2.8 times faster than ONNX runtime for batch inference.

And for larger batch sizes (64 up to 4096), MAX Engine can be up to 2 and 1.8 times faster than PyTorch and ONNX runtime, respectively. This is showcasing MAX Engine efficiency in handling high-volume data processing tasks.

large_batch_sizes = [2 ** i for i in range(6, 13)] large_results = measure_runtime(inputs, model, large_batch_sizes) large_maxresults = measure_runtime(inputs, lambda **kwargs: maxmodel.execute(**kwargs), large_batch_sizes, is_pytorch=False) plot_performance_comparison(large_batch_sizes, [("PyTorch", large_results), ("MAX Engine", large_maxresults), ("ONNX runtime", large_ortresults)], title="Batch Size (64 up to 4096) vs Mean Processing Time with 95% Confidence Intervals")

which shows


We have illustrated the application of MAX Engine with a pre-trained model for counterfactual binary classification, demonstrating the process of storing embeddings in a vector database suited for inference. Furthermore, our comparison between MAX Engine and PyTorch across various batch sizes has revealed that MAX Engine can achieve up to 1.6 and 2.8 times the speed up against PyTorch and ONNX runtime for varying small batch sizes when running on a CPU, and is 2 and 1.8 times faster against PyTorch and ONNX runtime on large batch sizes, respectively. This efficiency gain highlights MAX Engine's potential to significantly enhance processing speed and resource utilization in large-scale NLP tasks.

Additional resources:

Report feedback, including issues on our Mojo and MAX GitHub tracker.

Until next time!🔥

Ehsan M. Kermani
AI DevRel