MLIR for Lox: Part 4 - MLIR Integration

Now we have a working GC runtime. Let's integrate it with our MLIR code generator.


Chapter 12: The Runtime Module

First, let's organize our runtime into a proper Rust module:

#![allow(unused)]
fn main() {
// src/runtime/mod.rs

pub mod object;
pub mod shadow_stack;
pub mod gc;

// Re-export the public API
pub use object::{ObjHeader, ObjType};
pub use gc::{alloc, gc_collect};
pub use shadow_stack::{gc_push_frame, gc_pop_frame, gc_set_root};
}

And the Cargo.toml needs to build it as a static library:

# Cargo.toml

[package]
name = "lox-runtime"
version = "0.1.0"
edition = "2021"

[lib]
name = "lox_runtime"
crate-type = ["staticlib", "cdylib"]  # Both static and dynamic lib

[dependencies]

Chapter 13: The Lox MLIR Dialect

We need MLIR operations that correspond to our GC operations:

OperationMeaningLLVM Lowering
lox.allocAllocate a heap objectCall lox_runtime::alloc
lox.gc_rootRegister a pointer as a GC rootStore in shadow stack slot
lox.gc_push_framePush a stack frameCall gc_push_frame
lox.gc_pop_framePop a stack frameCall gc_pop_frame

Defining Operations in Melior

Melior doesn't use TableGen like C++ MLIR. We define operations directly in Rust:

#![allow(unused)]
fn main() {
// src/codegen/lox_dialect.rs

use melior::{
    Context, Location,
    dialect::Dialect,
    ir::{Operation, OperationBuilder, Region, Type, Value},
};

/// The Lox dialect namespace
pub const DIALECT_NAME: &str = "lox";

/// Operation names
pub mod ops {
    pub const ALLOC: &str = "lox.alloc";
    pub const LOAD: &str = "lox.load";
    pub const STORE: &str = "lox.store";
    pub const PUSH_FRAME: &str = "lox.push_frame";
    pub const POP_FRAME: &str = "lox.pop_frame";
    pub const SET_ROOT: &str = "lox.set_root";
}

/// Create a lox.alloc operation
/// 
/// This allocates a heap object of the given type.
/// Returns a pointer to the object data (after the header).
pub fn create_alloc<'c>(
    context: &'c Context,
    obj_type: u8,      // ObjType enum value
    size: usize,       // Size of object data in bytes
    location: Location<'c>,
) -> melior::ir::Operation<'c> {
    // The lox.alloc operation takes:
    //   - type: i8 (the ObjType enum)
    //   - size: i64 (allocation size)
    // And returns:
    //   - ptr: !llvm.ptr (pointer to object data)
    
    OperationBuilder::new(ops::ALLOC, location)
        .add_attribute("obj_type", 
            melior::ir::attribute::IntegerAttribute::new(obj_type as i64, 
                Type::integer(context, 8)).into())
        .add_attribute("size",
            melior::ir::attribute::IntegerAttribute::new(size as i64,
                Type::integer(context, 64)).into())
        .add_results(&[Type::parse(context, "!llvm.ptr").unwrap()])
        .build()
        .unwrap()
}

/// Create a lox.push_frame operation
/// 
/// This pushes a new shadow stack frame with the given number of root slots.
/// Returns a pointer to the roots array.
pub fn create_push_frame<'c>(
    context: &'c Context,
    root_count: usize,
    location: Location<'c>,
) -> melior::ir::Operation<'c> {
    OperationBuilder::new(ops::PUSH_FRAME, location)
        .add_attribute("root_count",
            melior::ir::attribute::IntegerAttribute::new(root_count as i64,
                Type::integer(context, 64)).into())
        .add_results(&[Type::parse(context, "!llvm.ptr").unwrap()])
        .build()
        .unwrap()
}

/// Create a lox.pop_frame operation
pub fn create_pop_frame<'c>(
    context: &'c Context,
    location: Location<'c>,
) -> melior::ir::Operation<'c> {
    OperationBuilder::new(ops::POP_FRAME, location)
        .build()
        .unwrap()
}

/// Create a lox.set_root operation
/// 
/// Sets a root in the current shadow stack frame.
pub fn create_set_root<'c>(
    context: &'c Context,
    root_index: usize,
    value: Value<'c>,
    location: Location<'c>,
) -> melior::ir::Operation<'c> {
    OperationBuilder::new(ops::SET_ROOT, location)
        .add_attribute("index",
            melior::ir::attribute::IntegerAttribute::new(root_index as i64,
                Type::integer(context, 64)).into())
        .add_operand(value)
        .build()
        .unwrap()
}
}

