MLIR for Lox: A Practical Guide

This guide reimagines the MLIR tutorial for building a Lox compiler instead of the tensor-based Toy language. If you know Crafting Interpreters, this should feel familiar.

Why MLIR for Lox?

LLVM is powerful but low-level. It doesn't know about:

  • Variable scoping rules
  • Closure captures
  • Dynamic typing
  • Lox-specific optimizations

MLIR lets you define a dialect that represents Lox semantics directly, then progressively lower it to LLVM IR. This is how modern languages like Swift, Rust, and Julia work.


Part 1: The Lox Dialect

Instead of tensors, our dialect models Lox's types and operations.

Lox Types

// include/lox/Types.td (conceptual - type list, not code)
lox.nil      - The nil type
lox.bool     - Boolean type
lox.number   - 64-bit float (Lox uses doubles)
lox.string   - String type
lox.object   - Instance/class type (dynamic)

A Sample Lox Program

// test/add.lox
fun add(a, b) {
  return a + b;
}

print add(3, 4);  // 7

The MLIR Representation

// output after: loxc test/add.lox --emit-mlir
module {
  lox.func @add(%arg0: !lox.number, %arg1: !lox.number) -> !lox.number {
    %result = lox.add %arg0, %arg1 : !lox.number
    lox.return %result : !lox.number
  }
  
  lox.func @main() {
    %three = lox.constant 3.0 : !lox.number
    %four = lox.constant 4.0 : !lox.number
    %sum = lox.call @add(%three, %four) : (!lox.number, !lox.number) -> !lox.number
    lox.print %sum : !lox.number
    lox.return
  }
}

Part 2: Defining the Lox Dialect in TableGen

MLIR uses TableGen (.td files) to declaratively define dialects. This generates C++ boilerplate for you.

Dialect Definition

// include/lox/Lox.td
// The Lox dialect definition
def Lox_Dialect : Dialect {
  let name = "lox";
  let summary = "A dialect for the Lox programming language";
  let description = [{
    This dialect represents the Lox language at a high level,
    enabling Lox-specific optimizations before lowering to LLVM.
  }];
  let cppNamespace = "lox";
}

Operation Definitions

// include/lox/Ops.td
// Base class for Lox operations
class Lox_Op<string mnemonic, list<Trait> traits = []> :
  Op<Lox_Dialect, mnemonic, traits>;

// Arithmetic: lox.add
def AddOp : Lox_Op<"add", [Pure, SameOperandsAndResultType]> {
  let summary = "Add two Lox numbers";
  let arguments = (ins Lox_NumberType:$lhs, Lox_NumberType:$rhs);
  let results = (outs Lox_NumberType:$result);
  
  let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)";
}

// Subtraction, multiplication, division follow the same pattern
def SubOp : Lox_Op<"sub", [Pure, SameOperandsAndResultType]> { ... }
def MulOp : Lox_Op<"mul", [Pure, SameOperandsAndResultType]> { ... }
def DivOp : Lox_Op<"div", [Pure, SameOperandsAndResultType]> { ... }

// Comparison: lox.less, lox.equal, etc.
def LessOp : Lox_Op<"less", [Pure]> {
  let arguments = (ins Lox_NumberType:$lhs, Lox_NumberType:$rhs);
  let results = (outs Lox_BoolType:$result);
}

// Variables: lox.declare, lox.assign, lox.load
def DeclareOp : Lox_Op<"declare"> {
  let arguments = (ins StrAttr:$name, AnyLoxType:$init);
  let results = (outs AnyLoxType:$value);
}

def LoadOp : Lox_Op<"load", [Pure]> {
  let arguments = (ins StrAttr:$name);
  let results = (outs AnyLoxType:$value);
}

def AssignOp : Lox_Op<"assign"> {
  let arguments = (ins StrAttr:$name, AnyLoxType:$value);
}

// Control flow: lox.if, lox.while, lox.for
def IfOp : Lox_Op<"if", [Terminator, NoTerminator]> {
  let arguments = (ins Lox_BoolType:$condition);
  let regions = (region SizedRegion<1>:$thenRegion, 
                      OptionalRegion<1>:$elseRegion);
}

def WhileOp : Lox_Op<"while", [Terminator]> {
  let arguments = (ins Lox_BoolType:$condition);
  let regions = (region SizedRegion<1>:$body);
}

