MLIR for Lox: A Practical Guide
This guide reimagines the MLIR tutorial for building a Lox compiler instead of the tensor-based Toy language. If you know Crafting Interpreters, this should feel familiar.
Why MLIR for Lox?
LLVM is powerful but low-level. It doesn't know about:
- Variable scoping rules
- Closure captures
- Dynamic typing
- Lox-specific optimizations
MLIR lets you define a dialect that represents Lox semantics directly, then progressively lower it to LLVM IR. This is how modern languages like Swift, Rust, and Julia work.
Part 1: The Lox Dialect
Instead of tensors, our dialect models Lox's types and operations.
Lox Types
// include/lox/Types.td (conceptual - type list, not code)
lox.nil - The nil type
lox.bool - Boolean type
lox.number - 64-bit float (Lox uses doubles)
lox.string - String type
lox.object - Instance/class type (dynamic)
A Sample Lox Program
// test/add.lox
fun add(a, b) {
return a + b;
}
print add(3, 4); // 7
The MLIR Representation
// output after: loxc test/add.lox --emit-mlir
module {
lox.func @add(%arg0: !lox.number, %arg1: !lox.number) -> !lox.number {
%result = lox.add %arg0, %arg1 : !lox.number
lox.return %result : !lox.number
}
lox.func @main() {
%three = lox.constant 3.0 : !lox.number
%four = lox.constant 4.0 : !lox.number
%sum = lox.call @add(%three, %four) : (!lox.number, !lox.number) -> !lox.number
lox.print %sum : !lox.number
lox.return
}
}
Part 2: Defining the Lox Dialect in TableGen
MLIR uses TableGen (.td files) to declaratively define dialects. This generates C++ boilerplate for you.
Dialect Definition
// include/lox/Lox.td
// The Lox dialect definition
def Lox_Dialect : Dialect {
let name = "lox";
let summary = "A dialect for the Lox programming language";
let description = [{
This dialect represents the Lox language at a high level,
enabling Lox-specific optimizations before lowering to LLVM.
}];
let cppNamespace = "lox";
}
Operation Definitions
// include/lox/Ops.td
// Base class for Lox operations
class Lox_Op<string mnemonic, list<Trait> traits = []> :
Op<Lox_Dialect, mnemonic, traits>;
// Arithmetic: lox.add
def AddOp : Lox_Op<"add", [Pure, SameOperandsAndResultType]> {
let summary = "Add two Lox numbers";
let arguments = (ins Lox_NumberType:$lhs, Lox_NumberType:$rhs);
let results = (outs Lox_NumberType:$result);
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
}
// Subtraction, multiplication, division follow the same pattern
def SubOp : Lox_Op<"sub", [Pure, SameOperandsAndResultType]> { ... }
def MulOp : Lox_Op<"mul", [Pure, SameOperandsAndResultType]> { ... }
def DivOp : Lox_Op<"div", [Pure, SameOperandsAndResultType]> { ... }
// Comparison: lox.less, lox.equal, etc.
def LessOp : Lox_Op<"less", [Pure]> {
let arguments = (ins Lox_NumberType:$lhs, Lox_NumberType:$rhs);
let results = (outs Lox_BoolType:$result);
}
// Variables: lox.declare, lox.assign, lox.load
def DeclareOp : Lox_Op<"declare"> {
let arguments = (ins StrAttr:$name, AnyLoxType:$init);
let results = (outs AnyLoxType:$value);
}
def LoadOp : Lox_Op<"load", [Pure]> {
let arguments = (ins StrAttr:$name);
let results = (outs AnyLoxType:$value);
}
def AssignOp : Lox_Op<"assign"> {
let arguments = (ins StrAttr:$name, AnyLoxType:$value);
}
// Control flow: lox.if, lox.while, lox.for
def IfOp : Lox_Op<"if", [Terminator, NoTerminator]> {
let arguments = (ins Lox_BoolType:$condition);
let regions = (region SizedRegion<1>:$thenRegion,
OptionalRegion<1>:$elseRegion);
}
def WhileOp : Lox_Op<"while", [Terminator]> {
let arguments = (ins Lox_BoolType:$condition);
let regions = (region SizedRegion<1>:$body);
}
// Functions
// Lox semantics that differ from func.func:
// 1. Parameters have no declared types (dynamic typing)
// 2. Return type is optional (implicitly nil if missing)
// 3. Functions are first-class values (can be passed, returned, assigned)
// 4. May capture variables from enclosing scopes (closures)
// 5. Arity is checked at runtime, not compile time
// 6. No function overloading
def FuncOp : Lox_Op<"func", [AffineScope, IsolatedFromAbove]> {
let summary = "A Lox function definition";
let description = [{
Declares a Lox function. Unlike func.func, parameters have no declared
types (Lox is dynamically typed). The return type is also optional —
functions that don't explicitly return implicitly return nil.
Functions are first-class values: they can be passed as arguments,
returned from other functions, and assigned to variables.
}];
let arguments = (ins
SymbolNameAttr:$sym_name, // Function name
ArrayAttr:$param_names, // Parameter names (not types!)
OptionalAttr<TypeAttr>:$return_type // Optional return type
);
let regions = (region AnyRegion:$body);
let assemblyFormat = [{
$sym_name `(` $param_names `)` (`->` $return_type^)? attr-dict $body
}];
let skipDefaultBuilders = 1;
// Builder for functions with no explicit return type (returns nil)
let builders = [
OpBuilder<(ins "StringRef":$name, "ArrayRef<StringRef>":$params)>
];
}
def CallOp : Lox_Op<"call", [Pure]> {
let arguments = (ins FlatSymbolRefAttr:$callee,
Variadic<AnyLoxType>:$operands);
let results = (outs Variadic<AnyLoxType>:$results);
}
def ReturnOp : Lox_Op<"return", [Terminator]> {
let arguments = (ins Optional<AnyLoxType>:$value);
}
// Print builtin
def PrintOp : Lox_Op<"print"> {
let arguments = (ins AnyLoxType:$value);
}
Part 3: From AST to MLIR
This is like Chapter 2 of the Toy tutorial, but for Lox.
The Lox AST
Your parser (from Crafting Interpreters) produces an AST. Here's the full structure:
// include/lox/AST.h
// AST node types (used for dispatching in the visitor)
enum class StmtType {
FUNCTION, RETURN, VAR, IF, WHILE, PRINT, BLOCK, EXPRESSION
};
enum class ExprType {
BINARY, UNARY, LITERAL, GROUPING, VARIABLE, ASSIGN, CALL, LOGICAL
};
// ========================================================================
// Base classes
// ========================================================================
struct Location {
int line;
int column;
};
struct Expr {
ExprType type;
Location loc;
virtual ~Expr() = default;
};
struct Stmt {
StmtType type;
virtual ~Stmt() = default;
};
// ========================================================================
// Statements
// ========================================================================
// A whole program is just a list of top-level statements
struct Program {
std::vector<Stmt*> statements;
};
struct FunctionStmt : Stmt {
std::string name;
std::vector<std::string> params;
std::vector<Stmt*> body;
FunctionStmt(std::string n, std::vector<std::string> p, std::vector<Stmt*> b)
: Stmt{StmtType::FUNCTION}, name(std::move(n)), params(std::move(p)), body(std::move(b)) {}
};
struct ReturnStmt : Stmt {
Expr* value; // Can be null for bare "return;"
ReturnStmt(Expr* v) : Stmt{StmtType::RETURN}, value(v) {}
};
struct VarStmt : Stmt {
std::string name;
Expr* init;
VarStmt(std::string n, Expr* i) : Stmt{StmtType::VAR}, name(std::move(n)), init(i) {}
};
struct IfStmt : Stmt {
Expr* condition;
std::vector<Stmt*> thenBranch;
std::vector<Stmt*> elseBranch; // Empty if no else clause
IfStmt(Expr* c, std::vector<Stmt*> then, std::vector<Stmt*> els)
: Stmt{StmtType::IF}, condition(c), thenBranch(std::move(then)), elseBranch(std::move(els)) {}
};
struct WhileStmt : Stmt {
Expr* condition;
std::vector<Stmt*> body;
WhileStmt(Expr* c, std::vector<Stmt*> b)
: Stmt{StmtType::WHILE}, condition(c), body(std::move(b)) {}
};
struct PrintStmt : Stmt {
Expr* expression;
PrintStmt(Expr* e) : Stmt{StmtType::PRINT}, expression(e) {}
};
struct BlockStmt : Stmt {
std::vector<Stmt*> statements;
BlockStmt(std::vector<Stmt*> stmts) : Stmt{StmtType::BLOCK}, statements(std::move(stmts)) {}
};
struct ExpressionStmt : Stmt {
Expr* expression;
ExpressionStmt(Expr* e) : Stmt{StmtType::EXPRESSION}, expression(e) {}
};
// ========================================================================
// Expressions
// ========================================================================
struct BinaryExpr : Expr {
Expr* left;
TokenType op; // PLUS, MINUS, STAR, SLASH, etc.
Expr* right;
BinaryExpr(Expr* l, TokenType o, Expr* r)
: Expr{ExprType::BINARY}, left(l), op(o), right(r) {}
};
struct UnaryExpr : Expr {
TokenType op; // MINUS, BANG
Expr* right;
UnaryExpr(TokenType o, Expr* r) : Expr{ExprType::UNARY}, op(o), right(r) {}
};
struct LiteralExpr : Expr {
LoxValue value; // Your value type from Crafting Interpreters
LiteralExpr(LoxValue v) : Expr{ExprType::LITERAL}, value(std::move(v)) {}
};
struct GroupingExpr : Expr {
Expr* expression;
GroupingExpr(Expr* e) : Expr{ExprType::GROUPING}, expression(e) {}
};
struct VarExpr : Expr {
std::string name;
VarExpr(std::string n) : Expr{ExprType::VARIABLE}, name(std::move(n)) {}
};
struct AssignExpr : Expr {
std::string name;
Expr* value;
AssignExpr(std::string n, Expr* v)
: Expr{ExprType::ASSIGN}, name(std::move(n)), value(v) {}
};
struct CallExpr : Expr {
Expr* callee;
std::vector<Expr*> arguments;
CallExpr(Expr* c, std::vector<Expr*> args)
: Expr{ExprType::CALL}, callee(c), arguments(std::move(args)) {}
};
struct LogicalExpr : Expr {
Expr* left;
TokenType op; // AND, OR
Expr* right;
LogicalExpr(Expr* l, TokenType o, Expr* r)
: Expr{ExprType::LOGICAL}, left(l), op(o), right(r) {}
};
How the AST is Built
Your parser (from Crafting Interpreters) constructs these nodes while parsing:
// lib/Parser.cpp (simplified excerpt from Crafting Interpreters)
class Parser {
std::vector<Token> tokens;
size_t current = 0;
public:
Program* parse() {
Program* program = new Program();
while (!isAtEnd()) {
program->statements.push_back(declaration());
}
return program;
}
private:
// Top-level declarations (functions, variables, statements)
Stmt* declaration() {
if (match(TokenType::FUN)) return functionDeclaration();
if (match(TokenType::VAR)) return varDeclaration();
return statement();
}
Stmt* functionDeclaration() {
std::string name = consume(TokenType::IDENTIFIER, "Expect function name.");
consume(TokenType::LEFT_PAREN, "Expect '(' after function name.");
std::vector<std::string> params;
if (!check(TokenType::RIGHT_PAREN)) {
do {
params.push_back(consume(TokenType::IDENTIFIER, "Expect parameter name.").lexeme);
} while (match(TokenType::COMMA));
}
consume(TokenType::RIGHT_PAREN, "Expect ')' after parameters.");
consume(TokenType::LEFT_BRACE, "Expect '{' before function body.");
std::vector<Stmt*> body = block();
return new FunctionStmt(name, params, body);
}
Stmt* statement() {
if (match(TokenType::PRINT)) return printStatement();
if (match(TokenType::IF)) return ifStatement();
if (match(TokenType::WHILE)) return whileStatement();
if (match(TokenType::FOR)) return forStatement();
if (match(TokenType::RETURN)) return returnStatement();
if (match(TokenType::LEFT_BRACE)) return new BlockStmt(block());
return expressionStatement();
}
Stmt* printStatement() {
Expr* value = expression();
consume(TokenType::SEMICOLON, "Expect ';' after value.");
return new PrintStmt(value);
}
Stmt* ifStatement() {
consume(TokenType::LEFT_PAREN, "Expect '(' after 'if'.");
Expr* condition = expression();
consume(TokenType::RIGHT_PAREN, "Expect ')' after if condition.");
std::vector<Stmt*> thenBranch = { statement() };
std::vector<Stmt*> elseBranch;
if (match(TokenType::ELSE)) {
elseBranch.push_back(statement());
}
return new IfStmt(condition, thenBranch, elseBranch);
}
std::vector<Stmt*> block() {
std::vector<Stmt*> statements;
while (!check(TokenType::RIGHT_BRACE) && !isAtEnd()) {
statements.push_back(declaration());
}
consume(TokenType::RIGHT_BRACE, "Expect '}' after block.");
return statements;
}
// ... rest of parser ...
// Expression parsing (precedence climbing)
Expr* expression() { return assignment(); }
Expr* assignment() {
Expr* expr = orExpr();
if (match(TokenType::EQUAL)) {
Expr* value = assignment();
// Check if left side is a variable
if (VarExpr* var = dynamic_cast<VarExpr*>(expr)) {
return new AssignExpr(var->name, value);
}
error("Invalid assignment target.");
}
return expr;
}
Expr* orExpr() {
Expr* expr = andExpr();
while (match(TokenType::OR)) {
Expr* right = andExpr();
expr = new LogicalExpr(expr, TokenType::OR, right);
}
return expr;
}
// ... more expression parsing methods ...
Expr* term() {
Expr* expr = factor();
while (match(TokenType::MINUS, TokenType::PLUS)) {
Expr* right = factor();
expr = new BinaryExpr(expr, previous().type, right);
}
return expr;
}
};
Example: Parsing print 1 + 2;
Parser::parse()
└── declaration()
└── statement()
└── printStatement()
├── expression()
│ └── term()
│ └── factor()
│ ├── primary() → LiteralExpr(1)
│ ├── match(PLUS) → true
│ └── factor()
│ └── primary() → LiteralExpr(2)
│ → BinaryExpr(Literal(1), PLUS, Literal(2))
└── consume(SEMICOLON)
→ PrintStmt(BinaryExpr(...))
The MLIR Generator
You write a visitor that walks the AST and emits MLIR. The visitor pattern dispatches based on AST node type — each visit* method handles one node type.
// lib/IR/MLIRGen.cpp
#include "lox/Dialect.h"
#include "lox/Ops.h"
#include "lox/AST.h" // Your AST classes from Crafting Interpreters
class MLIRGenerator : public ExprVisitor, public StmtVisitor {
private:
mlir::OpBuilder builder;
mlir::ModuleOp module;
mlir::Location currentLoc;
public:
// ========================================================================
// Entry point: Convert a whole program (list of statements) to MLIR
// ========================================================================
mlir::ModuleOp generateModule(Program *program) {
module = mlir::ModuleOp::create(builder.getUnknownLoc());
// Visit each top-level statement (function declarations, etc.)
for (Stmt *stmt : program->statements) {
visitStmt(stmt);
}
return module;
}
// ========================================================================
// Statement visitors (called by visitStmt, which dispatches by type)
// ========================================================================
// The dispatcher: called for every statement, routes to the right visit method
void visitStmt(Stmt *stmt) {
// Each Stmt subclass has a type tag we switch on
switch (stmt->type) {
case StmtType::FUNCTION: return visitFunctionStmt((FunctionStmt*)stmt);
case StmtType::RETURN: return visitReturnStmt((ReturnStmt*)stmt);
case StmtType::VAR: return visitVarStmt((VarStmt*)stmt);
case StmtType::IF: return visitIfStmt((IfStmt*)stmt);
case StmtType::WHILE: return visitWhileStmt((WhileStmt*)stmt);
case StmtType::PRINT: return visitPrintStmt((PrintStmt*)stmt);
case StmtType::BLOCK: return visitBlockStmt((BlockStmt*)stmt);
case StmtType::EXPRESSION: return visitExpressionStmt((ExpressionStmt*)stmt);
}
}
void visitFunctionStmt(FunctionStmt *stmt) {
// Create a new function in the module
auto func = builder.create<FuncOp>(currentLoc, stmt->name, stmt->params);
// Push a new block for the function body
auto *entryBlock = func.getBody().emplaceBlock();
builder.setInsertionPointToStart(entryBlock);
// Add block arguments for each parameter
for (const std::string ¶m : stmt->params) {
entryBlock->addArgument(builder.getType<LoxDynamicType>());
}
// Visit all statements in the function body
for (Stmt *bodyStmt : stmt->body) {
visitStmt(bodyStmt);
}
// Add implicit return nil if the function doesn't end with a return
if (!endsWithReturn(func)) {
builder.create<ReturnOp>(currentLoc);
}
}
void visitReturnStmt(ReturnStmt *stmt) {
mlir::Value value = stmt->value ? visitExpr(stmt->value) : nullptr;
builder.create<ReturnOp>(currentLoc, value);
}
void visitVarStmt(VarStmt *stmt) {
mlir::Value init = visitExpr(stmt->init);
builder.create<DeclareOp>(currentLoc, stmt->name, init);
}
void visitIfStmt(IfStmt *stmt) {
mlir::Value cond = visitExpr(stmt->condition);
auto ifOp = builder.create<IfOp>(currentLoc, cond);
// Then branch
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
for (Stmt *thenStmt : stmt->thenBranch) {
visitStmt(thenStmt);
}
// Else branch (optional)
if (stmt->elseBranch) {
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
for (Stmt *elseStmt : stmt->elseBranch) {
visitStmt(elseStmt);
}
}
}
void visitWhileStmt(WhileStmt *stmt) {
mlir::Value cond = visitExpr(stmt->condition);
auto whileOp = builder.create<WhileOp>(currentLoc, cond);
builder.setInsertionPointToStart(&whileOp.getBody().front());
for (Stmt *bodyStmt : stmt->body) {
visitStmt(bodyStmt);
}
}
void visitPrintStmt(PrintStmt *stmt) {
mlir::Value value = visitExpr(stmt->expression);
builder.create<PrintOp>(currentLoc, value);
}
void visitExpressionStmt(ExpressionStmt *stmt) {
visitExpr(stmt->expression);
}
void visitBlockStmt(BlockStmt *stmt) {
for (Stmt *blockStmt : stmt->statements) {
visitStmt(blockStmt);
}
}
// ========================================================================
// Expression visitors (called by visitExpr, which dispatches by type)
// ========================================================================
// The dispatcher: called for every expression, routes to the right visit method
// This is what calls visitBinaryExpr, visitLiteralExpr, etc.
mlir::Value visitExpr(Expr *expr) {
currentLoc = convertLocation(expr->loc);
switch (expr->type) {
case ExprType::BINARY: return visitBinaryExpr((BinaryExpr*)expr);
case ExprType::UNARY: return visitUnaryExpr((UnaryExpr*)expr);
case ExprType::LITERAL: return visitLiteralExpr((LiteralExpr*)expr);
case ExprType::GROUPING: return visitGroupingExpr((GroupingExpr*)expr);
case ExprType::VARIABLE: return visitVarExpr((VarExpr*)expr);
case ExprType::ASSIGN: return visitAssignExpr((AssignExpr*)expr);
case ExprType::CALL: return visitCallExpr((CallExpr*)expr);
case ExprType::LOGICAL: return visitLogicalExpr((LogicalExpr*)expr);
}
return nullptr;
}
mlir::Value visitBinaryExpr(BinaryExpr *expr) {
// Recursively visit left and right operands
// These calls go through visitExpr, which dispatches to the right visitor
mlir::Value lhs = visitExpr(expr->left);
mlir::Value rhs = visitExpr(expr->right);
switch (expr->op) {
case TokenType::PLUS:
return builder.create<AddOp>(currentLoc, lhs, rhs);
case TokenType::MINUS:
return builder.create<SubOp>(currentLoc, lhs, rhs);
case TokenType::STAR:
return builder.create<MulOp>(currentLoc, lhs, rhs);
case TokenType::SLASH:
return builder.create<DivOp>(currentLoc, lhs, rhs);
case TokenType::LESS:
return builder.create<LessOp>(currentLoc, lhs, rhs);
case TokenType::LESS_EQUAL:
return builder.create<LessEqualOp>(currentLoc, lhs, rhs);
case TokenType::GREATER:
return builder.create<GreaterOp>(currentLoc, lhs, rhs);
case TokenType::GREATER_EQUAL:
return builder.create<GreaterEqualOp>(currentLoc, lhs, rhs);
case TokenType::EQUAL_EQUAL:
return builder.create<EqualOp>(currentLoc, lhs, rhs);
case TokenType::BANG_EQUAL:
return builder.create<NotEqualOp>(currentLoc, lhs, rhs);
default:
emitError(currentLoc, "Unknown binary operator");
return nullptr;
}
}
mlir::Value visitUnaryExpr(UnaryExpr *expr) {
mlir::Value operand = visitExpr(expr->right);
switch (expr->op) {
case TokenType::MINUS:
return builder.create<NegateOp>(currentLoc, operand);
case TokenType::BANG:
return builder.create<NotOp>(currentLoc, operand);
default:
emitError(currentLoc, "Unknown unary operator");
return nullptr;
}
}
mlir::Value visitLiteralExpr(LiteralExpr *expr) {
if (expr->value.isNumber()) {
return builder.create<ConstantOp>(currentLoc, expr->value.asNumber());
}
if (expr->value.isBool()) {
return builder.create<ConstantOp>(currentLoc, expr->value.asBool());
}
if (expr->value.isString()) {
return builder.create<ConstantOp>(currentLoc, expr->value.asString());
}
// nil
return builder.create<NilOp>(currentLoc);
}
mlir::Value visitGroupingExpr(GroupingExpr *expr) {
// Grouping is just parentheses - emit the inner expression directly
return visitExpr(expr->expression);
}
mlir::Value visitVarExpr(VarExpr *expr) {
return builder.create<LoadOp>(currentLoc, expr->name);
}
mlir::Value visitAssignExpr(AssignExpr *expr) {
mlir::Value value = visitExpr(expr->value);
builder.create<AssignOp>(currentLoc, expr->name, value);
return value;
}
mlir::Value visitCallExpr(CallExpr *expr) {
mlir::Value callee = visitExpr(expr->callee);
// Visit all arguments
llvm::SmallVector<mlir::Value, 4> args;
for (Expr *arg : expr->arguments) {
args.push_back(visitExpr(arg));
}
return builder.create<CallOp>(currentLoc, callee, args);
}
mlir::Value visitLogicalExpr(LogicalExpr *expr) {
mlir::Value left = visitExpr(expr->left);
// Logical AND/OR short-circuit, so they need control flow
// lox.and and lox.or operations handle this
switch (expr->op) {
case TokenType::AND:
return builder.create<AndOp>(currentLoc, left,
/* right is a lazy region */ expr->right);
case TokenType::OR:
return builder.create<OrOp>(currentLoc, left,
/* right is a lazy region */ expr->right);
default:
return nullptr;
}
}
};
How the visitor works:
generateProgram()
│
├── for each statement:
│ │
│ visitStmt(stmt) ───── dispatches based on stmt->type
│ │
│ ├── visitFunctionStmt() if StmtType::FUNCTION
│ ├── visitIfStmt() if StmtType::IF
│ ├── visitWhileStmt() if StmtType::WHILE
│ └── ...
│
└── statements call visitExpr() for their expressions:
│
visitExpr(expr) ───── dispatches based on expr->type
│
├── visitBinaryExpr() if ExprType::BINARY
│ │
│ ├── visitExpr(left) ── recursive!
│ └── visitExpr(right) ── recursive!
│
├── visitLiteralExpr() if ExprType::LITERAL
├── visitCallExpr() if ExprType::CALL
└── ...
Part 4: Lox-Specific Optimizations
This is where MLIR shines. You can write passes that understand Lox semantics.
Example: Constant Folding
// lib/Transforms/LoxOpt.cpp
#include "lox/Ops.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
// Fold lox.add(constant, constant) -> constant
struct FoldConstantAdd : public OpRewritePattern<AddOp> {
using OpRewritePattern<AddOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AddOp op, PatternRewriter &rewriter) const override {
// Check if both operands are constants
auto lhsDef = op.getLhs().getDefiningOp<ConstantOp>();
auto rhsDef = op.getRhs().getDefiningOp<ConstantOp>();
if (!lhsDef || !rhsDef) return failure();
// Fold the constants
double result = lhsDef.getValue() + rhsDef.getValue();
rewriter.replaceOpWithNewOp<ConstantOp>(op, result);
return success();
}
};
Example: Dead Variable Elimination
// lib/Transforms/LoxOpt.cpp (continued)
// Remove variables that are declared but never used
struct EliminateDeadVariable : public OpRewritePattern<DeclareOp> {
using OpRewritePattern<DeclareOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DeclareOp op, PatternRewriter &rewriter) const override {
// Check if the variable is ever loaded
if (hasUses(op.getName())) return failure();
// Variable is never used, remove it
rewriter.eraseOp(op);
return success();
}
};
Example: Inline Simple Functions
// lib/Transforms/LoxOpt.cpp (continued)
// Inline functions that are small (e.g., just return an expression)
struct InlineSimpleFunctions : public OpRewritePattern<CallOp> {
using OpRewritePattern<CallOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CallOp call, PatternRewriter &rewriter) const override {
FuncOp func = lookupFunction(call.getCallee());
// Only inline small functions
if (!isSimpleEnoughToInline(func)) return failure();
// Clone the function body and substitute arguments
inlineFunction(call, func, rewriter);
return success();
}
};
Part 5: Lowering to Standard Dialects
Now we progressively transform lox.* operations into lower-level MLIR dialects.
Step 1: Lox Dialect → SCF + Arith Dialects
Control flow becomes structured control flow (scf.if, scf.for), arithmetic becomes arith.addf, etc.
// lib/Transforms/LowerToLLVM.cpp
#include "lox/Ops.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
// Convert lox.add to arith.addf
struct AddOpLowering : public OpConversionPattern<AddOp> {
using OpConversionPattern<AddOp>::OpConversionPattern;
LogicalResult matchAndRewrite(AddOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::arith::AddFOp>(op, adaptor.getLhs(), adaptor.getRhs());
return success();
}
};
// Convert lox.while to scf.while
struct WhileOpLowering : public OpConversionPattern<WhileOp> {
LogicalResult matchAndRewrite(WhileOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Convert lox.while body to scf.while
auto whileOp = rewriter.create<scf::WhileOp>(
op.getLoc(), TypeRange{}, adaptor.getCondition());
// Convert the body region
rewriter.inlineRegionBefore(op.getBody(), whileOp.getBody(),
whileOp.getBody().begin());
rewriter.eraseOp(op);
return success();
}
};
Step 2: SCF + Arith → LLVM Dialect
This is handled by built-in MLIR passes:
# Run from build directory or project root
mlir-opt lox.mlir \
--convert-lox-to-scf \
--convert-scf-to-cf \
--convert-cf-to-llvm \
--convert-arith-to-llvm \
--convert-func-to-llvm \
-o lox_llvm.mlir
Step 3: LLVM Dialect → LLVM IR → Machine Code
# Run from build directory or project root
mlir-translate lox_llvm.mlir --mlir-to-llvmir -o lox.ll
llc lox.ll -o lox.s
clang lox.s -o lox
Part 6: Handling Dynamic Typing
Lox is dynamically typed, but MLIR is typed. Options:
Option A: Tagged Values
Every Lox value is a struct with a type tag:
// include/lox/Types.td (type definition)
!lox.value = struct<{
tag: i8, // 0=nil, 1=bool, 2=number, 3=string, 4=object
data: union<i1, f64, !lox.string, !lox.object>
}>
Operations check the tag at runtime:
// lib/Transforms/RuntimeChecks.cpp (generated MLIR for runtime type checking)
// lox.add checks that both operands are numbers
%lhs_tag = llvm.extractvalue %lhs[0] : !lox.value
%rhs_tag = llvm.extractvalue %rhs[0] : !lox.value
%tags_match = arith.cmpi eq, %lhs_tag, %rhs_tag : i8
cf.cond_br %tags_match, ^ok, ^type_error
^ok:
%lhs_num = llvm.extractvalue %lhs[1] : !lox.value
%rhs_num = llvm.extractvalue %rhs[1] : !lox.value
%result = arith.addf %lhs_num, %rhs_num : f64
// ... pack result back into tagged value ...
Option B: Specialize by Type
Generate specialized versions of functions for each type combination:
// lib/Transforms/TypeSpecialization.cpp (generated after type specialization)
// Generic call
lox.call @add(%a, %b)
// After specialization (if both are numbers)
lox.call @add_numbers(%a, %b)
// If strings, call string concatenation
lox.call @add_strings(%a, %b)
Part 7: Closures
Closures are the trickiest part of Lox. MLIR's regions help.
Captured Variables
// test/closure.lox
fun makeCounter() {
var count = 0;
fun counter() {
count = count + 1;
return count;
}
return counter;
}
MLIR Representation
// output after: loxc test/closure.lox --emit-mlir
lox.func @makeCounter() -> !lox.closure {
// Create a closure environment
%env = lox.alloc_env { size = 1 }
// Store count in the environment
%zero = lox.constant 0.0 : !lox.number
lox.env_store %env[0] = %zero : !lox.number
// Create the closure, capturing the environment
%closure = lox.make_closure @counter(%env)
lox.return %closure : !lox.closure
}
lox.func @counter(%env: !lox.env) -> !lox.number {
// Load count from the environment
%count = lox.env_load %env[0] : !lox.number
// Increment
%one = lox.constant 1.0 : !lox.number
%new_count = lox.add %count, %one : !lox.number
// Store back
lox.env_store %env[0] = %new_count : !lox.number
lox.return %new_count : !lox.number
}
Part 8: Project Structure
lox-mlir/
├── include/
│ └── lox/
│ ├── Lox.td # Dialect definition (TableGen)
│ ├── Ops.td # Operation definitions
│ ├── Types.td # Type definitions
│ ├── Dialect.h # Generated C++ header
│ └── Dialect.cpp # Dialect implementation
├── lib/
│ ├── Dialect.cpp # Dialect registration
│ ├── IR/
│ │ └── MLIRGen.cpp # AST → MLIR conversion
│ └── Transforms/
│ ├── LoxOpt.cpp # Lox-specific optimizations
│ └── LowerToLLVM.cpp # Lowering passes
├── tools/
│ └── loxc/
│ └── main.cpp # Compiler driver
└── test/
└── *.lox # Test programs
Part 9: Build System (CMake)
# CMakeLists.txt (project root)
cmake_minimum_required(VERSION 3.16)
project(lox-mlir)
find_package(MLIR REQUIRED CONFIG)
# Generate dialect files from TableGen
mlir_tablegen(LoxDialect.h.inc -gen-dialect-decls)
mlir_tablegen(LoxDialect.cpp.inc -gen-dialect-defs)
mlir_tablegen(LoxOps.h.inc -gen-op-decls)
mlir_tablegen(LoxOps.cpp.inc -gen-op-defs)
add_library(LoxDialect
Dialect.cpp
)
target_link_libraries(LoxDialect MLIRIR)
add_library(LoxTransforms
Transforms/LoxOpt.cpp
Transforms/LowerToLLVM.cpp
)
target_link_libraries(LoxTransforms LoxDialect MLIRTransformUtils)
add_executable(loxc tools/loxc/main.cpp)
target_link_libraries(loxc LoxDialect LoxTransforms MLIRParser)
Quick Reference: Lox → MLIR Mapping
| Lox Construct | MLIR Operation | Lowered To |
|---|---|---|
a + b | lox.add | arith.addf |
a - b | lox.sub | arith.subf |
a * b | lox.mul | arith.mulf |
a / b | lox.div | arith.divf |
a < b | lox.less | arith.cmpf |
a == b | lox.equal | runtime check |
var x = v | lox.declare | llvm.alloca + store |
x | lox.load | llvm.load |
x = v | lox.assign | llvm.store |
if (c) {...} | lox.if | scf.if → cf.cond_br |
while (c) {...} | lox.while | scf.while → cf.br/cf.cond_br |
fun f(...) {...} | lox.func | func.func → llvm.func |
f(args) | lox.call | func.call → llvm.call |
return v | lox.return | func.return → llvm.return |
print v | lox.print | runtime printf call |
Next Steps
- Start small: Just numbers and arithmetic. Get
print 1 + 2;working. - Add variables: Implement local variables with
lox.declare/lox.load/lox.assign. - Add control flow:
ifandwhilewithscfdialect. - Add functions:
lox.funcandlox.call. - Add closures: This is where it gets interesting.
- Add classes/objects: The full Lox experience.
The MLIR infrastructure handles the boring parts (SSA construction, dominance checking, pass management), letting you focus on Lox-specific concerns.