Chapter 14: Lowering to LLVM

Our lox.* operations need to be converted to LLVM IR. We do this with a lowering pass:

#![allow(unused)]
fn main() {
// src/codegen/lowering.rs

use melior::{
    Context, Location, PassManager,
    ir::{Block, Module, Operation, Region, Value},
    dialect::{func, llvm},
    pass::Pass,
};

/// Lower lox.alloc to a runtime call
fn lower_alloc(op: &Operation, block: &mut Block, context: &Context) {
    let location = op.location();
    
    // Get attributes
    let obj_type = op.attribute("obj_type").unwrap()
        .as_integer().unwrap() as i64;
    let size = op.attribute("size").unwrap()
        .as_integer().unwrap() as i64;
    
    // Create constants for arguments
    let obj_type_val = create_const_i8(context, obj_type as i8);
    let size_val = create_const_i64(context, size);
    
    // Call lox_runtime_alloc(type, size)
    let call = func::call(
        context,
        melior::ir::attribute::FlatSymbolRefAttribute::new(context, "lox_runtime_alloc"),
        &[obj_type_val, size_val],
        &[Type::parse(context, "!llvm.ptr").unwrap()],
        location,
    );
    
    block.append_operation(call.clone());
    
    // Replace uses of the original result
    let result = op.result(0).unwrap();
    let new_result = call.result(0).unwrap();
    
    // (In real code, we'd track and replace all uses)
}

/// Lower lox.push_frame to a runtime call
fn lower_push_frame(op: &Operation, block: &mut Block, context: &Context) {
    let location = op.location();
    
    let root_count = op.attribute("root_count").unwrap()
        .as_integer().unwrap() as i64;
    
    let count_val = create_const_i64(context, root_count);
    
    // Call gc_push_frame(root_count)
    let call = func::call(
        context,
        melior::ir::attribute::FlatSymbolRefAttribute::new(context, "gc_push_frame"),
        &[count_val],
        &[Type::parse(context, "!llvm.ptr").unwrap()],
        location,
    );
    
    block.append_operation(call.clone());
    
    // Replace uses of the original result
    let result = op.result(0).unwrap();
    let new_result = call.result(0).unwrap();
}

/// Lower lox.pop_frame to a runtime call
fn lower_pop_frame(op: &Operation, block: &mut Block, context: &Context) {
    let location = op.location();
    
    // Call gc_pop_frame()
    let call = func::call(
        context,
        melior::ir::attribute::FlatSymbolRefAttribute::new(context, "gc_pop_frame"),
        &[],
        &[],
        location,
    );
    
    block.append_operation(call);
}

/// Lower lox.set_root to a store instruction
fn lower_set_root(op: &Operation, block: &mut Block, context: &Context) {
    let location = op.location();
    
    let root_index = op.attribute("index").unwrap()
        .as_integer().unwrap() as i64;
    let value = op.operand(0).unwrap();
    
    // Get the frame pointer (from the most recent push_frame)
    // In real code, we'd track this
    
    // Store value at frame_ptr[index]
    // llvm.store value, frame_ptr[index]
}
}

Chapter 15: Code Generation for Functions

Now we modify our function code generator to use the shadow stack:

#![allow(unused)]
fn main() {
// src/codegen/generator.rs (modified)

impl<'c> CodeGenerator<'c> {
    
    /// Compile a function with GC support
    fn compile_function(&mut self, func: &FunctionStmt) {
        let location = self.loc(func.location);
        
        // === STEP 1: Count roots ===
        // Roots = parameters + local variables
        let root_count = self.count_roots(func);
        
        // === STEP 2: Create function type ===
        let float_type = Type::float64(self.context);
        let param_types: Vec<Type> = func.params.iter().map(|_| float_type).collect();
        let function_type = FunctionType::new(self.context, &param_types, &[float_type]);
        
        // === STEP 3: Create function body ===
        let region = Region::new();
        let block = Block::new(
            &param_types.iter().map(|&t| (t, location)).collect::<Vec<_>>()
        );
        
        // === STEP 4: Push shadow stack frame ===
        let push_frame = create_push_frame(self.context, root_count, location);
        block.append_operation(push_frame.clone());
        
        // The result is a pointer to the roots array
        let roots_ptr = push_frame.result(0).unwrap().into();
        
        // === STEP 5: Store parameters as roots ===
        for (i, param_name) in func.params.iter().enumerate() {
            let arg = block.argument(i).unwrap();
            
            // Store the parameter value in roots[i]
            // (In a real implementation, we'd handle this based on type)
            self.set_root(&block, i, arg.into(), location);
            
            // Also track in our local variables map
            self.variables.insert(param_name.clone(), arg.into());
        }
        
        // === STEP 6: Compile function body ===
        self.current_block = Some(block);
        self.current_root_index = func.params.len();  // Next free root slot
        
        for stmt in &func.body {
            self.compile_statement(stmt);
        }
        
        // === STEP 7: Add implicit return if needed ===
        // ...
        
        // === STEP 8: Pop shadow stack frame ===
        let pop_frame = create_pop_frame(self.context, location);
        if let Some(block) = &self.current_block {
            block.append_operation(pop_frame);
        }
        
        // === STEP 9: Create the function ===
        // ... append block to region, add func to module ...
    }
    
    /// Count the total number of roots needed for a function
    fn count_roots(&self, func: &FunctionStmt) -> usize {
        // Parameters are roots
        let mut count = func.params.len();
        
        // Local variables are roots
        for stmt in &func.body {
            count += self.count_roots_in_stmt(stmt);
        }
        
        count
    }
    
    /// Count roots introduced by a statement
    fn count_roots_in_stmt(&self, stmt: &Stmt) -> usize {
        match stmt {
            Stmt::Var(v) => 1,  // Each var declaration is a root
            Stmt::Block(b) => b.statements.iter()
                .map(|s| self.count_roots_in_stmt(s))
                .sum(),
            Stmt::If(i) => {
                self.count_roots_in_stmt(&i.then_branch[0]) +
                i.else_branch.iter().map(|s| self.count_roots_in_stmt(s)).sum::<usize>()
            }
            Stmt::While(w) => self.count_roots_in_stmt(&w.body[0]),
            _ => 0,
        }
    }
    
    /// Set a root in the shadow stack
    fn set_root(&mut self, block: &Block<'c>, index: usize, value: Value<'c>, location: Location<'c>) {
        let set_root_op = create_set_root(self.context, index, value, location);
        block.append_operation(set_root_op);
    }
}
}

Chapter 16: The Lowering Pass

We need a pass that converts lox.* operations to LLVM calls:

#![allow(unused)]
fn main() {
// src/codegen/lowering_pass.rs

use melior::{
    Context, Module, PassManager,
    pass::Pass,
};

/// Create the lowering pass manager
pub fn create_lowering_pass_manager(context: &Context) -> PassManager {
    let pm = PassManager::new(context);
    
    // Lower Lox dialect to LLVM
    pm.add_pass(pass::convert_lox_to_llvm());
    
    // Lower standard dialects to LLVM
    pm.add_pass(pass::convert_scf_to_cf());
    pm.add_pass(pass::convert_cf_to_llvm());
    pm.add_pass(pass::convert_arith_to_llvm());
    pm.add_pass(pass::convert_func_to_llvm());
    
    pm
}

/// Our custom Lox-to-LLVM pass
mod pass {
    use super::*;
    
    pub fn convert_lox_to_llvm() -> Pass {
        Pass::from_info("lox-to-llvm", |module: &Module| {
            // Walk all operations
            // For each lox.* operation, replace with LLVM call
            
            module.as_operation().walk(|op| {
                let op_name = op.name();
                
                match op_name {
                    "lox.alloc" => {
                        // Replace with call to lox_runtime_alloc
                    }
                    "lox.push_frame" => {
                        // Replace with call to gc_push_frame
                    }
                    "lox.pop_frame" => {
                        // Replace with call to gc_pop_frame
                    }
                    "lox.set_root" => {
                        // Replace with store instruction
                    }
                    _ => {}
                }
            });
        })
    }
}
}

Chapter 17: Linking Everything Together

Now we need to link the MLIR-generated code with our Rust runtime:

Step 1: Compile the Runtime

cd lox-mlir
cargo build --release

This produces target/release/liblox_runtime.a (static lib) and liblox_runtime.so (dynamic lib).

Step 2: Generate MLIR

#![allow(unused)]
fn main() {
// Compile Lox source to MLIR
let mlir = compile_to_mlir(source)?;
println!("{}", mlir);
}

Output:

module {
  func.func @example() -> f64 {
    %0 = lox.push_frame root_count = 3 : !llvm.ptr
    
    // Allocate a string object
    %1 = lox.alloc obj_type = 1, size = 5 : !llvm.ptr
    
    // Store as root 0
    lox.set_root index = 0, %1 : !llvm.ptr
    
    // ... function body ...
    
    lox.pop_frame
    return %result : f64
  }
}

Step 3: Lower to LLVM IR