// Functions
// Lox semantics that differ from func.func:
// 1. Parameters have no declared types (dynamic typing)
// 2. Return type is optional (implicitly nil if missing)
// 3. Functions are first-class values (can be passed, returned, assigned)
// 4. May capture variables from enclosing scopes (closures)
// 5. Arity is checked at runtime, not compile time
// 6. No function overloading
def FuncOp : Lox_Op<"func", [AffineScope, IsolatedFromAbove]> {
  let summary = "A Lox function definition";
  let description = [{
    Declares a Lox function. Unlike func.func, parameters have no declared
    types (Lox is dynamically typed). The return type is also optional —
    functions that don't explicitly return implicitly return nil.
    
    Functions are first-class values: they can be passed as arguments,
    returned from other functions, and assigned to variables.
  }];
  
  let arguments = (ins 
    SymbolNameAttr:$sym_name,         // Function name
    ArrayAttr:$param_names,           // Parameter names (not types!)
    OptionalAttr<TypeAttr>:$return_type  // Optional return type
  );
  
  let regions = (region AnyRegion:$body);
  
  let assemblyFormat = [{
    $sym_name `(` $param_names `)` (`->` $return_type^)? attr-dict $body
  }];
  
  let skipDefaultBuilders = 1;
  
  // Builder for functions with no explicit return type (returns nil)
  let builders = [
    OpBuilder<(ins "StringRef":$name, "ArrayRef<StringRef>":$params)>
  ];
}

def CallOp : Lox_Op<"call", [Pure]> {
  let arguments = (ins FlatSymbolRefAttr:$callee, 
                       Variadic<AnyLoxType>:$operands);
  let results = (outs Variadic<AnyLoxType>:$results);
}

def ReturnOp : Lox_Op<"return", [Terminator]> {
  let arguments = (ins Optional<AnyLoxType>:$value);
}

// Print builtin
def PrintOp : Lox_Op<"print"> {
  let arguments = (ins AnyLoxType:$value);
}

Part 3: From AST to MLIR

This is like Chapter 2 of the Toy tutorial, but for Lox.

The Lox AST

Your parser (from Crafting Interpreters) produces an AST. Here's the full structure:

// include/lox/AST.h
// AST node types (used for dispatching in the visitor)
enum class StmtType {
  FUNCTION, RETURN, VAR, IF, WHILE, PRINT, BLOCK, EXPRESSION
};

enum class ExprType {
  BINARY, UNARY, LITERAL, GROUPING, VARIABLE, ASSIGN, CALL, LOGICAL
};

// ========================================================================
// Base classes
// ========================================================================
struct Location {
  int line;
  int column;
};

struct Expr {
  ExprType type;
  Location loc;
  virtual ~Expr() = default;
};

struct Stmt {
  StmtType type;
  virtual ~Stmt() = default;
};

// ========================================================================
// Statements
// ========================================================================

// A whole program is just a list of top-level statements
struct Program {
  std::vector<Stmt*> statements;
};

struct FunctionStmt : Stmt {
  std::string name;
  std::vector<std::string> params;
  std::vector<Stmt*> body;
  
  FunctionStmt(std::string n, std::vector<std::string> p, std::vector<Stmt*> b)
    : Stmt{StmtType::FUNCTION}, name(std::move(n)), params(std::move(p)), body(std::move(b)) {}
};

struct ReturnStmt : Stmt {
  Expr* value;  // Can be null for bare "return;"
  
  ReturnStmt(Expr* v) : Stmt{StmtType::RETURN}, value(v) {}
};

struct VarStmt : Stmt {
  std::string name;
  Expr* init;
  
  VarStmt(std::string n, Expr* i) : Stmt{StmtType::VAR}, name(std::move(n)), init(i) {}
};

struct IfStmt : Stmt {
  Expr* condition;
  std::vector<Stmt*> thenBranch;
  std::vector<Stmt*> elseBranch;  // Empty if no else clause
  
  IfStmt(Expr* c, std::vector<Stmt*> then, std::vector<Stmt*> els)
    : Stmt{StmtType::IF}, condition(c), thenBranch(std::move(then)), elseBranch(std::move(els)) {}
};

struct WhileStmt : Stmt {
  Expr* condition;
  std::vector<Stmt*> body;
  
  WhileStmt(Expr* c, std::vector<Stmt*> b)
    : Stmt{StmtType::WHILE}, condition(c), body(std::move(b)) {}
};

struct PrintStmt : Stmt {
  Expr* expression;
  
  PrintStmt(Expr* e) : Stmt{StmtType::PRINT}, expression(e) {}
};

struct BlockStmt : Stmt {
  std::vector<Stmt*> statements;
  
  BlockStmt(std::vector<Stmt*> stmts) : Stmt{StmtType::BLOCK}, statements(std::move(stmts)) {}
};

