Part 5 — Imitation Learning
We have a working heuristic. We have a network. Now we connect them: generate training data from the heuristic, then train the network to predict the same moves.
The idea is called imitation learning or behavioral cloning. It’s supervised learning: for each board state, we want the network to predict the direction the heuristic chose. We’re training it to clone a teacher.
Collecting training data
We need a function that runs a game with the heuristic and records each state:
#![allow(unused)]
fn main() {
use snake_ml::{Board, Snake, Direction, encode_board};
use crate::heuristic::HeuristicSnake;
use serde::{Deserialize, Serialize};
/// One training example: a board encoding and the teacher's chosen direction
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingExample {
/// Flat encoded board state: (8 * height * width,) f32
pub board: Vec<f32>,
/// Direction chosen by the teacher (0=up, 1=down, 2=left, 3=right)
pub label: u32,
}
/// Run one game with the heuristic teacher and collect all examples.
///
/// Note: `advance_snakes` is a simplified game simulation. A full implementation
/// needs to handle movement, food eating, body growth, starvation, and collision
/// detection. The `GameEnv` in Part 6 shows a complete implementation you can use
/// instead of this placeholder.
pub fn collect_game(board: &Board, heuristic: &HeuristicSnake) -> Vec<TrainingExample> {
let mut examples = Vec::new();
let mut snakes = board.snakes.clone();
let mut turn = 0u32;
// Simulate up to 200 turns
while turn < 200 && snakes.iter().all(|s| s.health > 0) {
for my_idx in 0..snakes.len() {
let board_snapshot = Board {
width: board.width,
height: board.height,
food: board.food.clone(),
snakes: snakes.clone(),
hazards: board.hazards.clone(),
};
let my_snake = &snakes[my_idx];
let dir = heuristic.decide(my_snake, &board_snapshot);
examples.push(TrainingExample {
board: encode_board_flat(&board_snapshot, my_snake),
label: dir as u32,
});
}
// Advance the game one step (simplified)
snakes = advance_snakes(snakes, board);
turn += 1;
}
examples
}
/// Encode + flatten in one step for training data
fn encode_board_flat(board: &Board, my_snake: &Snake) -> Vec<f32> {
let tensor = encode_board(board, my_snake).unwrap();
let flat = tensor.flatten_all().unwrap();
flat.to_vec1::<f32>().unwrap()
}
/// Advance snakes by one turn. This is a minimal placeholder — it only
/// decrements health. For real training data, use the `GameEnv` from Part 6
/// which handles movement, food eating, body growth, wall collisions, and
/// snake-on-snake collisions correctly.
fn advance_snakes(snakes: Vec<Snake>, board: &Board) -> Vec<Snake> {
snakes.into_iter().filter_map(|mut s| {
if s.health <= 0 { return None; }
s.health = s.health.saturating_sub(1);
Some(s)
}).collect()
}
}
In practice you’d generate thousands of games across different board configurations. The data goes into a JSONL file (one JSON object per line) for efficient reading during training.
The training loop
Here’s what training means at a high level before we look at the code:
We show the network many (board state, correct direction) pairs. For each pair, we ask the network: “what direction would you pick?” Then we compare its answer to the correct answer (which we got from the heuristic). If it’s wrong, we adjust the weights to make it more likely to be right next time. We repeat this for 20 full passes over the data. Each pass is called an epoch.
We use a loss function called cross-entropy to measure how far off the network’s prediction is from the correct answer. Minimizing cross-entropy is equivalent to maximizing the probability the network assigns to the heuristic’s choice — so the network learns to pick the same direction the heuristic would.
Here’s the code:
#![allow(unused)]
fn main() {
use candle_core::{Device, DType, Result, Tensor};
use candle_nn::{AdamW, Optimizer, VarBuilder, VarMap};
use candle_nn::loss::cross_entropy;
use snake_ml::{Model, SnakeNet, TrainingExample};
pub fn train() -> Result<()> {
let device = Device::Cpu;
let lr = 1e-3;
// Load training data (JSONL: one JSON object per line)
let examples: Vec<TrainingExample> = std::fs::read_to_string("training_data.jsonl")?
.lines()
.filter(|line| !line.is_empty())
.map(|line| serde_json::from_str(line))
.collect::<Result<_, _>>()?;
println!("Loaded {} training examples", examples.len());
// Initialize model
let mut var_map = VarMap::new();
let vs = VarBuilder::from_varmap(&var_map, DType::F32, &device);
let net = SnakeNet::new(vs)?;
let mut opt = AdamW::new(var_map.all_vars(), candle_nn::ParamsAdamW {
lr,
..Default::default()
})?;
// Note: the companion code in snake-ml/src/lib.rs uses SGD (Stochastic Gradient Descent — the simpler optimizer that updates weights after each batch) instead of
// AdamW and packs the training step into Model::train_batch(). The
// math is the same (cross-entropy + gradient descent) — the difference
// is the optimizer and the API surface. AdamW converges faster on
// larger datasets; SGD is simpler for the companion's single-method API.
let batch_size = 64;
let epochs = 20;
for epoch in 0..epochs {
// Shuffle the dataset each epoch. Without this, the snake can fall into
// a repeating cycle if batches always cover the same sequence of states
// (e.g. the same food placements in the same order).
let mut indices: Vec<usize> = (0..examples.len()).collect();
use rand::seq::SliceRandom;
indices.shuffle(&mut rand::thread_rng());
let mut total_loss = 0.0_f64;
let num_batches = examples.len() / batch_size;
for batch_idx in 0..num_batches {
let start = batch_idx * batch_size;
let batch: &[usize] = &indices[start..start + batch_size];
let xs: Vec<f32> = batch.iter()
.flat_map(|&i| examples[i].board.iter().copied())
.collect();
let targets: Vec<u32> = batch.iter()
.map(|&i| examples[i].label)
.collect();
// 968 = 8 channels × 11 × 11 (the default board size from Part 2)
let xs_t = Tensor::from_vec(xs, (batch.len(), 968), &device)?;
let targets_t = Tensor::from_vec(targets, (batch.len(),), &device)?;
// Forward pass → softmax probabilities
let logits = net.forward(&xs_t)?;
let loss = cross_entropy(&logits, &targets_t)?;
// Backward pass
opt.backward_step(&loss)?;
total_loss += loss.to_scalar::<f32>()? as f64;
}
let avg = total_loss / num_batches as f64;
println!("Epoch {}: avg loss = {:.4}", epoch, avg);
}
// Save the trained weights
var_map.save("snake_model.safetensors")?;
println!("Model saved to snake_model.safetensors");
Ok(())
}
}
The loss function intuition
Cross-entropy loss measures how far the predicted probabilities are from the target distribution. When the teacher always picks “up”, the target distribution is (1.0, 0.0, 0.0, 0.0) for (up, down, left, right). The loss is smaller when the network also predicts “up” with high probability.
Minimizing cross-entropy = maximizing the probability assigned to the teacher’s choice. That’s exactly what we want: the network learns to pick the same direction the teacher would.
cross_entropy from candle_nn::loss handles the softmax internally — we pass in raw logits, it applies softmax, then computes cross-entropy against the target labels. For 4-class classification, this is the standard loss function.
The training signal is strong because the heuristic makes consistent choices. The network learns the heuristic’s patterns.
What the network learns
After training, the network’s forward pass on a new board state produces a probability distribution over directions. Crucially, it should have generalized — if it’s working well, it’s not merely memorizing what the heuristic did in specific situations, it’s learned patterns that transfer.
If training is working well, you might observe the network doing the following:
- Food in a particular direction → higher probability for that direction
- Enemy nearby → avoid it
- Narrow corridor → prefer the wider path
Some of these patterns the heuristic knew explicitly. Others it couldn’t express but the network picked up anyway. If you don’t see evidence of generalization (loss is flat, accuracy doesn’t improve), try generating more training data or training for more epochs.
Evaluation
Run the trained model against the heuristic on a set of held-out games:
#![allow(unused)]
fn main() {
use snake_ml::{Model, Direction, encode_board};
use crate::heuristic::HeuristicSnake;
/// Evaluate how often the model agrees with the heuristic.
/// Agreement rate of 80–90% is good — the network has learned the teacher's patterns.
pub fn evaluate(model: &Model, heuristic: &HeuristicSnake, games: usize) -> f32 {
let mut agreement = 0;
for _ in 0..games {
let board = random_board(11, 11, 1); // generate a test board (see Part 6)
let my_snake = &board.snakes[0];
let tensor = encode_board(board, my_snake).unwrap();
let flat = tensor.flatten_all().unwrap().unsqueeze(0).unwrap();
let model_dir = model.pick_direction(&flat).unwrap();
let heur_dir = heuristic.decide(my_snake, board);
// Direction derives PartialEq — direct comparison, no string matching.
if model_dir == heur_dir {
agreement += 1;
}
}
agreement as f32 / games as f32
}
}
At this stage, the network should agree with the heuristic 80–90% of the time. Lower agreement means the network hasn’t learned the heuristic’s patterns yet. Higher means it’s close to the teacher’s performance.
Next: we stop using the heuristic as a teacher and let the snake learn on its own.
Previous: Part 4 — The Heuristic Baseline · Next: Part 6 — Reinforcement Learning Basics