# Lower MLIR to LLVM IR
mlir-translate output.mlir --mlir-to-llvmir -o output.ll

Step 4: Compile to Object File

# Compile LLVM IR to object file
llc output.ll -filetype=obj -o output.o

# Link with runtime
clang output.o -L./target/release -llox_runtime -o output

Step 5: Run

./output

Chapter 17.5: Complete MLIR Example

Let's trace through the complete compilation of a simple Lox function:

Input: Lox Source

fun add(a, b) {
    return a + b;
}

print add(1, 2);

Stage 1: MLIR (Lox Dialect)

module {
  // The 'add' function
  func.func @add(%arg0: f64, %arg1: f64) -> f64 
      attributes {sym_name = "add"} 
  {
    // Push shadow stack frame with 2 roots (for parameters)
    %frame = lox.push_frame root_count = 2 : !llvm.ptr
    
    // Register parameters as roots
    lox.set_root index = 0, %arg0 : f64
    lox.set_root index = 1, %arg1 : f64
    
    // The addition
    %sum = arith.addf %arg0, %arg1 : f64
    
    // Pop frame before return
    lox.pop_frame
    
    // Return
    return %sum : f64
  }
  
  // The main entry point
  func.func @main() -> i32 {
    %frame = lox.push_frame root_count = 0 : !llvm.ptr
    
    // Call add(1, 2)
    %one = arith.constant 1.0 : f64
    %two = arith.constant 2.0 : f64
    %result = func.call @add(%one, %two) : (f64, f64) -> f64
    
    // Print the result
    // (simplified - would call a print runtime function)
    
    lox.pop_frame
    %zero = arith.constant 0 : i32
    return %zero : i32
  }
}

Stage 2: After Lowering (LLVM Dialect)

module {
  llvm.func @add(%arg0: f64, %arg1: f64) -> f64 {
    // gc_push_frame(2)
    %frame = llvm.call @gc_push_frame(%c2_i64) : (i64) -> !llvm.ptr
    
    // gc_set_root(0, arg0) - converted to store
    // (simplified representation)
    llvm.store %arg0, %frame[%c0_i64] : f64, !llvm.ptr
    
    // gc_set_root(1, arg1)
    llvm.store %arg1, %frame[%c1_i64] : f64, !llvm.ptr
    
    // Addition
    %sum = llvm.fadd %arg0, %arg1 : f64
    
    // gc_pop_frame()
    llvm.call @gc_pop_frame() : () -> ()
    
    llvm.return %sum : f64
  }
  
  llvm.func @main() -> i32 {
    // ... similar lowering ...
    llvm.return %c0_i32 : i32
  }
  
  // External declarations for runtime
  llvm.func @gc_push_frame(i64) -> !llvm.ptr
  llvm.func @gc_pop_frame() -> ()
  llvm.func @lox_runtime_alloc(i64, i8) -> !llvm.ptr
}

Stage 3: LLVM IR

define double @add(double %0, double %1) {
entry:
  ; Push shadow stack frame
  %frame = call i8* @gc_push_frame(i64 2)
  
  ; Store parameters as roots
  %root0_ptr = getelementptr i8*, i8* %frame, i64 0
  store double %0, double* %root0_ptr
  
  %root1_ptr = getelementptr i8*, i8* %frame, i64 1
  store double %1, double* %root1_ptr
  
  ; Addition
  %sum = fadd double %0, %1
  
  ; Pop frame
  call void @gc_pop_frame()
  
  ret double %sum
}

define i32 @main() {
entry:
  %result = call double @add(double 1.0, double 2.0)
  ; ... print result ...
  ret i32 0
}

; External runtime functions
declare i8* @gc_push_frame(i64)
declare void @gc_pop_frame()
declare i8* @lox_runtime_alloc(i64, i8)

Stage 4: Assembly (x86-64, simplified)

add:
    push   rbp
    mov    rbp, rsp
    
    ; gc_push_frame(2)
    mov    rdi, 2
    call   gc_push_frame
    mov    rax, rax          ; frame pointer
    
    ; Store roots (simplified)
    movsd  [rax], xmm0       ; store %0
    movsd  [rax + 8], xmm1   ; store %1
    
    ; Addition
    addsd  xmm0, xmm1        ; %sum = %0 + %1
    
    ; gc_pop_frame()
    call   gc_pop_frame
    
    ; Return
    pop    rbp
    ret

main:
    ; ...
    movsd  xmm0, 1.0
    movsd  xmm1, 2.0
    call   add
    ; ...

Chapter 17.6: Handling Different Object Types

Let's see how we generate code for different object types:

String Allocation

var s = "hello";

Generated MLIR:

// Allocate string object (ObjType.String = 1, size = 5)
%str = lox.alloc obj_type = 1, size = 5 : !llvm.ptr

// Initialize string data
// (would fill in length, hash, and characters)

// Store as root
lox.set_root index = 0, %str : !llvm.ptr

Closure Allocation

fun makeCounter() {
    var count = 0;
    fun counter() {
        count = count + 1;
        return count;
    }
    return counter;
}

Generated MLIR (simplified):

func.func @makeCounter() -> !llvm.ptr {
    // 1. Allocate environment for captured 'count'
    %env = lox.alloc obj_type = 2, size = 16 : !llvm.ptr
    // Environment layout: [enclosing, count_slots...]
    
    // 2. Initialize count = 0 in environment
    // (store at offset)
    
    // 3. Allocate closure
    %closure = lox.alloc obj_type = 3, size = 16 : !llvm.ptr
    // Closure layout: [function_index, environment_ptr]
    
    // 4. Link closure to environment
    // (store function index and env pointer)
    
    return %closure : !llvm.ptr
}

Chapter 17.7: Practice Exercises

Exercise 1: Generate MLIR for a Function

Write the MLIR (Lox dialect) for this Lox function:

fun multiply(x, y) {
    var result = x * y;
    return result;
}
Click to reveal answer
func.func @multiply(%arg0: f64, %arg1: f64) -> f64 {
    // 3 roots: x, y, result
    %frame = lox.push_frame root_count = 3 : !llvm.ptr
    
    // Register parameters
    lox.set_root index = 0, %arg0 : f64
    lox.set_root index = 1, %arg1 : f64
    
    // Compute x * y
    %product = arith.mulf %arg0, %arg1 : f64
    
    // Allocate 'result' (if it's a heap object)
    // For simplicity, if result is just a number, we don't allocate
    // But if it's an object:
    %result_obj = lox.alloc obj_type = 0, size = 8 : !llvm.ptr
    lox.set_root index = 2, %result_obj : !llvm.ptr
    
    // Store product in result_obj
    // (would need additional operations)
    
    lox.pop_frame
    return %product : f64
}

Note: For simple numbers, we might not need heap allocation. The actual implementation would depend on your Lox value representation.

Exercise 2: Trace the Compilation Pipeline

Given this Lox code:

var x = 1;
var y = 2;
print x + y;

What does each stage produce?

Click to reveal answer

MLIR (Lox Dialect):

func.func @main() -> i32 {
    %frame = lox.push_frame root_count = 2 : !llvm.ptr
    
    // var x = 1
    %x_val = arith.constant 1.0 : f64
    lox.set_root index = 0, %x_val : f64
    
    // var y = 2
    %y_val = arith.constant 2.0 : f64
    lox.set_root index = 1, %y_val : f64
    
    // x + y
    %sum = arith.addf %x_val, %y_val : f64
    
    // print (simplified)
    // call print_runtime(%sum)
    
    lox.pop_frame
    return %c0_i32 : i32
}

After Lowering: The lox.push_frame becomes func.call @gc_push_frame, etc.

LLVM IR: Standard LLVM with calls to runtime functions.

Assembly: Native x86-64 or ARM code.

Exercise 3: Why Separate Dialects?

Why do we have a separate lox dialect instead of just generating LLVM IR directly?

Click to reveal answer
  1. Abstraction Level: Lox dialect captures Lox semantics (allocation, GC roots) at a high level. We can optimize at this level before lowering.

  2. Target Independence: MLIR can target WebAssembly, GPUs, or other backends. We don't lock ourselves into LLVM.

  3. Debugging: We can inspect the IR at each stage (Lox dialect → LLVM dialect → LLVM IR).

  4. Custom Optimizations: We can write passes that understand Lox semantics (e.g., eliminate unnecessary allocations).

  5. Incremental Lowering: We can do some optimizations in the Lox dialect, then lower to LLVM for target-specific work.


Checkpoint 3

We now have complete MLIR integration:

  • Lox dialect operations (lox.alloc, lox.push_frame, etc.)
  • Lowering to LLVM calls
  • Function code generation with shadow stack
  • Linking with Rust runtime
  • Complete MLIR example walkthrough
  • Different object types
  • Practice exercises

Files created:

  1. mlir-lox-guide-rust-part2.md - Concepts + allocation
  2. mlir-lox-guide-rust-part3.md - Roots + full GC
  3. mlir-lox-guide-rust-part4.md - MLIR integration (this file)

Next: Handling closures (the tricky part)