Multimodal Search with Snowflake Embedding and MAX Engine

April 26, 2024

Ehsan M. Kermani

AI DevRel

In our previous blog post on  semantic search with MAX Engine, we demonstrated that MAX Engine significantly outperformed both PyTorch eager and ONNX runtime—by factors of 2 to 2.8 times, respectively, across different batch sizes. Today, we're exploring how a multimodal approach can further enhance semantic search by combining textual and visual data and we will discuss how MAX Engine can optimize multiple models for inference.

The integration of multiple data types in semantic search presents a unique challenge: how to effectively blend and understand information that comes in different forms, such as text and images. This approach is crucial because it mirrors the way humans process information, leveraging both visual cues and textual context. Pioneering models, such as OpenAI CLIP, have successfully merged text and image understanding through a single neural network trained on vast amounts of data.

For this blog, we will use lightweight models such as snowflake-arctic-embed-m which is a state-of-the-art, pretrained text embedding model and MobileNetV2 image classification model that we extract the logits for our image embeddings. Inspired by CLIP, we create a custom deep learning model that takes image and text embeddings and maps them into a common shared space. For training, we utilize a contrastive loss function, which helps in aligning the text and image embeddings in the shared latent space. Additionally, we use the Flickr30k dataset which contains 31,000 images collected from Flickr, together with 5 reference sentences (captions) provided by human annotator, for training, evaluating and testing our multimodal model.

We will demonstrate how to convert the models to TorchScript and optimize them in MAX Engine for inference. Lastly, we compute the cosine similarity matrix of the image-caption pairs in the test dataset and conclude by visualizing the ground-truth captions alongside the top five most similar captions found using the similarity matrix.

To install MAX, please check out Get started with MAX Engine. Also have a look at Getting Started with MAX Developer Edition in case you missed it. 

The code for this blog post is available in our GitHub repository. The MAX version for this blog is max 24.2.1 (58157dc0).

Data inspection

We start by downloading the Flickr30k dataset which is a comprehensive collection of 31,000 images sourced from Flickr, each accompanied by five captions written by human annotators.The dataset is publicly available on HuggingFace. To construct our (image, caption) pairs, we simplify the dataset by randomly selecting one out of the five available captions for each image, ensuring that each pair consists of one image and its corresponding caption.

The dataset is divided into three distinct parts, which serve different roles throughout our experiments:

  • Train: This subset is used for training our model, allowing it to learn and adapt to the task of associating images with textual descriptions. 
  • Validation: This subset helps in tuning the model and validating its performance during training, providing a way to monitor overfitting and make adjustments to model hyper-parameters. 
  • Test: Used for evaluating the model’s performance after training, this subset helps assess how well the model generalizes to new, unseen data.

For effective visualization and interaction with the dataset, we use the fiftyone Python package that is particularly suited for tasks like ours because it offers a variety of tools that facilitate the quick construction, visualization, and analysis of image datasets. 

Here is our the dashboard looks like

Note that when using the fiftyone package over SSH, by default the server runs on port 5151. To access it remotely, we need to setup port forwarding via ssh -L 5151:localhost:5151 username@server.

Building our own multimodal model

Our multimodal model consists of three parts, each of which we will explain in detail.

Text embedding

For text embedding, we use snowflake-arctic-embed-m which is a state-of-the-art, pretrained text embedding model with 768 embedding dimensions that supports up to 512 tokens and occupies 436MB of disk size. It's particularly suitable for CPU-only applications due to its low runtime memory usage of 0.41GB (float32), which optimizes performance without compromising speed. 

Below is the code snippet demonstrating how to load the model, compute the query embeddings, and normalize them for subsequent processing.

import torch import torch.nn.functional as F from transformers import AutoModel, AutoTokenizer tokenizer = AutoTokenizer.from_pretrained('Snowflake/snowflake-arctic-embed-m') text_model = AutoModel.from_pretrained('Snowflake/snowflake-arctic-embed-m', add_pooling_layer=False) text_model.eval() def text_embeddings(texts): tokens = tokenizer(list(texts), padding=True, truncation=True, return_tensors="pt", max_length=512) with torch.no_grad(): query_embeddings = text_model(**tokens)[0][:, 0] normalized = F.normalize(query_embeddings, p=2, dim=1) return normalized

This function will be integral to our semantic search system, allowing us to accurately and efficiently process textual information alongside visual data, thereby enhancing the overall search capability.

Image embedding

