MAX Graph API Tutorial

May 14, 2024

Ehsan M. Kermani

AI DevRel

MAX Engine is a next-generation compiler and runtime library for running AI inference. With support for PyTorch (TorchScript), ONNX, and native Mojo models, it delivers low-latency, high-throughput inference on a wide range of hardware to accelerate your entire AI workload. As highlighted in the recent MAX version 24.3 release, the MAX platform enables users to fully leverage the capabilities of the MAX Engine by creating bespoke inference models using the MAX Graph APIs. The Graph API offers a low-level programming interface for constructing high-performance symbolic computation graphs in Mojo. This interface provides a uniform representation of symbolic values and a suite of operators that process these symbols to construct the entire graph.

In this blog post, we guide you step-by-step how to use the MAX Graph API. In a nutshell, working with MAX Graph API involves three main steps:

  1. Building and verifying the graph.
  2. Creating an inference session and compiling the graph.
  3. Executing the graph with input(s) and retrieving the output(s).

We begin by creating two straightforward graphs for addition and matrix multiplication in Mojo, demonstrating how to compile and execute these graphs. Then we proceed to implement a two-layer feedforward neural network with ReLU activation for inference on MNIST data, comparing the accuracy to a PyTorch implementation. Additionally, we implement ReLU6 as a custom operator and use the MAX Graph Custom Operator API to substitute ReLU and ensuring the accuracy aligns with the PyTorch model.

To install MAX, please check out Get started with MAX Engine. If you are also new to Mojo, you can start with the Mojo Manual. To get involved and ask questions, you can join our Discord community and contribute to discussions on the Mojo and MAX GitHub. Should you encounter any issues, we recommend checking the roadmap and known issues first.

The code for this tutorial is in our GitHub repository. The MAX version for this tutorial is max 24.3.0 (9882e19d).

We also have a video walkthrough of all the code featured in this blog post below.

Hello, world!

To begin familiarizing ourselves with the Graph API, we start by constructing a simple addition graph. We will verify and compile this graph, and then proceed to execute it as demonstrated below.

Addition graph

Below is a straightforward graph that takes two inputs; input0 and input1. It adds these inputs together and produces output0 as the output.

Step 1: Build the graph

To construct the addition graph, we start by importing the necessary modules. We then instantiate the Graph by specifying two input types of fixed static dimension 1 (we will later see other types of supported dimensions such as symbolic dimension). Next, we create a symbolic representation of the addition with the expression out = graph[0] + graph[1]. Here graph[0] refers to the first input input0 and graph[1] to input1. This operation adds two inputs together. Finally, we designate out as the output of the graph by calling  graph.output(out).

from max.graph import Graph, TensorType, Type graph = Graph(in_types=List[Type](TensorType(DType.float32, 1), TensorType(DType.float32, 1))) out = graph[0] + graph[1] graph.output(out) print(graph)

We can print the graph to visually confirm its structure. The output should show the following representation where rmo and mo are Modular’s internal intermediate representations

%0 = rmo.add(%arg0, %arg1) : !mo.tensor<[1], f32>, !mo.tensor<[1], f32>

This line corresponds to the symbolic addition operation out = graph[0] + graph[1]

The subsequent line

mo.output %0 : !mo.tensor<[1], f32>

indicates that %0 has been set as the output of the graph, aligning with the graph.output(out) in our code. 

The complete graph representation looks like this:

graph: module { mo.graph @graph(%arg0: !mo.tensor<[1], f32>, %arg1: !mo.tensor<[1], f32>) -> !mo.tensor<[1], f32> no_inline { %0 = rmo.add(%arg0, %arg1) : !mo.tensor<[1], f32>, !mo.tensor<[1], f32> mo.output %0 : !mo.tensor<[1], f32> } }

To programmatically verify the complete graph construction, we use the graph.verify() method. This checks for various structural integrity criteria such as ensuring there are no cycles within the graph (acyclicity) which would indicate recursion or feedback loops that can not be part of the dataflow graph. For more details, check out the official documentation on the verify method.

Step 2: Create inference session, load and compile the graph

With our graph now verified and ready, the next step involves creating an inference session instance, loading the graph into this session and compiling the graph into a model instance. We also print the input names to use when executing the model.

from max import engine session = engine.InferenceSession() model = session.load(graph) print("input names are:") for input_name in model.get_model_input_names(): # Mojo lesson: `[]` de-references in Mojo as `input_name` is of `Reference` type print(input_name[])

which outputs

input names are: input0 input1

Verifying input names input0 and input1 is crucial for correctly executing the model in the subsequent section.

Step 3: Execute the graph/model with inputs

To execute the graph, we first create two input tensors in Mojo, specifying their names and values in the execute method. The result from the execution are returned as TensorMap, from which we can retrieve the value of output0 via the get method as follows