struct ExpressionStmt : Stmt {
  Expr* expression;
  
  ExpressionStmt(Expr* e) : Stmt{StmtType::EXPRESSION}, expression(e) {}
};

// ========================================================================
// Expressions
// ========================================================================

struct BinaryExpr : Expr {
  Expr* left;
  TokenType op;   // PLUS, MINUS, STAR, SLASH, etc.
  Expr* right;
  
  BinaryExpr(Expr* l, TokenType o, Expr* r)
    : Expr{ExprType::BINARY}, left(l), op(o), right(r) {}
};

struct UnaryExpr : Expr {
  TokenType op;   // MINUS, BANG
  Expr* right;
  
  UnaryExpr(TokenType o, Expr* r) : Expr{ExprType::UNARY}, op(o), right(r) {}
};

struct LiteralExpr : Expr {
  LoxValue value;  // Your value type from Crafting Interpreters
  
  LiteralExpr(LoxValue v) : Expr{ExprType::LITERAL}, value(std::move(v)) {}
};

struct GroupingExpr : Expr {
  Expr* expression;
  
  GroupingExpr(Expr* e) : Expr{ExprType::GROUPING}, expression(e) {}
};

struct VarExpr : Expr {
  std::string name;
  
  VarExpr(std::string n) : Expr{ExprType::VARIABLE}, name(std::move(n)) {}
};

struct AssignExpr : Expr {
  std::string name;
  Expr* value;
  
  AssignExpr(std::string n, Expr* v) 
    : Expr{ExprType::ASSIGN}, name(std::move(n)), value(v) {}
};

struct CallExpr : Expr {
  Expr* callee;
  std::vector<Expr*> arguments;
  
  CallExpr(Expr* c, std::vector<Expr*> args)
    : Expr{ExprType::CALL}, callee(c), arguments(std::move(args)) {}
};

struct LogicalExpr : Expr {
  Expr* left;
  TokenType op;  // AND, OR
  Expr* right;
  
  LogicalExpr(Expr* l, TokenType o, Expr* r)
    : Expr{ExprType::LOGICAL}, left(l), op(o), right(r) {}
};

How the AST is Built

Your parser (from Crafting Interpreters) constructs these nodes while parsing:

// lib/Parser.cpp (simplified excerpt from Crafting Interpreters)
class Parser {
  std::vector<Token> tokens;
  size_t current = 0;
  
public:
  Program* parse() {
    Program* program = new Program();
    while (!isAtEnd()) {
      program->statements.push_back(declaration());
    }
    return program;
  }
  
private:
  // Top-level declarations (functions, variables, statements)
  Stmt* declaration() {
    if (match(TokenType::FUN))   return functionDeclaration();
    if (match(TokenType::VAR))   return varDeclaration();
    return statement();
  }
  
  Stmt* functionDeclaration() {
    std::string name = consume(TokenType::IDENTIFIER, "Expect function name.");
    consume(TokenType::LEFT_PAREN, "Expect '(' after function name.");
    
    std::vector<std::string> params;
    if (!check(TokenType::RIGHT_PAREN)) {
      do {
        params.push_back(consume(TokenType::IDENTIFIER, "Expect parameter name.").lexeme);
      } while (match(TokenType::COMMA));
    }
    consume(TokenType::RIGHT_PAREN, "Expect ')' after parameters.");
    
    consume(TokenType::LEFT_BRACE, "Expect '{' before function body.");
    std::vector<Stmt*> body = block();
    
    return new FunctionStmt(name, params, body);
  }
  
  Stmt* statement() {
    if (match(TokenType::PRINT))      return printStatement();
    if (match(TokenType::IF))         return ifStatement();
    if (match(TokenType::WHILE))      return whileStatement();
    if (match(TokenType::FOR))        return forStatement();
    if (match(TokenType::RETURN))     return returnStatement();
    if (match(TokenType::LEFT_BRACE)) return new BlockStmt(block());
    return expressionStatement();
  }
  
  Stmt* printStatement() {
    Expr* value = expression();
    consume(TokenType::SEMICOLON, "Expect ';' after value.");
    return new PrintStmt(value);
  }
  
  Stmt* ifStatement() {
    consume(TokenType::LEFT_PAREN, "Expect '(' after 'if'.");
    Expr* condition = expression();
    consume(TokenType::RIGHT_PAREN, "Expect ')' after if condition.");
    
    std::vector<Stmt*> thenBranch = { statement() };
    std::vector<Stmt*> elseBranch;
    if (match(TokenType::ELSE)) {
      elseBranch.push_back(statement());
    }
    return new IfStmt(condition, thenBranch, elseBranch);
  }
  