For the image component of our multimodal search system, we use the MobileNetV2, a lightweight classification model pretrained on the ImageNet-1k dataset. This model is designed to handle images of 224 by 224 resolution, making it an excellent choice for efficient image processing without sacrificing accuracy. 

Here's how to load the model and compute the normalized logits, which we use as embedding values. This process ensures that the embeddings are on a consistent scale, which is crucial for accurately comparing images in our search system.

from transformers import MobileNetV2ForImageClassification image_model = MobileNetV2ForImageClassification.from_pretrained( "google/mobilenet_v2_1.0_224" ) image_model.eval() def image_embeddings(images): with torch.no_grad(): image_logits = image_model(images).logits normalized = F.normalize(image_logits, p=2, dim=1) return normalized

This method provides a streamlined approach to obtaining image embeddings, which we will use in the next section.

Custom multimodal model

The following is an example of a simple multimodal model, designed to integrate text and image embeddings effectively. This model aims to project both types of embeddings into a common space where their similarities can be compared directly, which is crucial for tasks like multimodal semantic search. 

Here's a breakdown of how the model is structured and functions:

import torch import torch.nn as nn class CustomMultiModalModel(nn.Module): def __init__(self, text_embedding_dim=768, image_embedding_dim=1001, common_dim=64): super().__init__() self.text_transform = nn.Sequential( nn.Linear(text_embedding_dim, common_dim), nn.ReLU() ) self.image_transform = nn.Sequential( nn.Linear(image_embedding_dim, common_dim), nn.ReLU() ) = nn.Linear(common_dim, common_dim) def forward(self, image_embeddings, text_embeddings): transformed_image = self.image_transform(image_embeddings) transformed_text = self.text_transform(text_embeddings) project_image = project_text = return project_image, project_text

The model consists of

  • Text and Image Transformation Modules: These are simple neural networks consisting of a linear layer followed by a ReLU activation function. The linear layer reduces the dimensionality of the text and image embeddings to common_dim, making the embeddings more manageable and focused on the most essential features for comparison. 
  • Final Linear Transformation: This is a shared linear layer applied to both the transformed text and image embeddings. By projecting both types of embeddings into the same space with the same transformation, we ensure that they are directly comparable. This is crucial for tasks that involve finding the similarity between text and images. 

Overall, this simple model facilitates the interaction between text and image data in a meaningful way, which enhances the capability of our multimodal semantic search. By ensuring that text and images are represented in the same vector space, we improve the system's ability to match text queries with relevant images and vice versa.

Contrastive loss function

For our model to be able to distinguish between similar and dissimilar pairs of data, we use a contrastive loss function. The core idea behind contrastive loss is to ensure that representations of similar items are pulled closer together in the embedding space, while representations of dissimilar items are pushed apart. This is typically achieved by comparing the distance between embeddings with a margin—a threshold beyond which the distance between dissimilar pairs should ideally lie. 

One foundational paper that discusses the use of contrastive loss in deep learning is "Dimensionality Reduction by Learning an Invariant Mapping" by Hadsell, Chopra, and LeCun (2006) which has since been widely adopted and adapted for various applications in machine learning, particularly in supervised and semi-supervised settings where relational knowledge between data points is crucial.

Our ContrastiveLoss class is a specialized module for calculating contrastive loss, which is particularly effective for learning tasks where the goal is to learn similarities and differences between pairs of items—in our case, between image and text embeddings. We will use a simpler form (without including negative pairs i.e. negative hard mining) of the following formulation from SimCLR paper, where sim is the cosine similarity function, tau is the temperature.

Here's a breakdown of its components and functionality:

class ContrastiveLoss(nn.Module): def __init__(self, temperature=0.07): super().__init__() self.temperature = temperature self.cosine_similarity = nn.CosineSimilarity(dim=2) def forward(self, image_features, text_features): logits = self.cosine_similarity(image_features.unsqueeze(0), text_features.unsqueeze(1)) / self.temperature batch_size = text_features.shape[0] # Labels are the main diagonal, as corresponding text and images are aligned labels = torch.arange(batch_size) loss = nn.CrossEntropyLoss()(logits, labels) return loss
  • Temperature: This scalar is used to scale the cosine similarity outputs, which affects the sharpness of the distribution of output probabilities. A lower temperature makes the distribution harder, amplifying the differences between the more and less similar pairs, which is crucial for training stable and robust models. 
  • Cosine Similarity: This function measures the cosine of the angle between two vectors. In this model, it is used to calculate the similarity between the embeddings of images and texts. By treating the embeddings as vectors in a high-dimensional space, the cosine similarity directly correlates to how similar the actual content of the texts and images are. 
  • Logits Calculation: The cosine similarities are computed between each image and text embedding pair, scaled by the temperature, and structured in a matrix where each row corresponds to an image and each column to a text. This matrix setup facilitates the comparison of each image with each text. 
  • Labels and Loss Calculation: The true matches between images and texts are on the main diagonal of the logits matrix (since each row and column index of the diagonal corresponds to paired image and text embeddings). The CrossEntropyLoss function then uses these labels to compute the loss, which encourages the model to correctly align the embeddings of corresponding images and texts.