from tensor import Tensor print("set some input values:") input0 = Tensor[DType.float32](List[Float32](1.0)) print("input0:", input0) input1 = Tensor[DType.float32](List[Float32](1.0)) print("input1:", input1) print("obtain the result using `get`:") # Mojo lesson: here the `^` in `input0^` passes the ownership and ends the lifetime of `input0` ret = model.execute("input0", input0^, "input1", input1^) print("result:", ret.get[DType.float32]("output0"))

The outputs are printed as follows

set some input values: input0: Tensor([[1.0]], dtype=float32, shape=1) input1: Tensor([[1.0]], dtype=float32, shape=1) obtain the result using `get`: result: Tensor([[2.0]], dtype=float32, shape=1)

Now, let’s explore our second example.

Matmul graph

In this example, we create a graph specifically for performing matrix multiplication (matmul) by a constant symbol which we will use further along in the next section. This type of graph is particularly important as it demonstrates how constant symbols, representing trained and fixed weights in a neural network, can be utilized. This concept will be expanded upon in subsequent sections.

The setup for this matmul graph follows the same foundational steps as our initial example but includes some critical additions:

  1. We introduce a symbolic dimension m to represent m x 2
  2. The use graph.constant to create  a constant symbol, crucial for maintaining static values

Here's how we compile and execute the graph to accommodate varying input tensor sizes at runtime:

from max.graph import Graph, TensorType from tensor import Tensor, TensorShape, randn from random import seed from max.engine import InferenceSession graph = Graph(TensorType(DType.float32, "m", 2)) # create a constant tensor value to later create a graph constant symbol constant_value = Tensor[DType.float32](TensorShape(2, 2), 42.0) print("constant value:", constant_value) # create a constant symbol constant_symbol = graph.constant(constant_value) # create a matmul node mm = graph[0] @ constant_symbol graph.output(mm) # verify graph.verify() # create session, load and compile the graph session = InferenceSession() model = session.load(graph) # generate random input seed(42) input0 = randn[DType.float32]((2, 2)) print("random 2x2 input0:", input0) ret = model.execute("input0", input0^) print("matmul 2x2 result:", ret.get[DType.float32]("output0")) # with 3 x 2 matrix input input0 = randn[DType.float32]((3, 2)) print("random 3x2 input0:", input0) ret = model.execute("input0", input0^) print("matmul 3x2 result:", ret.get[DType.float32]("output0"))

Here are the results of matmul graph using a constant symbol of 2 x 2 tensor and a random input tensors of shapes 2 x 2 or 3 x 2 for demonstration

constant value: Tensor([[42.0, 42.0], [42.0, 42.0]], dtype=float32, shape=2x2) random 2x2 input0: Tensor([[-1.7141127586364746, 0.057178866118192673], [0.75628399848937988, -1.6024507284164429]], dtype=float32, shape=2x2) matmul 2x2 result: Tensor([[-69.591224670410156, -69.591224670410156], [-35.53900146484375, -35.53900146484375]], dtype=float32, shape=2x2) random 3x2 input0: Tensor([[1.0167152881622314, -0.10449378937482834], [-0.27936717867851257, -0.69003057479858398], [0.80745488405227661, -0.48231619596481323]], dtype=float32, shape=3x2) matmul 3x2 result: Tensor([[38.313301086425781, 38.313301086425781], [-40.714706420898438, -40.714706420898438], [13.655824661254883, 13.655824661254883]], dtype=float32, shape=3x2)

With this foundation, we are ready to explore more advanced applications in the next section of the tutorial.

Inference with MAX Graph API

In this section, we demonstrate how to build a two-layer neural network with ReLU activation using PyTorch, train it on the famous MNIST data featuring black and white 28 x 28 pixel images of handwritten digits (0 to 9 i.e. total of 10 classes) and then test its accuracy.

Subsequently, we will implement the same model using the MAX Graph API for inference to ensure the accuracy remains consistent.

Train, test and save a model on MNIST using PyTorch

First, to set up, let’s define our neural network in PyTorch:

import torch.nn as nn class Model(nn.Module): def __init__(self, input_size, hidden_size, num_classes): super().__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_size, num_classes) def forward(self, x): x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x

We can train and test the network as follows (python

loss_fn = nn.CrossEntropyLoss() optimizer = optim.RMSprop(model.parameters(), lr=learning_rate) total_steps = len(train_loader) for epoch in range(num_epochs): for i, (images, labels) in enumerate(train_loader): images = images.reshape(-1, 28 * 28) outputs = model(images) loss = loss_fn(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() if (i+1) % 100 == 0: print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_steps}], Loss: {loss.item():.4f}') # test model.eval() with torch.no_grad(): correct = 0 total = 0 for images, labels in test_loader: images = images.reshape(-1, 28 * 28) outputs = model(images) probs = F.softmax(outputs, dim=1) predicted = torch.argmax(probs, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print(f"Accuracy of the network on the 10000 test images: {100 * correct / total} %") # save weights in numpy binary format weights = {} for name, param in model.named_parameters(): weights[name] = param.detach().cpu().numpy()"model_weights.npy", weights)

