Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Part 7 — Self-Play Training

REINFORCE works. A single snake can learn from its own experience — eat food, get a reward, make that move more likely next time. But there’s a problem: the snake is playing against an environment that doesn’t fight back.

Real BattleSnake has opponents. Opponents that block your path, steal your food, and chase you into corners. Training against an empty board teaches you to eat. It doesn’t teach you to compete.

Self-play fixes this. The snake trains against itself — or more precisely, against past versions of itself. Every time it gets stronger, its opponent gets stronger too. The ceiling keeps rising.

Why self-play works

Here’s the intuition. Imagine you’re learning chess. Playing against a wall (someone who never moves) teaches you the rules. Playing against a weak opponent teaches you basic tactics. Playing against someone slightly better than you teaches you to see your own mistakes.

Self-play is that last one, automated. The key insight: a snake is always training against the current best version of itself. When it discovers a new strategy, the opponent (an older copy of the same network) doesn’t know that strategy yet. The new strategy wins. The network updates. Now the next training run faces a network that does know that strategy — so it has to discover a counter. And so on.

This is how AlphaGo trained. It’s how AlphaStar trained. It’s how every self-play system works: you are your own curriculum. The difficulty adjusts automatically.

The training loop

The self-play loop has four steps, repeated many times:

  1. Sample an opponent. Pick a past checkpoint — the current network, or a saved snapshot from an earlier training run.
  2. Play episodes. Run games where the training snake faces the opponent snake. Record the experience: states, actions, rewards.
  3. Update the training snake. Run REINFORCE on the collected experience. The training snake’s weights change; the opponent’s weights stay frozen.
  4. Evaluate and checkpoint. Periodically pit the current snake against the previous best. If it wins more than it loses, save a checkpoint. That checkpoint becomes a new candidate opponent.

The loop looks like this:

┌───────────────────────────────────────────┐
│  1. Sample opponent from checkpoint pool  │
│                   │                       │
│                   ▼                       │
│  2. Play N episodes (training vs opponent)│
│                   │                       │
│                   ▼                       │
│  3. REINFORCE update (training snake only)│
│                   │                       │
│                   ▼                       │
│  4. Evaluate: current vs best checkpoint  │
│     If win rate > 55%: save checkpoint    │
│     Go to 1                               │
└───────────────────────────────────────────┘

The 55% win rate threshold is deliberate. We don’t require a supermajority — we want to save checkpoints that are slightly better, because slightly-better opponents create the steady pressure that drives improvement. If we waited for 80% wins, the network would stagnate between checkpoints.

The game simulator

We need a game engine that runs locally — no HTTP, no web server, a function that advances the board state. This is the GameEnv that Part 6 referenced.

The simulator doesn’t need every BattleSnake rule. For training, we need: snakes move, food spawns, snakes die if they hit walls or bodies, health goes down, eating food restores health. That’s the core loop.