  std::vector<Stmt*> block() {
    std::vector<Stmt*> statements;
    while (!check(TokenType::RIGHT_BRACE) && !isAtEnd()) {
      statements.push_back(declaration());
    }
    consume(TokenType::RIGHT_BRACE, "Expect '}' after block.");
    return statements;
  }
  
  // ... rest of parser ...
  
  // Expression parsing (precedence climbing)
  Expr* expression() { return assignment(); }
  
  Expr* assignment() {
    Expr* expr = orExpr();
    
    if (match(TokenType::EQUAL)) {
      Expr* value = assignment();
      
      // Check if left side is a variable
      if (VarExpr* var = dynamic_cast<VarExpr*>(expr)) {
        return new AssignExpr(var->name, value);
      }
      error("Invalid assignment target.");
    }
    return expr;
  }
  
  Expr* orExpr() {
    Expr* expr = andExpr();
    
    while (match(TokenType::OR)) {
      Expr* right = andExpr();
      expr = new LogicalExpr(expr, TokenType::OR, right);
    }
    return expr;
  }
  
  // ... more expression parsing methods ...
  
  Expr* term() {
    Expr* expr = factor();
    
    while (match(TokenType::MINUS, TokenType::PLUS)) {
      Expr* right = factor();
      expr = new BinaryExpr(expr, previous().type, right);
    }
    return expr;
  }
};

Example: Parsing print 1 + 2;

Parser::parse()
  └── declaration()
        └── statement()
              └── printStatement()
                    ├── expression()
                    │     └── term()
                    │           └── factor()
                    │                 ├── primary() → LiteralExpr(1)
                    │                 ├── match(PLUS) → true
                    │                 └── factor()
                    │                       └── primary() → LiteralExpr(2)
                    │                 → BinaryExpr(Literal(1), PLUS, Literal(2))
                    └── consume(SEMICOLON)
              → PrintStmt(BinaryExpr(...))

The MLIR Generator

You write a visitor that walks the AST and emits MLIR. The visitor pattern dispatches based on AST node type — each visit* method handles one node type.

// lib/IR/MLIRGen.cpp
#include "lox/Dialect.h"
#include "lox/Ops.h"
#include "lox/AST.h"  // Your AST classes from Crafting Interpreters

class MLIRGenerator : public ExprVisitor, public StmtVisitor {
private:
  mlir::OpBuilder builder;
  mlir::ModuleOp module;
  mlir::Location currentLoc;
  
public:
  // ========================================================================
  // Entry point: Convert a whole program (list of statements) to MLIR
  // ========================================================================
  mlir::ModuleOp generateModule(Program *program) {
    module = mlir::ModuleOp::create(builder.getUnknownLoc());
    
    // Visit each top-level statement (function declarations, etc.)
    for (Stmt *stmt : program->statements) {
      visitStmt(stmt);
    }
    
    return module;
  }
  
  // ========================================================================
  // Statement visitors (called by visitStmt, which dispatches by type)
  // ========================================================================
  
  // The dispatcher: called for every statement, routes to the right visit method
  void visitStmt(Stmt *stmt) {
    // Each Stmt subclass has a type tag we switch on
    switch (stmt->type) {
      case StmtType::FUNCTION:   return visitFunctionStmt((FunctionStmt*)stmt);
      case StmtType::RETURN:     return visitReturnStmt((ReturnStmt*)stmt);
      case StmtType::VAR:        return visitVarStmt((VarStmt*)stmt);
      case StmtType::IF:         return visitIfStmt((IfStmt*)stmt);
      case StmtType::WHILE:      return visitWhileStmt((WhileStmt*)stmt);
      case StmtType::PRINT:      return visitPrintStmt((PrintStmt*)stmt);
      case StmtType::BLOCK:      return visitBlockStmt((BlockStmt*)stmt);
      case StmtType::EXPRESSION: return visitExpressionStmt((ExpressionStmt*)stmt);
    }
  }
  
  void visitFunctionStmt(FunctionStmt *stmt) {
    // Create a new function in the module
    auto func = builder.create<FuncOp>(currentLoc, stmt->name, stmt->params);
    
    // Push a new block for the function body
    auto *entryBlock = func.getBody().emplaceBlock();
    builder.setInsertionPointToStart(entryBlock);
    
    // Add block arguments for each parameter
    for (const std::string &param : stmt->params) {
      entryBlock->addArgument(builder.getType<LoxDynamicType>());
    }
    
    // Visit all statements in the function body
    for (Stmt *bodyStmt : stmt->body) {
      visitStmt(bodyStmt);
    }
    
    // Add implicit return nil if the function doesn't end with a return
    if (!endsWithReturn(func)) {
      builder.create<ReturnOp>(currentLoc);
    }
  }
  