After training and testing the network, we found the model achieves an accuracy of 96.99% on the test dataset.

Epoch [1/5], Step [100/469], Loss: 0.5380 Epoch [1/5], Step [200/469], Loss: 0.2307 Epoch [1/5], Step [300/469], Loss: 0.3197 Epoch [1/5], Step [400/469], Loss: 0.2660 Epoch [2/5], Step [100/469], Loss: 0.1091 Epoch [2/5], Step [200/469], Loss: 0.2256 Epoch [2/5], Step [300/469], Loss: 0.2351 Epoch [2/5], Step [400/469], Loss: 0.1577 Epoch [3/5], Step [100/469], Loss: 0.1253 Epoch [3/5], Step [200/469], Loss: 0.1233 Epoch [3/5], Step [300/469], Loss: 0.0751 Epoch [3/5], Step [400/469], Loss: 0.1261 Epoch [4/5], Step [100/469], Loss: 0.0613 Epoch [4/5], Step [200/469], Loss: 0.0792 Epoch [4/5], Step [300/469], Loss: 0.1482 Epoch [4/5], Step [400/469], Loss: 0.1186 Epoch [5/5], Step [100/469], Loss: 0.0231 Epoch [5/5], Step [200/469], Loss: 0.0303 Epoch [5/5], Step [300/469], Loss: 0.1051 Epoch [5/5], Step [400/469], Loss: 0.0442 Accuracy of the network on the 10000 test images: 96.99 %

Next, we implement the PyTorch model in MAX Graph API for inference.

Define inference graph with MAX graph API in Mojo

After training our model and saving its weights, we need to construct an inference graph and load the weights as constant symbols. Our graph will handle input dimensions with a symbolic "batch" dimension and static 28x28 spatial dimensions, representing flattened and preprocessed images. We will also include a softmax operation via ops.softmax to compute probabilities directly within the inference graph.

from max.graph import Graph, TensorType, ops from max import engine def build_mnist_graph( fc1w: Tensor[DType.float32], fc1b: Tensor[DType.float32], fc2w: Tensor[DType.float32], fc2b: Tensor[DType.float32], ) -> Graph: # Note: "batch" is a symbolic dim which is known ahead of time vs dynamic dim graph = Graph(TensorType(DType.float32, "batch", 28 * 28)) # PyTorch linear is defined as: x W^T + b so we need to transpose the weights fc1 = (graph[0] @ ops.transpose(graph.constant(fc1w), 1, 0)) + graph.constant(fc1b) relu = ops.relu(fc1) fc2 = (relu @ ops.transpose(graph.constant(fc2w), 1, 0)) + graph.constant(fc2b) out = ops.softmax(fc2) # adding explicit softmax for inference prob graph.output(out) graph.verify() return graph

With the inference graph defined, we can now execute it with test images. 

Run inference and check accuracy

To execute the graph, we first convert the model weights from numpy format to Mojo tensor format, then create the graph, compile it, and run inference. Finally, to check the accuracy, we iterate on test images, preprocess them, obtain the result and calls argmax to find the predicted value between the 10 classes and count how many of them correctly match the ground truth label.

weights_dict = load_model_weights() fc1w = numpy_to_tensor[DType.float32](weights_dict["fc1.weight"]) fc1b = numpy_to_tensor[DType.float32](weights_dict["fc1.bias"]) fc2w = numpy_to_tensor[DType.float32](weights_dict["fc2.weight"]) fc2b = numpy_to_tensor[DType.float32](weights_dict["fc2.bias"]) mnist_graph = build_mnist_graph(fc1w^, fc1b^, fc2w^, fc2b^) session = engine.InferenceSession() model = session.load(mnist_graph) correct = 0 total = 0 # use batch size of 1 in this example test_dataset = load_mnist_test_data() for i in range(len(test_dataset)): item = test_dataset[i] image = item[0] label = item[1] preprocessed_image = preprocess(image) output = model.execute("input0", preprocessed_image) probs = output.get[DType.float32]("output0") predicted = probs.argmax(axis=1) label_ = Tensor[DType.index](TensorShape(1), int(label)) correct += int(predicted == label_) total += 1 print("Accuracy of the network on the 10000 test images:", 100 * correct / total, "%")

The output of mojo mnist.mojo is

Accuracy of the network on the 10000 test images: 96.989999999999995 %