#![allow(unused)]
fn main() {
use snake_ml::{Board, Point, Snake};

/// A player in the simulation — either a neural network or a heuristic
pub trait Player {
    fn decide(&self, board: &Board, my_snake: &Snake) -> u32;
}

/// The game simulator
pub struct GameEnv {
    board: Board,
    players: Vec<Box<dyn Player>>,
    done: bool,
    turn: u32,
    max_turns: u32,
}

#[derive(Debug, Clone)]
pub struct StepResult {
    pub board: Board,
    pub rewards: Vec<f32>,    // one reward per player
    pub done: bool,
    pub winner: Option<usize>, // Some(index) or None for draw
}

impl GameEnv {
    pub fn new(board: Board, players: Vec<Box<dyn Player>>) -> Self {
        Self {
            board,
            players,
            done: false,
            turn: 0,
            max_turns: 200,
        }
    }

    /// Reset the environment with a fresh board and return the initial state
    pub fn reset(&mut self) -> Board {
        self.board = random_board(11, 11, self.players.len());
        self.done = false;
        self.turn = 0;
        self.board.clone()
    }

    /// Advance the game one turn: each player decides, then we resolve
    pub fn step(&mut self) -> StepResult {
        // 1. Each player picks a direction
        let directions: Vec<u32> = self.players.iter()
            .enumerate()
            .map(|(i, p)| p.decide(&self.board, &self.board.snakes[i]))
            .collect();

        // 2. Move snakes
        for (i, &dir) in directions.iter().enumerate() {
            let (dx, dy) = match dir {
                0 => (0, 1),   // up
                1 => (0, -1),  // down
                2 => (-1, 0),  // left
                3 => (1, 0),   // right
                _ => (0, 0),
            };

            let snake = &mut self.board.snakes[i];
            let new_head = Point {
                x: snake.head.x + dx,
                y: snake.head.y + dy,
            };

            // Insert new head at the front of the body
            snake.body.insert(0, new_head.clone());
            snake.head = new_head;

            // If the head isn't on food, remove the tail (snake doesn't grow)
            let on_food = self.board.food.iter()
                .any(|f| f.x == new_head.x && f.y == new_head.y);

            if !on_food {
                snake.body.pop();
            } else {
                snake.health = 100;
            }

            snake.health = snake.health.saturating_sub(1);
        }

        // 3. Check for deaths
        let mut dead = vec![false; self.players.len()];

        for (i, snake) in self.board.snakes.iter().enumerate() {
            // Wall collision
            if snake.head.x < 0 || snake.head.x >= self.board.width as i32
                || snake.head.y < 0 || snake.head.y >= self.board.height as i32
            {
                dead[i] = true;
                continue;
            }

            // Self collision (head hit own body, starting from index 1)
            for (j, seg) in snake.body.iter().enumerate() {
                if j > 0 && seg.x == snake.head.x && seg.y == snake.head.y {
                    dead[i] = true;
                    break;
                }
            }

            // Body collision with other snakes
            if dead[i] { continue; }
            for (j, other) in self.board.snakes.iter().enumerate() {
                if i == j { continue; }
                // Did I hit the other snake's body?
                for seg in &other.body {
                    if seg.x == snake.head.x && seg.y == snake.head.y {
                        // Check for head-to-head: if both heads just moved here
                        if snake.head.x == other.head.x && snake.head.y == other.head.y {
                            // Shorter snake dies; if same length, both die
                            if snake.body.len() <= other.body.len() {
                                dead[i] = true;
                            }
                        } else {
                            dead[i] = true;
                        }
                        break;
                    }
                }
            }
        }

        // 4. Starvation check
        for (i, snake) in self.board.snakes.iter().enumerate() {
            if snake.health == 0 {
                dead[i] = true;
            }
        }

        // 5. Compute rewards
        let mut rewards = vec![0.0_f32; self.players.len()];
        let mut winner = None;
        let alive_count = dead.iter().filter(|&&d| !d).count();

        for (i, is_dead) in dead.iter().enumerate() {
            if *is_dead {
                rewards[i] = -1.0;
            } else {
                // Survived this turn
                rewards[i] = 0.01;

                // Ate food (health was reset to 100)
                if self.board.snakes[i].health == 100 {
                    rewards[i] += 1.0;
                }
            }
        }

        // If only one snake is alive, it wins
        if alive_count == 1 {
            winner = dead.iter().position(|&d| !d);
            if let Some(w) = winner {
                rewards[w] += 5.0; // bonus for winning
            }
            self.done = true;
        } else if alive_count == 0 {
            self.done = true;
        }

        self.turn += 1;
        if self.turn >= self.max_turns {
            self.done = true;
        }

        // 6. Remove eaten food
        self.board.food.retain(|f| {
            !self.board.snakes.iter().any(|s| s.head.x == f.x && s.head.y == f.y)
        });

        // 7. Spawn new food (one piece per empty slot, simplified)
        if self.board.food.is_empty() {
            if let Some(pos) = random_empty_cell(&self.board) {
                self.board.food.push(pos);
            }
        }

        StepResult {
            board: self.board.clone(),
            rewards,
            done: self.done,
            winner,
        }
    }

    pub fn is_done(&self) -> bool {
        self.done
    }
}

fn random_board(width: u32, height: u32, num_snakes: usize) -> Board {
    // Place snakes in opposite corners, place some food in the middle
    let mut snakes = Vec::with_capacity(num_snakes);
    let start_positions = [
        (1, 1),
        (width as i32 - 2, height as i32 - 2),
    ];

    for i in 0..num_snakes.min(start_positions.len()) {
        let (sx, sy) = start_positions[i];
        let body: Vec<Point> = (0..3).map(|j| Point { x: sx, y: sy + j }).collect();
        snakes.push(Snake {
            id: format!("snake-{i}"),
            body: body.clone(),
            head: body[0].clone(),
            health: 100,
        });
    }

    let mid_x = (width / 2) as i32;
    let mid_y = (height / 2) as i32;

    Board {
        width,
        height,
        food: vec![
            Point { x: mid_x, y: mid_y },
            Point { x: mid_x - 2, y: mid_y },
            Point { x: mid_x + 2, y: mid_y },
        ],
        snakes,
        hazards: vec![],
    }
}

fn random_empty_cell(board: &Board) -> Option<Point> {
    // Simplified: try a few random positions, return the first empty one
    use rand::Rng;
    let mut rng = rand::thread_rng();
    for _ in 0..20 {
        let x = rng.gen_range(0..board.width as i32);
        let y = rng.gen_range(0..board.height as i32);
        let occupied = board.snakes.iter()
            .any(|s| s.body.iter().any(|b| b.x == x && b.y == y));
        if !occupied {
            return Some(Point { x, y });
        }
    }
    None
}
}