  void visitReturnStmt(ReturnStmt *stmt) {
    mlir::Value value = stmt->value ? visitExpr(stmt->value) : nullptr;
    builder.create<ReturnOp>(currentLoc, value);
  }
  
  void visitVarStmt(VarStmt *stmt) {
    mlir::Value init = visitExpr(stmt->init);
    builder.create<DeclareOp>(currentLoc, stmt->name, init);
  }
  
  void visitIfStmt(IfStmt *stmt) {
    mlir::Value cond = visitExpr(stmt->condition);
    
    auto ifOp = builder.create<IfOp>(currentLoc, cond);
    
    // Then branch
    builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
    for (Stmt *thenStmt : stmt->thenBranch) {
      visitStmt(thenStmt);
    }
    
    // Else branch (optional)
    if (stmt->elseBranch) {
      builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
      for (Stmt *elseStmt : stmt->elseBranch) {
        visitStmt(elseStmt);
      }
    }
  }
  
  void visitWhileStmt(WhileStmt *stmt) {
    mlir::Value cond = visitExpr(stmt->condition);
    auto whileOp = builder.create<WhileOp>(currentLoc, cond);
    
    builder.setInsertionPointToStart(&whileOp.getBody().front());
    for (Stmt *bodyStmt : stmt->body) {
      visitStmt(bodyStmt);
    }
  }
  
  void visitPrintStmt(PrintStmt *stmt) {
    mlir::Value value = visitExpr(stmt->expression);
    builder.create<PrintOp>(currentLoc, value);
  }
  
  void visitExpressionStmt(ExpressionStmt *stmt) {
    visitExpr(stmt->expression);
  }
  
  void visitBlockStmt(BlockStmt *stmt) {
    for (Stmt *blockStmt : stmt->statements) {
      visitStmt(blockStmt);
    }
  }
  
  // ========================================================================
  // Expression visitors (called by visitExpr, which dispatches by type)
  // ========================================================================
  
  // The dispatcher: called for every expression, routes to the right visit method
  // This is what calls visitBinaryExpr, visitLiteralExpr, etc.
  mlir::Value visitExpr(Expr *expr) {
    currentLoc = convertLocation(expr->loc);
    
    switch (expr->type) {
      case ExprType::BINARY:   return visitBinaryExpr((BinaryExpr*)expr);
      case ExprType::UNARY:    return visitUnaryExpr((UnaryExpr*)expr);
      case ExprType::LITERAL:  return visitLiteralExpr((LiteralExpr*)expr);
      case ExprType::GROUPING: return visitGroupingExpr((GroupingExpr*)expr);
      case ExprType::VARIABLE: return visitVarExpr((VarExpr*)expr);
      case ExprType::ASSIGN:   return visitAssignExpr((AssignExpr*)expr);
      case ExprType::CALL:     return visitCallExpr((CallExpr*)expr);
      case ExprType::LOGICAL:  return visitLogicalExpr((LogicalExpr*)expr);
    }
    return nullptr;
  }
  
  mlir::Value visitBinaryExpr(BinaryExpr *expr) {
    // Recursively visit left and right operands
    // These calls go through visitExpr, which dispatches to the right visitor
    mlir::Value lhs = visitExpr(expr->left);
    mlir::Value rhs = visitExpr(expr->right);
    
    switch (expr->op) {
      case TokenType::PLUS:
        return builder.create<AddOp>(currentLoc, lhs, rhs);
      case TokenType::MINUS:
        return builder.create<SubOp>(currentLoc, lhs, rhs);
      case TokenType::STAR:
        return builder.create<MulOp>(currentLoc, lhs, rhs);
      case TokenType::SLASH:
        return builder.create<DivOp>(currentLoc, lhs, rhs);
      case TokenType::LESS:
        return builder.create<LessOp>(currentLoc, lhs, rhs);
      case TokenType::LESS_EQUAL:
        return builder.create<LessEqualOp>(currentLoc, lhs, rhs);
      case TokenType::GREATER:
        return builder.create<GreaterOp>(currentLoc, lhs, rhs);
      case TokenType::GREATER_EQUAL:
        return builder.create<GreaterEqualOp>(currentLoc, lhs, rhs);
      case TokenType::EQUAL_EQUAL:
        return builder.create<EqualOp>(currentLoc, lhs, rhs);
      case TokenType::BANG_EQUAL:
        return builder.create<NotEqualOp>(currentLoc, lhs, rhs);
      default:
        emitError(currentLoc, "Unknown binary operator");
        return nullptr;
    }
  }
  
