Mojo 🔥 Traits Have Arrived!

December 3, 2023

Jack Clayton

AI Developer Advocate

Introduction

Strongly typed languages need some mechanism to express generic functions, other languages don't enforce types which means every function is generic. Many languages implement some form of this feature under a variety of different names:

Traits : Rust, Scala, Crystal, Mojo 🔥 Interfaces : Go, Java, CSharp, TypeScript, Kotlin, PHP, Dart Protocols : Swift, Objective-C, Elixir, Python Typeclasses : Haskell Concepts : C++

Python

In dynamic languages like JavaScript and Python, you can pass any object to any function. But you must make sure the object implements the methods you call in that function, or you'll get a nasty runtime error:

Python
class Foo: pass def cause_error(foo): foo.will_fail() foo = Foo() cause_error(foo)
Output
Traceback (most recent call last): File "main.py", line 8, in module cause_error(foo) File "main.py", line 5, in cause_error foo.will_fail() ^^^^^^^^^^^^^ AttributeError: 'Foo' object has no attribute 'will_fail'

Python3.8 introduced a typing feature named protocols which is related to traits, here's an example of a protocol named Shape that defines an abstract function signature area but doesn't implement it:

Python
from typing import Protocol class Shape(Protocol): def area(self) -> float: ...

This is a similar concept to Mojo's traits in terms of syntax and concept, but it only gives linter warnings and some quality of life improvements like better completions. With Mojo, you also get no performance penalty for writing generic functions across different types, and you can't put mistakes into production as it simply won't compile.

Mojo 🔥

Let's take a look at Mojo traits:

Mojo
trait Shape: fn area(self) -> Float64: ...

We can now create a function that accepts anything implementing the Shape trait and run the abstract area method in the function:

Mojo
fn print_area[T: Shape](shape: T): print(shape.area())

The [T: Shape] at the start is common syntax across languages, you can think of T as declaring a generic Type that is constrained by the Shape trait.

Now we'll create a concrete fn area implementation on Circle:

Mojo
@value struct Circle(Shape): var radius: Float64 fn area(self) -> Float64: return 3.141592653589793 * self.radius ** 2

Circle(Shape) means that the struct Circle must implement all the methods specified in the Shape trait, in this case it's just the fn area(self) -> Float64 signature.

You can now run it through the generic function:

Mojo
let circle = Circle(radius=1.5) print_area(circle)
Output
7.0685834705770345

If we remove fn area from Circle, the compiler won't allow us to build the program, so it's impossible to get a runtime error for this mistake. The other advantage in Mojo compared to Python, is the error tells us exactly what signature we need to implement:

/tmp/traits.mojo:11:1: error: struct 'Circle' does not implement all requirements for 'Shape' struct Circle(Shape, Repr): ^ /tmp/traits.mojo:2:5: note: required function 'area' is not implemented fn area(self) -> Float64: ^ /tmp/traits.mojo:1:1: note: trait 'Shape' declared here

Lets create another type that implements the Shape trait:

Mojo
@value struct Rectangle(Shape): var width: Float64 var height: Float64 fn area(self) -> Float64: return self.width * self.height

And run it through the same function:

Mojo
let rectangle = Rectangle(width=2, height=3) print_area(rectangle)
Output
8

Truly Zero-Cost Generics

The popular pattern from C++ to achieve this behavior was using inheritance and abstract classes, but the compiler can't reason about what types are used when running methods on an abstract class, which can have significant performance impacts. The pattern can also explode a code base in complexity.

C++ added multiple features to address these problems, such as templates and concepts. But they still aren't completely zero-cost, while Mojo can guarantee that values are register passable when using traits for truly zero-cost generics.

Multiple Traits

The __str__ method comes from Python, it determines what will happen when you print() the type. We added a Stringable trait in the standard library, which you can implement it on your type to make it printable:

Mojo
trait Stringable: fn __str__(self) -> String: ...

Let's add the Stringable trait to Circle:

Mojo
@value struct Circle(Shape, Stringable): var radius: Float64 fn __str__(self) -> String: var res = String("Circle(radius=") res += self.radius res += ", area=" res += self.area() res += ")" return res fn area(self) -> Float64: return 3.141592653589793 * self.radius ** 2

Circle now takes the Shape and Stringable traits, so it must implement fn area and fn __str__ to compile.

This allows us to print the type just like Python:

Mojo
let circle = Circle(radius=1.5) print(circle)
Output
Circle(radius=1.5, area=7.0685834705770345)

Trait Inheritance

A really cool feature of traits, is it allows users to compose their types with your library. We've added a few simple traits to our standard library so you can inherit them into your own types.