This simulator handles the core mechanics. It’s not a perfect BattleSnake engine — for instance, it doesn’t handle simultaneous body-collision resolution the way the real engine does, and food spawning is simplified. But it’s good enough to generate training signal. The reward structure matches Part 6: eat food (+1), survive (+0.01), die (-1), win (+5).

The network as a player

We need to wrap our SnakeNet so it implements the Player trait. This is where the trained weights come in — a player with trained weights plays differently than one with random weights.

#![allow(unused)]
fn main() {
use burn::backend::Autodiff;
use burn::tensor::Tensor;
use burn::tensor::activation::relu;
use burn::module::Module;
use burn::record::{DefaultFileRecorder, FullPrecisionSettings, Recorder};
use snake_ml::{SnakeNet, encode_board_flat, Board, Snake};

type Backend = Autodiff<burn::backend::Cpu>;


/// A neural network player. Uses the trained (or random) weights
/// to pick a direction each turn.
pub struct NetworkPlayer<B: burn::tensor::backend::Backend> {
    net: SnakeNet<B>,
    device: B::Device,
    epsilon: f32, // exploration rate for ε-greedy
}

impl<B: burn::tensor::backend::Backend> NetworkPlayer<B> {
    pub fn from_net(net: SnakeNet<B>, epsilon: f32) -> Self {
        let device = B::Device::default();
        Self { net, device, epsilon }
    }

    pub fn from_checkpoint(path: &str, epsilon: f32) -> Self {
        let device = B::Device::default();
        let recorder = DefaultFileRecorder::<FullPrecisionSettings>::new();
        let record = recorder
            .load(path.into(), &device)
            .expect("failed to load checkpoint");
        let net = SnakeNet::new(&device).load_record(record);
        Self { net, device, epsilon }
    }
}

impl<B: burn::tensor::backend::Backend> Player for NetworkPlayer<B> {
    fn decide(&self, board: &Board, my_snake: &Snake) -> u32 {
        // ε-greedy: explore sometimes
        if rand::random::<f32>() < self.epsilon {
            return rand::thread_rng().gen_range(0..4) as u32;
        }

        // Encode the board and pick the best direction
        let flat = encode_board_flat::<B>(board, my_snake, &self.device);
        let state_t: Tensor<B, 2> = Tensor::from_data(
            burn::tensor::TensorData::new(flat, burn::tensor::Shape::new([1, 8 * 11 * 11])),
            &self.device,
        );

        self.net.pick_direction(state_t).index() as u32
    }
}
}

The epsilon parameter controls exploration. During training, the training snake uses a higher epsilon (0.1 — try random moves 10% of the time). The opponent uses epsilon = 0 (always pick the best move it knows) so it plays at full strength.

The opponent pool

Here’s the self-play twist: we don’t always train against the current network. If we did, the opponent would be exactly as strong as the training snake, which means every game is a coin flip. The reward signal gets noisy.

Instead, we maintain a pool of checkpoints — saved snapshots of the network weights from different points in training. Each training episode, we sample an opponent from this pool. Sometimes the opponent is the current network. Sometimes it’s a weaker checkpoint from 100 training iterations ago. The variety stabilizes training.