  mlir::Value visitUnaryExpr(UnaryExpr *expr) {
    mlir::Value operand = visitExpr(expr->right);
    
    switch (expr->op) {
      case TokenType::MINUS:
        return builder.create<NegateOp>(currentLoc, operand);
      case TokenType::BANG:
        return builder.create<NotOp>(currentLoc, operand);
      default:
        emitError(currentLoc, "Unknown unary operator");
        return nullptr;
    }
  }
  
  mlir::Value visitLiteralExpr(LiteralExpr *expr) {
    if (expr->value.isNumber()) {
      return builder.create<ConstantOp>(currentLoc, expr->value.asNumber());
    }
    if (expr->value.isBool()) {
      return builder.create<ConstantOp>(currentLoc, expr->value.asBool());
    }
    if (expr->value.isString()) {
      return builder.create<ConstantOp>(currentLoc, expr->value.asString());
    }
    // nil
    return builder.create<NilOp>(currentLoc);
  }
  
  mlir::Value visitGroupingExpr(GroupingExpr *expr) {
    // Grouping is just parentheses - emit the inner expression directly
    return visitExpr(expr->expression);
  }
  
  mlir::Value visitVarExpr(VarExpr *expr) {
    return builder.create<LoadOp>(currentLoc, expr->name);
  }
  
  mlir::Value visitAssignExpr(AssignExpr *expr) {
    mlir::Value value = visitExpr(expr->value);
    builder.create<AssignOp>(currentLoc, expr->name, value);
    return value;
  }
  
  mlir::Value visitCallExpr(CallExpr *expr) {
    mlir::Value callee = visitExpr(expr->callee);
    
    // Visit all arguments
    llvm::SmallVector<mlir::Value, 4> args;
    for (Expr *arg : expr->arguments) {
      args.push_back(visitExpr(arg));
    }
    
    return builder.create<CallOp>(currentLoc, callee, args);
  }
  
  mlir::Value visitLogicalExpr(LogicalExpr *expr) {
    mlir::Value left = visitExpr(expr->left);
    
    // Logical AND/OR short-circuit, so they need control flow
    // lox.and and lox.or operations handle this
    switch (expr->op) {
      case TokenType::AND:
        return builder.create<AndOp>(currentLoc, left, 
          /* right is a lazy region */ expr->right);
      case TokenType::OR:
        return builder.create<OrOp>(currentLoc, left, 
          /* right is a lazy region */ expr->right);
      default:
        return nullptr;
    }
  }
};

How the visitor works:

generateProgram()
    │
    ├── for each statement:
    │       │
    │       visitStmt(stmt)  ───── dispatches based on stmt->type
    │           │
    │           ├── visitFunctionStmt()   if StmtType::FUNCTION
    │           ├── visitIfStmt()         if StmtType::IF
    │           ├── visitWhileStmt()      if StmtType::WHILE
    │           └── ...
    │
    └── statements call visitExpr() for their expressions:
            │
            visitExpr(expr)  ───── dispatches based on expr->type
                │
                ├── visitBinaryExpr()   if ExprType::BINARY
                │       │
                │       ├── visitExpr(left)   ── recursive!
                │       └── visitExpr(right)  ── recursive!
                │
                ├── visitLiteralExpr()  if ExprType::LITERAL
                ├── visitCallExpr()     if ExprType::CALL
                └── ...

Part 4: Lox-Specific Optimizations

This is where MLIR shines. You can write passes that understand Lox semantics.

Example: Constant Folding

// lib/Transforms/LoxOpt.cpp
#include "lox/Ops.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

// Fold lox.add(constant, constant) -> constant
struct FoldConstantAdd : public OpRewritePattern<AddOp> {
  using OpRewritePattern<AddOp>::OpRewritePattern;
  
  LogicalResult matchAndRewrite(AddOp op, PatternRewriter &rewriter) const override {
    // Check if both operands are constants
    auto lhsDef = op.getLhs().getDefiningOp<ConstantOp>();
    auto rhsDef = op.getRhs().getDefiningOp<ConstantOp>();
    
    if (!lhsDef || !rhsDef) return failure();
    
    // Fold the constants
    double result = lhsDef.getValue() + rhsDef.getValue();
    rewriter.replaceOpWithNewOp<ConstantOp>(op, result);
    return success();
  }
};