Here's an example of creating a new trait that inherits Shape, along with the standard library Stringable and CollectionElement so that you can push your type into a DynamicVector. The requirements for CollectionElement are implemented when using the @value decorator.

Mojo
trait VecPrintableShape(CollectionElement, Stringable, Shape): ... @value struct Circle(VecPrintableShape): var radius: Float64 fn __str__(self) -> String: var res = String("Circle(radius=") res += self.radius res += ", area=" res += self.area() res += ")" return res fn area(self) -> Float64: return 3.141592653589793 * self.radius ** 2 fn main(): var vec = DynamicVector[Circle](3) for i in range(0, 4): vec.push_back(Circle(i)) print(vec[i])
Output
Circle(radius=0.0, area=0.0) Circle(radius=1.0, area=3.1415926535897931) Circle(radius=2.0, area=12.566370614359172)

Now that we have a trait that's composed our three traits together, we can create a function that makes use of all of them:

Mojo
fn print_vec[T: VecPrintable](vec: DynamicVector[T]): for i in range(len(vec)): print(vec[i]) print_vec[Circle](vec)
Output
Circle(radius=0.0, area=0.0) Circle(radius=1.0, area=3.1415926535897931) Circle(radius=2.0, area=12.566370614359172)

Try defining your own Rectangle type that implements VecPrintableShape!

Database Trait

If you're still struggling to understand why traits are useful, a common example used to demonstrate the utility is a Database trait. We'll only define two methods to simplify the concept:

Mojo
trait Database: fn insert(self, key: String, value: String) raises: ... fn print_all(self) raises: ...

Now we can pass around an object that implements Database and use it's abstract methods:

Mojo
fn insert_and_print[T: Database](db: T, key: String, value: String) raises: db.insert(key, value) db.print_all()

For example, imagine you have a function that runs inference on an image and stores the result somewhere. Maybe we just want to use SQLite for our local batch tests, but in production it'll be stored in some dynamodb instance. Or maybe we just want the flexibility to change out the database later without causing breaking changes for users of our library.

Lets import the sqlite3 Python package to implement a database that runs locally, and conforms to the trait above:

Mojo
struct SQLite(Database): var con: PythonObject var cur: PythonObject fn __init__(inout self, path: StringLiteral) raises: let sqlite3 = Python.import_module("sqlite3") self.con = sqlite3.connect(path) self.cur = self.con.cursor() _ = self.cur.execute("CREATE TABLE IF NOT EXISTS test (key TEXT, value TEXT)") fn insert(self, key: String, value: String) raises: var query = String("INSERT INTO test VALUES ('") query += key query += "', '" query += value query += "')" _ = self.cur.execute(query) fn print_all(self) raises: let query = String("SELECT * FROM test") var result = self.cur.execute(query) for row in result: print(row)

Now we can pass the database to our previous function:

Mojo
let db = SQLite("test.db") insert_and_print(db, "name", "billy")

If you don't have sqlite installed, first run pip3 install sqlite3

Then run the program:

Python
('name', 'billy')

And now we can implement the Mongo version:

Mojo
struct Mongo(Database): var client: PythonObject var db: PythonObject var collection: PythonObject fn __init__(inout self, host: StringLiteral, port: Int, database: StringLiteral, collection: StringLiteral) raises: let MongoClient = Python.import_module("pymongo").MongoClient self.client = MongoClient(host, port) self.db = self.client[database] self.collection = self.db[collection] fn insert(self, key: String, value: String) raises: var document = String('{"') document += key document += '": "' document += value document += '"}' let my_document = Python.evaluate(document._strref_dangerous()) _ = self.collection.insert_one(my_document) document._strref_keepalive() fn print_all(self) raises: var docs = self.collection.find() for doc in docs: print(doc)

You'll need to pip install pymongo if you want to run this, and follow the instructions here to start a service.

Then run it:

Mojo
let db2 = Mongo("localhost", 27017, "test", "test") insert_and_print(db, "name", "billy")
Output
{'_id': ObjectId('655734c648d3297fa1b91b0b'), 'name': 'billy'}

This is simplified to demonstrate the functionality, but you could create an entire library following these principles wrapping Python libraries, and then introduce optimized Mojo implementations where you need better performance without changing the API.

Conclusion

You may have seen abstract methods navigating around Python code bases with ..., and not understood why they're there. It's pervasive in ML libraries where the authors want to provide correctness and nice tooling while still having multiple implementations for CUDA, CPU, and the many emerging hardware backends. You get an extra benefit in Mojo, you can write generic reusable functionality across multiple types, but still retain full type safety and compiler optimizations by writing concrete implementations for each type.

There are more features to come for traits such as default implementations, make sure to check back on the docs.

We're excited to see what you build with traits, please share your projects on the Discord and GitHub!

Jack Clayton
,
AI Developer Advocate