#![allow(unused)]
fn main() {
use std::path::{Path, PathBuf};

/// A pool of saved model checkpoints for opponent sampling
pub struct OpponentPool {
    checkpoints: Vec<PathBuf>,
}

impl OpponentPool {
    pub fn new() -> Self {
        Self { checkpoints: Vec::new() }
    }

    /// Add a checkpoint to the pool
    pub fn add(&mut self, path: PathBuf) {
        self.checkpoints.push(path);
    }

    /// Sample a random opponent from the pool.
    /// If the pool is empty, returns None (use a random player instead).
    pub fn sample(&self) -> Option<PathBuf> {
        if self.checkpoints.is_empty() {
            return None;
        }
        use rand::Rng;
        let idx = rand::thread_rng().gen_range(0..self.checkpoints.len());
        Some(self.checkpoints[idx].clone())
    }

    /// Weighted sampling: prefer recent opponents, but occasionally
    /// pick older ones. This balances challenge (recent = harder)
    /// with diversity (older = different strategies).
    pub fn sample_weighted(&self) -> Option<PathBuf> {
        if self.checkpoints.is_empty() {
            return None;
        }

        // 50% chance: most recent checkpoint (hardest opponent)
        // 50% chance: uniformly random from the full pool
        use rand::Rng;
        let mut rng = rand::thread_rng();
        if rng.gen_range(0.0..1.0) < 0.5 {
            Some(self.checkpoints.last().unwrap().clone())
        } else {
            let idx = rand::thread_rng().gen_range(0..self.checkpoints.len());
            Some(self.checkpoints[idx].clone())
        }
    }

    pub fn len(&self) -> usize {
        self.checkpoints.len()
    }
}
}

The weighted sampling is a practical detail that matters. Training only against the strongest opponent is a curriculum, but it can be too hard — the training snake loses every game, the reward signal is uniformly negative, and nothing is learned. Mixing in weaker opponents gives the training snake some wins, some positive reward, and a gradient that points in a useful direction.

The full training loop

Everything comes together here. The loop runs for a fixed number of iterations. Each iteration: sample an opponent, play episodes, update the training snake, occasionally evaluate.