Example: Dead Variable Elimination

// lib/Transforms/LoxOpt.cpp (continued)
// Remove variables that are declared but never used
struct EliminateDeadVariable : public OpRewritePattern<DeclareOp> {
  using OpRewritePattern<DeclareOp>::OpRewritePattern;
  
  LogicalResult matchAndRewrite(DeclareOp op, PatternRewriter &rewriter) const override {
    // Check if the variable is ever loaded
    if (hasUses(op.getName())) return failure();
    
    // Variable is never used, remove it
    rewriter.eraseOp(op);
    return success();
  }
};

Example: Inline Simple Functions

// lib/Transforms/LoxOpt.cpp (continued)
// Inline functions that are small (e.g., just return an expression)
struct InlineSimpleFunctions : public OpRewritePattern<CallOp> {
  using OpRewritePattern<CallOp>::OpRewritePattern;
  
  LogicalResult matchAndRewrite(CallOp call, PatternRewriter &rewriter) const override {
    FuncOp func = lookupFunction(call.getCallee());
    
    // Only inline small functions
    if (!isSimpleEnoughToInline(func)) return failure();
    
    // Clone the function body and substitute arguments
    inlineFunction(call, func, rewriter);
    return success();
  }
};

Part 5: Lowering to Standard Dialects

Now we progressively transform lox.* operations into lower-level MLIR dialects.

Step 1: Lox Dialect → SCF + Arith Dialects

Control flow becomes structured control flow (scf.if, scf.for), arithmetic becomes arith.addf, etc.

// lib/Transforms/LowerToLLVM.cpp
#include "lox/Ops.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"

// Convert lox.add to arith.addf
struct AddOpLowering : public OpConversionPattern<AddOp> {
  using OpConversionPattern<AddOp>::OpConversionPattern;
  
  LogicalResult matchAndRewrite(AddOp op, OpAdaptor adaptor,
                                 ConversionPatternRewriter &rewriter) const override {
    rewriter.replaceOpWithNewOp<mlir::arith::AddFOp>(op, adaptor.getLhs(), adaptor.getRhs());
    return success();
  }
};

// Convert lox.while to scf.while
struct WhileOpLowering : public OpConversionPattern<WhileOp> {
  LogicalResult matchAndRewrite(WhileOp op, OpAdaptor adaptor,
                                 ConversionPatternRewriter &rewriter) const override {
    // Convert lox.while body to scf.while
    auto whileOp = rewriter.create<scf::WhileOp>(
      op.getLoc(), TypeRange{}, adaptor.getCondition());
    
    // Convert the body region
    rewriter.inlineRegionBefore(op.getBody(), whileOp.getBody(), 
                                 whileOp.getBody().begin());
    
    rewriter.eraseOp(op);
    return success();
  }
};

Step 2: SCF + Arith → LLVM Dialect

This is handled by built-in MLIR passes:

# Run from build directory or project root
mlir-opt lox.mlir \
  --convert-lox-to-scf \
  --convert-scf-to-cf \
  --convert-cf-to-llvm \
  --convert-arith-to-llvm \
  --convert-func-to-llvm \
  -o lox_llvm.mlir

Step 3: LLVM Dialect → LLVM IR → Machine Code

# Run from build directory or project root
mlir-translate lox_llvm.mlir --mlir-to-llvmir -o lox.ll
llc lox.ll -o lox.s
clang lox.s -o lox

Part 6: Handling Dynamic Typing

Lox is dynamically typed, but MLIR is typed. Options:

Option A: Tagged Values

Every Lox value is a struct with a type tag:

// include/lox/Types.td (type definition)
!lox.value = struct<{
  tag: i8,        // 0=nil, 1=bool, 2=number, 3=string, 4=object
  data: union<i1, f64, !lox.string, !lox.object>
}>

Operations check the tag at runtime:

// lib/Transforms/RuntimeChecks.cpp (generated MLIR for runtime type checking)
// lox.add checks that both operands are numbers
%lhs_tag = llvm.extractvalue %lhs[0] : !lox.value
%rhs_tag = llvm.extractvalue %rhs[0] : !lox.value
%tags_match = arith.cmpi eq, %lhs_tag, %rhs_tag : i8
cf.cond_br %tags_match, ^ok, ^type_error

^ok:
  %lhs_num = llvm.extractvalue %lhs[1] : !lox.value
  %rhs_num = llvm.extractvalue %rhs[1] : !lox.value
  %result = arith.addf %lhs_num, %rhs_num : f64
  // ... pack result back into tagged value ...