This matches the accuracy we observed from the PyTorch test, confirming that our MAX Graph API implementation performs equivalently.

MAX Graph custom operator API

In this final section of our tutorial, we demonstrate how to create and register a custom operator to use inside a MAX graph. Following our previous two layer neural network, we first train our model with ReLU6 activation via python —-use-relu6 which replaces ReLU with ReLU6, checks the test accuracy and saves the model weights that were done before. 

Creating custom operator in Mojo

To create a custom operator in Mojo, we should follow these steps

  • Create a dedicated sub-repository and name it custom_ops
  • Create a __init__.mojo with the import content from .relu6 import relu6

Create a custom op Mojo file, relu6.mojo with the following code

import math from max.extensibility import Tensor, empty_tensor from max import register @register.op("relu6") fn relu6[type: DType, rank: Int](x: Tensor[type, rank]) -> Tensor[type, rank]: var output = empty_tensor[type](x.shape) @always_inline @parameter fn _relu6[width: Int](i: StaticIntTuple[rank]) -> SIMD[type, width]: var val = x.simd_load[width](i) return math.min(math.max(0, val), 6) output.for_each[_relu6]() return output^

Above code uses @register.op(“relu6”) decorator to register the wrapped relu6 function with name ”relu6”, as a custom operator. The wrapped function can only take max.extensibility tensors and must have only one output of the same type and can not raise an Error. We create an empty_tensor to store the output. 

To obtain the output, we create a function wrapped in @parameter to be applied on each element of the input tensor via for_each. Such function (_relu6) loads SIMD values of each rank and applies the ReLU6 formula math.min(math.max(0, val), 6). Finally, we move the output via output^ to correctly transfer ownership of the result tensor.

Using Custom Operator for Inference

Once we have the custom operator defined, we need to package it as .mojopkg  via mojo package custom_ops.

In our graph definition, we are now ready to replace the ops.relu with our custom one

relu = ops.relu(fc1)


relu = ops.custom["relu6"](fc1, fc1.type())

Here we use the ops.custom that takes the custom operator name ”relu6” as parameter and the fc1 as input and the output type fc1.type(). The rest of the code stays the same.

def build_mnist_graph( fc1w: Tensor[DType.float32], fc1b: Tensor[DType.float32], fc2w: Tensor[DType.float32], fc2b: Tensor[DType.float32], use_relu6: Bool ) -> Graph: # Note: "batch" is a symbolic dim which is known ahead of time vs dynamic dim graph = Graph(TensorType(DType.float32, "batch", 28 * 28)) # PyTorch linear is defined as: x W^T + b so we need to transpose the weights fc1 = (graph[0] @ ops.transpose(graph.constant(fc1w), 1, 0)) + graph.constant(fc1b) # custom op relu = ops.custom["relu6"](fc1, fc1.type()) fc2 = (relu @ ops.transpose(graph.constant(fc2w), 1, 0)) + graph.constant(fc2b) out = ops.softmax(fc2) # adding explicit softmax for inference prob graph.output(out) graph.verify() return graph

The last change is to let the inference session know about the custom operator at runtime via

model = session.load(mnist_graph, custom_ops_paths=Path("custom_ops.mojopkg"))

Final verification

As the final check, we train and test the model that uses ReLU6 via python —-use-relu6 which outputs

Accuracy of the network on the 10000 test images: 96.0 %

Then we run the inference code via mojo mnist.mojo —-use-relu6 which shows

Accuracy of the network on the 10000 test images: 96.0 %

The matching accuracy between the PyTorch version and the Mojo implementation confirms the effective integration of the custom operator.

Deploying the MAX Graph binary

For deployment, we can build the mnist binary via mojo build mnist.mojo. To execute the binary, since we use the Mojo-Python interop, we should make sure to set the MOJO_PYTHON_LIBRARY as follows

export MOJO_PYTHON_LIBRARY=$(modular config mojo-max.python_lib) ./mnist # or ./mnist --use-relu6

Next Steps

Here are a few potential steps for you

  • Explore other neural network architectures beyond a simple two-layer feedforward network and implement them using MAX Graph API
  • Experiment with other custom operator 
  • Test and assess correctness and contribute to the community 🚀


In this blog post, we demonstrated how to use MAX Graph API step-by-step, to create a symbolic graph, compile and execute such graphs. We also showed how to replicate a two layer neural network trained in PyTorch, in MAX Graph API and saw that the test accuracy remained intact. We concluded by showing how to create and register a custom operator to use for inference. To verify correctness, we showed the test accuracy also remained intact when using such a custom operator. We hope that by the end of this blog post, you have gained a better understanding of the inner workings of MAX Graph APIs.

Additional resources:

Report feedback, including issues on our Mojo and MAX GitHub tracker.

Until next time!🔥

Ehsan M. Kermani
AI DevRel