#![allow(unused)]
fn main() {
use std::fs;
use burn::backend::Autodiff;
use burn::optim::{AdamWConfig, AdamW};
use burn::tensor::{Tensor, TensorData, Shape};
use burn::module::Module;
use burn::record::{DefaultFileRecorder, FullPrecisionSettings, Recorder};

type Backend = Autodiff<burn::backend::Cpu>;

pub fn train_self_play(
    mut net: SnakeNet<Backend>,
    num_iterations: usize,
    episodes_per_iteration: usize,
    checkpoint_dir: &str,
) -> Result<SnakeNet<Backend>, Box<dyn std::error::Error>> {
    let learning_rate = 1e-3_f64;
    let gamma = 0.99_f32;
    let initial_epsilon = 0.2_f32;
    let min_epsilon = 0.01_f32;

    // Create the optimizer once so optimizer state persists
    // across training iterations.
    let mut optim = AdamWConfig::new().init::<Backend>();

    let mut pool = OpponentPool::new();

    // Save the initial random weights as the first checkpoint
    fs::create_dir_all(checkpoint_dir)?;
    let initial_path = format!("{checkpoint_dir}/checkpoint-0");
    let recorder = DefaultFileRecorder::<FullPrecisionSettings>::new();
    net.clone().save_file(&initial_path, &recorder)?;
    pool.add(PathBuf::from(&initial_path));

    for iteration in 1..=num_iterations {
        // Epsilon decays over time: less exploration as training progresses
        let epsilon = (initial_epsilon
            * (1.0 - iteration as f32 / num_iterations as f32))
            .max(min_epsilon);

        let mut all_experiences: Vec<Experience> = Vec::new();
        let mut total_reward = 0.0_f32;
        let mut wins = 0usize;
        let mut losses = 0usize;

        for _ in 0..episodes_per_iteration {
            // 1. Sample an opponent
            let opponent = match pool.sample_weighted() {
                Some(path) => NetworkPlayer::from_checkpoint(&path.to_string_lossy(), 0.0),
                None => NetworkPlayer::from_net(net.clone(), 0.0),
            };

            // The training snake (with exploration)
            let training_player = NetworkPlayer::from_net(net.clone(), epsilon);

            // 2. Play an episode
            let board = random_board(11, 11, 2);
            let mut env = GameEnv::new(
                board,
                vec![Box::new(training_player), Box::new(opponent)],
            );

            env.reset();
            let mut experiences: Vec<Experience> = Vec::new();

            while !env.is_done() {
                let my_snake = &env.board.snakes[0];
                let state = encode_board_flat(&env.board, my_snake, &<Backend as burn::tensor::backend::Backend>::Device::default());

                // The training player already decided (inside GameEnv::step),
                // but we need to record what action it took.
                // Re-derive the action from the forward pass.
                let state_t: Tensor<Backend, 2> = Tensor::from_data(
                    burn::tensor::TensorData::new(state.clone(), burn::tensor::Shape::new([1, 8 * 11 * 11])),
                    &<Backend as burn::tensor::backend::Backend>::Device::default(),
                );
                let action = net.pick_direction(state_t).index() as u32;

                let result = env.step();

                // Encode the next board state for the replay buffer.
                // We need (state, action, reward, next_state) for every step
                // so the replay buffer in Part 9 can do prioritized sampling.
                // `encode_board_and_flat` is a small wrapper in snake-ml that calls
                // encode_board, flattens, and returns (Vec<f32, ()). The () is a
                // unit marker for destructuring convenience.
                let (next_state_flat, _) = encode_board_and_flat(&result.board, &result.board.snakes[0], &<Backend as burn::tensor::backend::Backend>::Device::default());

                experiences.push(Experience {
                    state,
                    next_state: next_state_flat,
                    action: action as usize,
                    reward: result.rewards[0],
                });

                total_reward += result.rewards[0];

                if result.done {
                    match result.winner {
                        Some(0) => wins += 1,
                        Some(1) => losses += 1,
                        None => {} // draw
                    }
                }
            }

            all_experiences.extend(experiences);
        }
}

Entropy and beta decay. If you add entropy regularization — subtracting beta * H(π) from the reward, where H(π) measures how spread-out the policy distribution is — make sure beta decays over training. A constant high beta keeps the snake perpetually random; it never commits to learned behavior. A common schedule: beta = max(0.01, 0.5 * exp(-step / 1000)) — start exploratory, gradually lock in the policy.

#![allow(unused)]
fn main() {
        // 3. Compute returns
        let returns = compute_returns(&all_experiences, gamma);

        // Normalize returns (reduces variance in REINFORCE)
        let mean = returns.iter().sum::<f32>() / returns.len() as f32;
        let std = (returns.iter()
            .map(|r| (r - mean).powi(2))
            .sum::<f32>() / returns.len() as f32)
            .sqrt()
            .max(1e-8);
        let normalized: Vec<f32> = returns.iter()
            .map(|r| (r - mean) / std)
            .collect();

        // 4. REINFORCE update
        let n = all_experiences.len() as i64;
        let states_t: Tensor<Backend, 2> = Tensor::from_data(
            TensorData::new(
                all_experiences.iter().flat_map(|e| e.state.iter().copied()).collect::<Vec<_>>(),
                Shape::new([n, 8 * 11 * 11]),
            ),
            &<Backend as burn::tensor::backend::Backend>::Device::default(),
        );
        let actions: Vec<i64> = all_experiences.iter().map(|e| e.action as i64).collect();

        // reinforce_step returns (updated_net, loss). We reassign `net` so the
        // next iteration uses the updated weights.
        let (updated_net, _loss) = reinforce_step(net, &mut optim, learning_rate, &states_t, &actions, &normalized, &<Backend as burn::tensor::backend::Backend>::Device::default());
        net = updated_net;

        // 5. Evaluate and checkpoint
        if iteration % 10 == 0 {
            let win_rate = wins as f32 / (wins + losses).max(1) as f32;
            println!(
                "Iteration {iteration}: reward={total_reward:.1} \
                 win_rate={win_rate:.2} epsilon={epsilon:.3} pool={}",
                pool.len()
            );

            // Save a checkpoint
            let path = format!("{checkpoint_dir}/checkpoint-{iteration}");
            let recorder = DefaultFileRecorder::<FullPrecisionSettings>::new();
            net.clone().save_file(&path, &recorder)?;
            pool.add(PathBuf::from(&path));

            // Evaluate against the previous best
            if pool.len() > 1 {
                let prev = pool.checkpoints[pool.len() - 2].to_string_lossy().to_string();
                let eval_wr = evaluate_against(net, &prev, 50)?;
                println!("  eval vs previous: {eval_wr:.2} win rate");

                if eval_wr > 0.55 {
                    println!("  ✓ New best! Checkpoint saved.");
                }
            }
        }
    }

    // Save the final model
    let final_path = format!("{checkpoint_dir}/model-final");
    let recorder = DefaultFileRecorder::<FullPrecisionSettings>::new();
    net.clone().save_file(&final_path, &recorder)?;
    println!("Training complete. Final model saved to {final_path}");

    Ok(net)
}
}

