Leveraging MAX Engine's Dynamic Shape Capabilities

March 28, 2024

Mikhail Zolotukhin

AI Compiler Engineer

Ehsan M. Kermani

AI DevRel

In this blog post, we will dive deep into the dynamic shapes support in MAX Engine’s 24.2 release. We will do so by first defining what dynamic shapes mean in machine learning, discussing their various types and use cases. We will then see how to use dynamic shapes in MAX Engine. Finally, for demonstration purposes, we will compare the average latency of the dynamic shapes to the static shapes for the BERT model on the GLUE dataset.

What are dynamic shapes?

In machine learning, dynamic shapes refer to the ability of models to handle inputs of various sizes automatically. This capability is essential for managing real-world data, which often varies in size and dimension. This includes processing input text of unknown length - a common situation in applications such as chatbots - or dealing with batches of inputs whose sizes are not known in advance (dynamic batching).

What are the types of dynamism in AI models?

In AI models, dynamic shapes can manifest in several ways depending on the degree of “dynamism” involved:

  1. At the most basic level, tensor shapes can be entirely static, meaning their rank (the number of dimensions), data type, and dimensions are all specified in advance.
  2. Next, tensors might have a known rank, data type and sizes across some dimensions, but not all. For example, if a model processes images, the input tensor might be a 4-d tensor in an NCHW layout (Batch size, Channels, Height, Width). In this case dimension N can be dynamic to allow for varying batch sizes, H and W can be dynamic too if we expect the model to work on images of arbitrary sizes. However, C is often static - assuming all input images are expected to have the same number of channels (e.g. R, G, B for color images).
  3. Moving towards greater dynamism, tensors might have a known rank and data type, but unknown shapes across all dimensions. Continuing our previous example, consider a model that works with both colored and black-and-white images, making the C dimension (channels) also variable.
  4. If we increase the dynamism degree even more, we get rank-dynamic tensors - i.e. tensors, for which we don’t know their rank, but still know their data type.
  5. Finally, in the most dynamic scenario, we might lack any information about the tensors and don’t even know their data type.

As the level of dynamism increases, it becomes more and more difficult to optimize performance of the model. This is why most existing frameworks nowadays focus on handling the first three cases (1-3), where at least some tensor properties are known. Let’s examine that in more detail.

What type of dynamism does MAX Engine support in 24.2 release? 

As we explained various types of dynamism above, in the 24.2 release, MAX Engine supports dynamism across any dimension, but the rank and data type of the tensors need to be known. When MAX is used for running TensorFlow or ONNX models, this information is extracted directly from the model, and when we use a TorchScript model we require the user to provide such information explicitly since the model itself doesn’t contain it. Similarly, MAX Graph API also allows specifying which dimensions are dynamic in the model.

MAX Engine dynamic shapes support for PyTorch models

In this section, we will explain how to use MAX Engine for applications that require dynamic shapes. For the sake of simplicity, we will use the BERT model, but exactly the same technique is applicable to all other models including state-of-the-art Large Language Models.

Here, our model will process an input text, trying to guess a masked word in it. Because we don’t know in advance how many words (or tokens, to be more precise) will be in the input, our model needs to be able to handle inputs of unknown shapes.

For PyTorch models, MAX Engine requires a TorchScript format of the model. Here is how to convert the BERT model to TorchScript by tracing a dummy input:

import torch from transformers import BertForMaskedLM model = BertForMaskedLM.from_pretrained("bert-base-uncased").eval() batch = 1 seqlen = 512 inputs = { "input_ids": torch.zeros((batch, seqlen), dtype=torch.int64), "attention_mask": torch.zeros((batch, seqlen), dtype=torch.int64), "token_type_ids": torch.zeros((batch, seqlen), dtype=torch.int64), } with torch.no_grad(): torchscript_model = torch.jit.trace(model, example_kwarg_inputs=dict(inputs), strict=False) torch.jit.save(torchscript_model, MODEL_PATH)

Next, we create an inference session for the MAX Engine and load our TorchScript model into it, providing the input information along with the model.

from max import engine session = engine.InferenceSession() input_specs = [ engine.TorchInputSpec(shape=[None, None], dtype=engine.DType.int64) for tensor in inputs.values() ] options = engine.TorchLoadOptions(input_specs) maxmodel = session.load(MODEL_PATH, options)

Note that instead of passing concrete sizes of the tensors below, we will pass None for each dimension that needs to be treated as dynamic:

input_specs = [ engine.TorchInputSpec(shape=[None, None], dtype=engine.DType.int64) for tensor in inputs.values() ]

Side note: Please refer to the relevant documentations if you are working with MAX using C API or Mojo API and learn how dynamic dimensions need to be specified.

The maxmodel is now ready for inference. As an example, we tokenize the input "Paris is the [MASK] of France.", then use maxmodel.execute to get the output and post-process the output to find the predicted_token as follows:

from transformers import AutoTokenizer INPUT_EXAMPLE = "Paris is the [MASK] of France." tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") inputs = tokenizer(text, return_tensors="pt") output = model.execute(**inputs)["result0"] masked_index = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True )[1] logits = torch.from_numpy(output[0, masked_index, :]) predicted_token_id = logits.argmax(dim=-1) predicted_token = tokenizer.decode( [predicted_token_id], skip_special_tokens=True, clean_up_tokenization_spaces=True) # replace the [MASK] with the predicted token print(INPUT_EXAMPLE.replace('[MASK]', predicted_token))

which outputs:

Paris is the capital of France.

Latency comparison of dynamic vs static shapes

For demonstration purposes, we compare here the average latency between dynamic shapes and static shapes for the BERT model on the GLUE dataset. With dynamic shapes our model processes only the input data and nothing else, which can be beneficial for datasets where the majority of samples’ lengths are short. With static shapes, however, the inputs are padded to their maximum length which is 512 in the BERT example, which results in unnecessary computations and wasted resources.


Throughout this post, we have explored the concept of dynamic shapes and the types of dynamism in AI models. The incorporation of dynamic shapes allows models to handle a wide array of input sizes, which is a common requirement in real-world scenarios ranging from text processing to image analysis. We have showcased MAX Engine’s dynamic shape support for the BERT model and then how to run inference for a sample input. At last, we have compared the average latency of the dynamic shapes to the static shapes for the BERT model on the GLUE dataset showing that on average dynamic shapes have lower latency compared to the static shapes. This is due to having shorter tensor lengths on average compared to longer tensor lengths for the static shapes case.

Additional resources:

Report feedback, including issues on our Mojo and MAX GitHub tracker.

Until next time!🔥

Mikhail Zolotukhin
AI Compiler Engineer
Ehsan M. Kermani
AI DevRel