Option B: Specialize by Type

Generate specialized versions of functions for each type combination:

// lib/Transforms/TypeSpecialization.cpp (generated after type specialization)
// Generic call
lox.call @add(%a, %b)

// After specialization (if both are numbers)
lox.call @add_numbers(%a, %b)

// If strings, call string concatenation
lox.call @add_strings(%a, %b)

Part 7: Closures

Closures are the trickiest part of Lox. MLIR's regions help.

Captured Variables

// test/closure.lox
fun makeCounter() {
  var count = 0;
  fun counter() {
    count = count + 1;
    return count;
  }
  return counter;
}

MLIR Representation

// output after: loxc test/closure.lox --emit-mlir
lox.func @makeCounter() -> !lox.closure {
  // Create a closure environment
  %env = lox.alloc_env { size = 1 }
  
  // Store count in the environment
  %zero = lox.constant 0.0 : !lox.number
  lox.env_store %env[0] = %zero : !lox.number
  
  // Create the closure, capturing the environment
  %closure = lox.make_closure @counter(%env)
  lox.return %closure : !lox.closure
}

lox.func @counter(%env: !lox.env) -> !lox.number {
  // Load count from the environment
  %count = lox.env_load %env[0] : !lox.number
  
  // Increment
  %one = lox.constant 1.0 : !lox.number
  %new_count = lox.add %count, %one : !lox.number
  
  // Store back
  lox.env_store %env[0] = %new_count : !lox.number
  
  lox.return %new_count : !lox.number
}

Part 8: Project Structure

lox-mlir/
├── include/
│   └── lox/
│       ├── Lox.td           # Dialect definition (TableGen)
│       ├── Ops.td           # Operation definitions
│       ├── Types.td         # Type definitions
│       ├── Dialect.h        # Generated C++ header
│       └── Dialect.cpp      # Dialect implementation
├── lib/
│   ├── Dialect.cpp          # Dialect registration
│   ├── IR/
│   │   └── MLIRGen.cpp      # AST → MLIR conversion
│   └── Transforms/
│       ├── LoxOpt.cpp       # Lox-specific optimizations
│       └── LowerToLLVM.cpp  # Lowering passes
├── tools/
│   └── loxc/
│       └── main.cpp         # Compiler driver
└── test/
    └── *.lox                # Test programs

Part 9: Build System (CMake)

# CMakeLists.txt (project root)
cmake_minimum_required(VERSION 3.16)
project(lox-mlir)

find_package(MLIR REQUIRED CONFIG)

# Generate dialect files from TableGen
mlir_tablegen(LoxDialect.h.inc -gen-dialect-decls)
mlir_tablegen(LoxDialect.cpp.inc -gen-dialect-defs)
mlir_tablegen(LoxOps.h.inc -gen-op-decls)
mlir_tablegen(LoxOps.cpp.inc -gen-op-defs)

add_library(LoxDialect
  Dialect.cpp
)
target_link_libraries(LoxDialect MLIRIR)

add_library(LoxTransforms
  Transforms/LoxOpt.cpp
  Transforms/LowerToLLVM.cpp
)
target_link_libraries(LoxTransforms LoxDialect MLIRTransformUtils)

add_executable(loxc tools/loxc/main.cpp)
target_link_libraries(loxc LoxDialect LoxTransforms MLIRParser)

Quick Reference: Lox → MLIR Mapping

Lox ConstructMLIR OperationLowered To
a + blox.addarith.addf
a - blox.subarith.subf
a * blox.mularith.mulf
a / blox.divarith.divf
a < blox.lessarith.cmpf
a == blox.equalruntime check
var x = vlox.declarellvm.alloca + store
xlox.loadllvm.load
x = vlox.assignllvm.store
if (c) {...}lox.ifscf.ifcf.cond_br
while (c) {...}lox.whilescf.whilecf.br/cf.cond_br
fun f(...) {...}lox.funcfunc.funcllvm.func
f(args)lox.callfunc.callllvm.call
return vlox.returnfunc.returnllvm.return
print vlox.printruntime printf call

Next Steps

  1. Start small: Just numbers and arithmetic. Get print 1 + 2; working.
  2. Add variables: Implement local variables with lox.declare/lox.load/lox.assign.
  3. Add control flow: if and while with scf dialect.
  4. Add functions: lox.func and lox.call.
  5. Add closures: This is where it gets interesting.
  6. Add classes/objects: The full Lox experience.

The MLIR infrastructure handles the boring parts (SSA construction, dominance checking, pass management), letting you focus on Lox-specific concerns.