A few things worth calling out in this code:

Return normalization. Raw REINFORCE returns can vary wildly — one episode might produce returns of +10, the next -5. The gradient magnitude depends on the return magnitude. Normalizing (subtract mean, divide by standard deviation) keeps the gradient stable. This is a standard trick that makes REINFORCE dramatically more practical.

Epsilon decay. Early in training, the snake needs to explore (epsilon = 0.2). Late in training, it mostly exploits what it’s learned (epsilon = 0.01). The linear decay schedule is simple but effective. More sophisticated schedules (exponential, cosine) can help, but linear decay works well enough for BattleSnake.

Checkpoint evaluation. Every 10 iterations, we pit the current network against the previous checkpoint. A 55% win rate means the current network is slightly better — good enough to justify saving. We’re looking for incremental progress, not dominance.

What the network learns through self-play

Imitation learning taught the network to copy a heuristic. Basic REINFORCE taught it to eat food and avoid dying. Self-play teaches it something neither of those could: how to compete.

Through self-play, the network can discover:

  • Space control. Claiming territory near the center of the board gives more room to maneuver. The heuristic knows to go for food, but it doesn’t know that position matters when food isn’t nearby.
  • Opponent tracking. The network learns to check where the enemy head is — not only where its body is. A nearby enemy head means a contested food or a head-to-head collision risk.
  • Patience. Sometimes the best move is to wait. If the opponent is between you and food, burning health to race them is worse than waiting for a better opportunity.
  • Endgame awareness. In the late game, when both snakes are long and the board is mostly body segments, the network learns to value space more than food.

None of these are explicitly in the reward function. They emerge from the pressure of playing against an opponent that’s also learning. The strategies are discovered, not designed.

The variance problem (and how to think about it)

REINFORCE is notoriously high-variance. One episode might produce a total return of +15 (the snake ate everything, won the game). The next might produce -5 (the snake died on turn 3). The gradient flips direction between episodes. Training is noisy.

Return normalization helps. So does averaging over many episodes per iteration. But the fundamental problem remains: in BattleSnake, one bad move can kill you, and the network has to figure out which move out of 20 was the bad one.

There are more sophisticated algorithms that address this — PPO (Proximal Policy Optimization) and A2C (Advantage Actor-Critic) are the standard ones. They use a baseline to reduce variance: instead of asking “was this action good?” they ask “was this action better than expected?” The “expected” part is a value function — a second network that predicts how good a state is.

For this tutorial, REINFORCE is the right starting point. It’s simple, it works, and it teaches the core idea (policy gradient) without the overhead of a value function. If you want to go further, swapping REINFORCE for PPO is the natural next step.

A note on training time

Self-play is slow. Each episode runs a full game (50–200 turns). Each turn requires a forward pass through the network for each player. A training run of 1000 iterations with 10 episodes per iteration is 10,000 games. On CPU, that might take hours. On GPU, it’s faster, but Burn’s GPU support varies by platform.

The practical approach: start small. Train for 100 iterations, check the win rate, adjust the hyperparameters, repeat. Don’t run a 10,000-iteration training job on your first try. The network’s behavior after 100 iterations will tell you whether the reward signal is working.

Part 8 takes the trained model and wires it into the web server from Part 1 — making the snake live on BattleSnake.

Previous: Part 6 — Reinforcement Learning Basics · Next: Part 8 — Wiring It to the Web