Part 5 — Imitation Learning
We have a working heuristic. We have a network. Now we connect them: generate training data from the heuristic, then train the network to predict the same moves.
The idea is called imitation learning or behavioral cloning. It’s supervised learning: for each board state, we want the network to predict the direction the heuristic chose. We’re training it to clone a teacher.
Collecting training data
We need a function that runs a game with the heuristic and records each state:
#![allow(unused)]
fn main() {
use snake_ml::{Board, Snake, Direction, encode_board_flat};
use crate::heuristic::HeuristicSnake;
use serde::{Deserialize, Serialize};
/// One training example: a board encoding and the teacher's chosen direction
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingExample {
/// Flat encoded board state: (8 * height * width,) f32
pub board: Vec<f32>,
/// Direction chosen by the teacher (0=up, 1=down, 2=left, 3=right)
pub label: u32,
}
/// Run one game with the heuristic teacher and collect all examples.
///
/// Note: `advance_snakes` is a simplified game simulation. A full implementation
/// needs to handle movement, food eating, body growth, starvation, and collision
/// detection. The `GameEnv` in Part 6 shows a complete implementation you can use
/// instead of this placeholder.
pub fn collect_game<B: burn::tensor::backend::Backend>(board: &Board, heuristic: &HeuristicSnake, device: &B::Device) -> Vec<TrainingExample> {
let mut examples = Vec::new();
let mut snakes = board.snakes.clone();
let mut turn = 0u32;
// Simulate up to 200 turns
while turn < 200 && snakes.iter().all(|s| s.health > 0) {
for my_idx in 0..snakes.len() {
let board_snapshot = Board {
width: board.width,
height: board.height,
food: board.food.clone(),
snakes: snakes.clone(),
hazards: board.hazards.clone(),
};
let my_snake = &snakes[my_idx];
let dir = heuristic.decide(my_snake, &board_snapshot);
examples.push(TrainingExample {
board: encode_board_flat::<B>(&board_snapshot, my_snake, device),
label: dir.index() as u32,
});
}
// Advance the game one step (simplified)
snakes = advance_snakes(snakes, board);
turn += 1;
}
examples
}
/// Encode + flatten in one step for training data
fn encode_board_flat<B: burn::tensor::backend::Backend>(board: &Board, my_snake: &Snake, device: &B::Device) ->Vec<f32> {
let tensor = encode_board::<B>(board, my_snake, device);
let flat: burn::tensor::Tensor<B, 1> = tensor.reshape([8 * board.height as i64 * board.width as i64]);
flat.to_data().to_vec::<f32>().unwrap()
}
/// Advance snakes by one turn. This is a minimal placeholder — it only
/// decrements health. For real training data, use the `GameEnv` from Part 6
/// which handles movement, food eating, body growth, wall collisions, and
/// snake-on-snake collisions correctly.
fn advance_snakes(snakes: Vec<Snake>, board: &Board) -> Vec<Snake> {
snakes.into_iter().filter_map(|mut s| {
if s.health <= 0 { return None; }
s.health = s.health.saturating_sub(1);
Some(s)
}).collect()
}
}
In practice you’d generate thousands of games across different board configurations. The data goes into a JSONL file (one JSON object per line) for efficient reading during training.
The training loop
Here’s the core training loop. We load the model, read the training data, and run Adam (stochastic gradient descent):
#![allow(unused)]
fn main() {
use burn::backend::Autodiff;
use burn::optim::{AdamWConfig, GradientsParams};
use burn::tensor::{Tensor, TensorData, Shape};
use burn::tensor::loss::cross_entropy_with_logits;
use burn::module::Module;
use burn::record::{DefaultFileRecorder, FullPrecisionSettings, Recorder};
use snake_ml::{SnakeNet, TrainingExample};
type Backend = Autodiff<burn::backend::Cpu>;
pub fn train() {
let lr = 1e-3;
// Load training data (JSONL: one JSON object per line)
let examples: Vec<TrainingExample> = std::fs::read_to_string("training_data.jsonl")
.expect("failed to read training data")
.lines()
.filter(|line| !line.is_empty())
.map(|line| serde_json::from_str(line).unwrap())
.collect();
println!("Loaded {} training examples", examples.len());
// Initialize model (mut — the optimizer returns an updated network each step)
let device = <Backend as burn::tensor::backend::Backend>::Device::default();
let mut net = SnakeNet::new(&device);
let mut optim = AdamWConfig::new().init::<Backend>();
let batch_size = 64;
let epochs = 20;
for epoch in 0..epochs {
// Shuffle the dataset each epoch. Without this, the snake can fall into
// a repeating cycle if batches always cover the same sequence of states
// (e.g. the same food placements in the same order).
let mut indices: Vec<usize> = (0..examples.len()).collect();
use rand::seq::SliceRandom;
indices.shuffle(&mut rand::thread_rng());
let mut total_loss = 0.0_f64;
let num_batches = examples.len() / batch_size;
for batch_idx in 0..num_batches {
let start = batch_idx * batch_size;
let batch_end = (start + batch_size).min(examples.len());
let batch = &indices[start..batch_end];
// Build batch tensors
let batch_size = batch.len() as i64;
let xs: Vec<f32> = batch.iter()
.flat_map(|&i| examples[i].board.iter().copied())
.collect();
let targets: Vec<i32> = batch.iter()
.map(|&i| examples[i].label as i32)
.collect();
// Input: (batch, 8*11*11), targets: (batch,) as integer class indices
let xs_t = Tensor::<Backend, 2>::from_data(
TensorData::new(xs, Shape::new([batch_size, 8 * 11 * 11])),
&device,
);
let targets_t = Tensor::<Backend, 1, burn::tensor::Int>::from_data(
TensorData::new(targets, Shape::new([batch_size])),
&device,
);
// Forward pass → cross-entropy loss
// cross_entropy_with_logits applies softmax internally
let logits = net.forward(xs_t);
let loss = cross_entropy_with_logits(logits, targets_t);
// Clone loss value before backward consumes it
let loss_val = loss.to_data().to_vec::<f32>().unwrap()[0] as f64;
// Backward pass
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &net);
// optim.step returns the updated network. Assignment (not `let`) mutates
// the outer `net` so the next batch uses the updated weights.
net = optim.step(lr, net, grads);
total_loss += loss_val;
}
let avg = total_loss / num_batches as f64;
println!("Epoch {}: avg loss = {:.4}", epoch, avg);
}
// Save the trained model using the recorder pattern
let recorder = DefaultFileRecorder::<FullPrecisionSettings>::new();
net.clone()
.save_file("snake_model", &recorder)
.expect("failed to save model");
println!("Model saved to snake_model");
}
}
cross_entropy_with_logits is the free function form of this loss: it takes raw logits and a target probability tensor, applies softmax internally, then computes cross-entropy. We pass the network’s logits and the label-derived probability tensor. For 4-class classification, this is the standard loss function.
The loss function intuition
Cross-entropy loss measures how far the predicted probabilities are from the target distribution. When the teacher always picks “up”, the target distribution is (1.0, 0.0, 0.0, 0.0) for (up, down, left, right). The loss is smaller when the network also predicts “up” with high probability.
Minimizing cross-entropy = maximizing the probability assigned to the teacher’s choice. That’s exactly what we want: the network learns to pick the same direction the teacher would.
The training signal is strong because the heuristic makes consistent choices. The network learns the heuristic’s patterns.
What the network learns
After training, the network’s forward pass on a new board state produces a probability distribution over directions. Crucially, it has generalized — it’s not merely memorizing what the heuristic did in specific situations. It’s learned patterns.
The network might learn that:
- Food in a particular direction → higher probability for that direction
- Enemy nearby → avoid it
- Narrow corridor → prefer the wider path
Some of these patterns the heuristic knew explicitly. Others it couldn’t express but the network picked up anyway.
Evaluation
Run the trained model against the heuristic on a set of held-out games:
#![allow(unused)]
fn main() {
use snake_ml::{SnakeNet, Direction, encode_board};
use crate::heuristic::HeuristicSnake;
type Backend = Autodiff<burn::backend::Cpu>;
/// Evaluate how often the model agrees with the heuristic.
/// Agreement rate of 80–90% is good — the network has learned the teacher's patterns.
pub fn evaluate(model: &SnakeNet<Backend>, heuristic: &HeuristicSnake, games: usize) -> f32 {
let mut agreement = 0;
for _ in 0..games {
let board = random_board(11, 11, 1); // generate a test board (see Part 6)
let my_snake = &board.snakes[0];
let device = <Backend as burn::tensor::backend::Backend>::Device::default();
let tensor = encode_board::<Backend>(&board, my_snake, &device);
let flat: Tensor<Backend, 2> = tensor.reshape([1, 8 * 11 * 11]);
let model_dir = model.pick_direction(flat); // SnakeNet<Backend>.pick_direction returns a Direction
let heur_dir = heuristic.decide(my_snake, board);
// Direction derives PartialEq — direct comparison, no string matching.
if model_dir == heur_dir {
agreement += 1;
}
}
agreement as f32 / games as f32
}
}
At this stage, the network should agree with the heuristic 80–90% of the time. Lower agreement means the network hasn’t learned the heuristic’s patterns yet. Higher means it’s close to the teacher’s performance.
Next: we stop using the heuristic as a teacher and let the snake learn on its own.
Previous: Part 4 — The Heuristic Baseline · Next: Part 6 — Reinforcement Learning Basics