Train and evaluation

To initiate training of our custom multimodal model, we will keep the setup straightforward by running the training loop for only five epochs without adjusting any hyper-parameters. This approach allows us to quickly validate our model's architecture and initial performance. Here is the Python code that handles the training process and saves a checkpoint at the end of each epoch to track the model's progress and potentially resume training if needed.

lr = 3e-4 epochs = 5 batch_size = 64 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) model = CustomMultiModalModel() loss_fn = ContrastiveLoss() optimizer = torch.optim.Adam(model.parameters(), lr=lr) train_loss_history = [] val_loss_history = [] for epoch in range(1, epochs + 1): model.train() train_losses = [] for image_batch, text_batch in train_loader: image_features = image_embeddings(image_batch) text_features = text_embeddings(text_batch) optimizer.zero_grad() image_out, text_out = model(image_features, text_features) train_loss = loss_fn(image_out, text_out) train_loss.backward() optimizer.step() train_losses.append(train_loss.item()) print(f"Epoch {epoch}, Loss: {train_loss.item()}") average_train_loss = sum(train_losses) / len(train_losses) train_loss_history.append(average_train_loss) print(f"Epoch {epoch}, Average train loss: {average_train_loss}"), f"custom_multimodal_model_epoch_{epoch}.pt") model.eval() val_losses = [] with torch.no_grad(): for image_batch, text_batch in val_loader: image_features = image_embeddings(image_batch) text_features = text_embeddings(text_batch) image_out, text_out = model(image_features, text_features) val_loss = loss_fn(image_out, text_out) val_losses.append(val_loss.item()) average_val_loss = sum(val_losses) / len(val_losses) val_loss_history.append(average_val_loss) print(f"Epoch {epoch}, Average validation loss: {average_val_loss}") print("=" * 80)

Once the training is complete, we can visualize the training and validation losses. For the remainder of this blog post, we will use the checkpoint from epoch 2.

Inference with MAX Engine

In this section, we will see how to prepare the models to be used with the MAX Engine.

Convert to TorchScript

MAX Engine requires PyTorch models to be in TorchScript formats. Below, we will demonstrate how to convert all three parts of our model—the image embedding, text embedding, and the custom multimodal part—into TorchScript format.

Note that for PyTorch tracing, we only need to provide dummy inputs but with correct shapes as follows

example_image_inputs = torch.zeros((1, 3, 224, 224)) with torch.no_grad(): traced_image_model = torch.jit.trace(image_model, example_image_inputs, strict=False)"image_embedding.torchscript") example_text_inputs = { "input_ids": torch.zeros((1, 512), dtype=torch.long), "token_type_ids": torch.zeros((1, 512), dtype=torch.long), "attention_mask": torch.zeros((1, 512), dtype=torch.long), } with torch.no_grad(): traced_text_model = torch.jit.trace(text_model, example_kwarg_inputs=dict(example_text_inputs), strict=False)"text_embedding.torchscript") model = CustomMultiModalModel() model.load_state_dict(torch.load("")) model.eval() example_image_embeddings = torch.rand(1, 1001) example_text_embeddings = torch.rand(1, 768) with torch.no_grad(): traced_script_module = torch.jit.trace(model, (example_image_embeddings, example_text_embeddings))"custom_multimodal_model.torchscript")

We are now ready to employ MAX Engine to optimize our models for fast inference.

Load and optimize models with MAX Engine

Next, we create a session object and provide input specifications for all three parts, loading them into their optimized MAX models. Here, we use MAX Engine dynamic shape capabilities to compile the models. For example, to compile the models for dynamic batch sizes, we use None as the first dimension of TorchInputSpec as follows

from max import engine session = engine.InferenceSession() max_image_embedding = session.load("image_embedding.torchscript", engine.TorchLoadOptions([ engine.TorchInputSpec(shape=[None, 3, 224, 224], dtype=engine.DType.float32) ])) max_custom_multimodal = session.load("custom_multimodal_model.torchscript", engine.TorchLoadOptions([ engine.TorchInputSpec(shape=[None, 1001], dtype=engine.DType.float32), engine.TorchInputSpec(shape=[None, 768], dtype=engine.DType.float32), ]))

and for the text embedding to account for variable sequence lengths ,we use None as the second dimension in TorchInputSpec

max_text_embedding = session.load("text_embedding.torchscript", engine.TorchLoadOptions([ engine.TorchInputSpec(shape=[None, None], dtype=engine.DType.int64), engine.TorchInputSpec(shape=[None, None], dtype=engine.DType.int64), engine.TorchInputSpec(shape=[None, None], dtype=engine.DType.int64), ]))

Create similarity matrix and inspect the results

In this section, we use our optimized MAX models to generate embeddings for images and captions from our test dataset and then compute the projected embeddings into a common space. This setup allows us to create a similarity matrix, enabling us to assess and visualize the alignment and correlation between image and text embeddings. For demonstration we use a fixed batch size of 16, but note that our model was compiled to support any feasible batch sizes. In real world use-cases if the batch size is fixed for inference, MAX Engine can optimize the model further. For that we no longer need to use None in TorchInputSpec and we can use a predetermined fixed batch size value in the compilation process.

from import DataLoader import torch.nn.functional as F test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False) all_image_embeddings = [] all_text_embeddings = [] for image, caption in test_loader: img_emb = max_image_embedding.execute(pixel_values=image)["result0"]["logits"] tokenized_caption = tokenizer(caption, return_tensors="pt") txt_emb = max_text_embedding.execute(**tokenized_caption)["result0"]["last_hidden_state"][:, 0, :] ret = max_custom_multimodal.execute(image_embeddings=img_emb, text_embeddings=txt_emb) projected_image_emb, projected_text_emb = ret["result0"], ret["result1"] all_image_embeddings.append(torch.from_numpy(projected_image_emb)) all_text_embeddings.append(torch.from_numpy(projected_text_emb)) assert len(all_image_embeddings) == len(all_text_embeddings) all_image_embeddings =, dim=0) all_text_embeddings =, dim=0) cosine_similarities = F.cosine_similarity(all_image_embeddings.unsqueeze(1), all_text_embeddings.unsqueeze(0), dim=2)

Given the cosine_similarities which is a matrix of shape 1000 by 1000 (since there are 1000 images in the test dataset), we can retrieve the k highest values and their corresponding indices from the cosine similarity matrix. In this case, k is set to 5, meaning the code identifies the top 5 most similar text embeddings for each image embedding. This is useful for tasks like retrieving the most relevant textual descriptions for given images or vice versa.

top_k_values, top_k_indices = torch.topk(cosine_similarities, k=5, dim=1)

Finally, we can visualize the top 5 captions for each image using the following approach. This visualization not only allows us to review the most relevant captions identified by our model but also serves as a critical assessment tool. As an example, is the 

  • ground-truth caption: A man is an orange hat starring at something
  • top predicated caption: A man with glasses is wearing a beer can crocheted hat

Additionally, here is a view of the test dataset with both ground-truth and predicted captions. This side-by-side comparison provides a clear illustration of our model’s performance, allowing us to directly assess the accuracy of the predictions against the actual captions.

Next steps

Here are a few potential next steps to consider

  • Experiment with different image and text embeddings to assess the impact on your results.
  • Modify the architecture of our custom model and perform hyper-parameter tuning.
  • Use evaluation metrics such as Precision@k and NDCG@k to obtain more reliable numerical results.
  • We are excited to see what you build! 🚀 Share your own end-to-end MAX pipelines with us.


In this blog post, we have explored the advantages of employing a multimodal approach to semantic search using the MAX Engine. We demonstrated how to create and train a multimodal model for enhanced search capabilities, and we showed how to effectively use MAX Engine when dealing with multiple interconnected models. By visualizing the top predicted captions alongside the ground-truth data, we were able to conduct a direct assessment of our model’s applicability. This not only highlighted the potential of multimodal systems in understanding complex datasets but also underscored the importance of continuous refinement and testing to achieve high accuracy and relevance in search results.

Additional resources:

Report feedback, including issues on our Mojo and MAX GitHub tracker.

Until next time!🔥

Ehsan M. Kermani
AI DevRel