Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Part 8 — Self-Play: The Training Loop

Part 7 gave us the game simulator and the NetworkPlayer that plugs into it. Now we wire it all together: the opponent pool, the full self-play loop, and what training actually looks like.

The opponent pool

Here’s the self-play twist: we don’t always train against the current network. If we did, the opponent would be exactly as strong as the training snake, which means every game is a coin flip. The reward signal gets noisy.

Instead, we maintain a pool of checkpoints — saved snapshots of the network weights from different points in training. Each training episode, we sample an opponent from this pool. Sometimes the opponent is the current network. Sometimes it’s a weaker checkpoint from 100 training iterations ago. The variety stabilizes training.

#![allow(unused)]
fn main() {
use std::path::{Path, PathBuf};

/// A pool of saved model checkpoints for opponent sampling
pub struct OpponentPool {
    checkpoints: Vec<PathBuf>,
}

impl OpponentPool {
    pub fn new() -> Self {
        Self { checkpoints: Vec::new() }
    }

    /// Add a checkpoint to the pool
    pub fn add(&mut self, path: PathBuf) {
        self.checkpoints.push(path);
    }

    /// Sample a random opponent from the pool.
    /// If the pool is empty, returns None (use a random player instead).
    pub fn sample(&self) -> Option<PathBuf> {
        if self.checkpoints.is_empty() {
            return None;
        }
        use rand::Rng;
        let idx = rand::thread_rng().gen_range(0..self.checkpoints.len());
        Some(self.checkpoints[idx].clone())
    }

    /// Weighted sampling: prefer recent opponents, but occasionally
    /// pick older ones. This balances challenge (recent = harder)
    /// with diversity (older = different strategies).
    pub fn sample_weighted(&self) -> Option<PathBuf> {
        if self.checkpoints.is_empty() {
            return None;
        }

        // 50% chance: most recent checkpoint (hardest opponent)
        // 50% chance: uniformly random from the full pool
        use rand::Rng;
        let mut rng = rand::thread_rng();
        if rng.gen_range(0.0..1.0) < 0.5 {
            Some(self.checkpoints.last().unwrap().clone())
        } else {
            let idx = rng.gen_range(0..self.checkpoints.len());
            Some(self.checkpoints[idx].clone())
        }
    }

    pub fn len(&self) -> usize {
        self.checkpoints.len()
    }
}
}

The weighted sampling is a practical detail that matters. Training only against the strongest opponent is a curriculum, but it can be too hard — the training snake loses every game, the reward signal is uniformly negative, and nothing is learned. Mixing in weaker opponents gives the training snake some wins, some positive reward, and a gradient that points in a useful direction.

The full training loop

Everything comes together here. The loop runs for a fixed number of iterations. Each iteration: sample an opponent, play episodes, update the training snake, occasionally evaluate.

#![allow(unused)]
fn main() {
use std::fs;
use rand::Rng;
use candle_nn::{AdamW, Optimizer, VarMap};

pub fn train_self_play(
    var_map: &VarMap,
    num_iterations: usize,
    episodes_per_iteration: usize,
    checkpoint_dir: &str,
) -> Result<(), Box<dyn std::error::Error>> {
    let device = Device::Cpu;
    let learning_rate = 1e-3_f64;
    let gamma = 0.99_f32;
    let initial_epsilon = 0.2_f32;
    let min_epsilon = 0.01_f32;

    // Create the optimizer once so Adam's momentum buffers persist
    // across training iterations.
    let mut opt = AdamW::new(var_map.all_vars(), candle_nn::ParamsAdamW {
        lr: learning_rate,
        ..Default::default()
    })?;

    let mut pool = OpponentPool::new();

    // Save the initial random weights as the first checkpoint
    fs::create_dir_all(checkpoint_dir)?;
    let initial_path = format!("{checkpoint_dir}/checkpoint-0.safetensors");
    var_map.save(&initial_path)?;
    pool.add(PathBuf::from(&initial_path));

    for iteration in 1..=num_iterations {
        // Epsilon decays over time: less exploration as training progresses
        let epsilon = (initial_epsilon
            * (1.0 - iteration as f32 / num_iterations as f32))
            .max(min_epsilon);

        let mut all_experiences: Vec<Experience> = Vec::new();
        let mut total_reward = 0.0_f32;
        let mut wins = 0usize;
        let mut losses = 0usize;

        for _ in 0..episodes_per_iteration {
            // 1. Sample an opponent
            let opponent = match pool.sample_weighted() {
                Some(path) => NetworkPlayer::from_checkpoint(&path.to_string_lossy(), 0.0),
                // Note: VarMap doesn't implement Clone. In a real implementation, you'd
                // either share a &VarMap through a reference (if the player only needs it
                // briefly) or construct a fresh VarMap loaded from the same safetensors file.
                None => NetworkPlayer::from_var_map(var_map, 0.0),
            };

            // The training snake (with exploration)
            let training_player = NetworkPlayer::from_var_map(var_map, epsilon);

            // 2. Play an episode
            let board = random_board(11, 11, 2);
            let mut env = GameEnv::new(
                board,
                vec![Box::new(training_player), Box::new(opponent)],
            );

            env.reset();
            let mut experiences: Vec<Experience> = Vec::new();

            while !env.is_done() {
                let my_snake = &env.board.snakes[0];
                let state = encode_board(&env.board, my_snake)?
                    .flatten_all()?
                    .to_vec1::<f32>()?;

                // Decide the training player's action *before* calling env.step().
                // We can't re-derive the action after the fact — the training player
                // uses ε-greedy, so the actual action might differ from the
                // greedy argmax. If we called argmax after env.step(),
                // we'd always record the greedy action, not the action the player
                // actually took. That would train the network to reinforce actions
                // it didn't take.
                let state_t = Tensor::from_vec(
                    state.clone(),
                    (1, 968),
                    &device,
                )?;
                let action = if rand::random::<f32>() < epsilon {
                    rand::thread_rng().gen_range(0..4) as u32  // explore
                } else {
                    // Exploit: run a forward pass and pick the best direction
                    let vs = VarBuilder::from_varmap(var_map, DType::F32, &device);
                    let net = SnakeNet::new(vs)?;
                    net.forward(&state_t)?.argmax(1)?.to_scalar::<u32>()?
                };

                let result = env.step_with_action(0, action);

                // Encode the next board state for the replay buffer.
                // We need (state, action, reward, next_state) for every step
                // so the replay buffer in Part 9 can do prioritized sampling.
                // `encode_board_and_flat` is a small wrapper in snake-ml that calls
                // encode_board, flattens, and returns (Vec<f32, ()). The () is a
                // unit marker for destructuring convenience.
                let (next_state_flat, _) = encode_board_and_flat(&result.board, &result.board.snakes[0])?;

                experiences.push(Experience {
                    state,
                    next_state: next_state_flat,
                    action,
                    reward: result.rewards[0],
                });

                total_reward += result.rewards[0];

                if result.done {
                    match result.winner {
                        Some(0) => wins += 1,
                        Some(1) => losses += 1,
                        None => {} // draw
                    }
                }
            }

            all_experiences.extend(experiences);
        }

        // 3. Compute returns
        // We defined compute_returns in Part 6: it iterates backward through
        // the experience list, accumulating discounted rewards. The result is
        // G_t for each step — how good the outcome was from that point onward.
        let returns = compute_returns(&all_experiences, gamma);

        // Normalize returns (reduces variance in REINFORCE)
        let mean = returns.iter().sum::<f32>() / returns.len() as f32;
        let std = (returns.iter()
            .map(|r| (r - mean).powi(2))
            .sum::<f32>() / returns.len() as f32)
            .sqrt()
            .max(1e-8);
        let normalized: Vec<f32> = returns.iter()
            .map(|r| (r - mean) / std)
            .collect();

        // 4. REINFORCE update
        let states_t = Tensor::from_vec(
            all_experiences.iter()
                .flat_map(|e| e.state.iter().copied())
                .collect(),
            (all_experiences.len(), 968),
            &device,
        )?;
        let actions: Vec<u32> = all_experiences.iter().map(|e| e.action).collect();

        // The optimizer is created once outside the iteration loop so
        // Adam's momentum persists across training steps.
        reinforce_step(&{
            let vs = VarBuilder::from_varmap(var_map, DType::F32, &device);
            SnakeNet::new(vs)?
        }, &mut opt, &states_t, &actions, &normalized, &device)?;

        // 5. Evaluate and checkpoint
        if iteration % 10 == 0 {
            let win_rate = wins as f32 / (wins + losses).max(1) as f32;
            println!(
                "Iteration {iteration}: reward={total_reward:.1} \
                 win_rate={win_rate:.2} epsilon={epsilon:.3} pool={}",
                pool.len()
            );

            // Save a checkpoint
            let path = format!("{checkpoint_dir}/checkpoint-{iteration}.safetensors");
            var_map.save(&path)?;
            pool.add(PathBuf::from(&path));

            // Evaluate against the previous best
            if pool.len() > 1 {
                let prev = pool.checkpoints[pool.len() - 2].to_string_lossy().to_string();
                let eval_wr = evaluate_against(var_map, &prev, 50)?;
                println!("  eval vs previous: {eval_wr:.2} win rate");

                if eval_wr > 0.55 {
                    println!("  ✓ New best! Checkpoint saved.");
                }
            }
        }
    }

    // Save the final model
    let final_path = format!("{checkpoint_dir}/model-final.safetensors");
    var_map.save(&final_path)?;
    println!("Training complete. Final model saved to {final_path}");

    Ok(())
}
}

Entropy and beta decay. If you add entropy regularization — subtracting beta * H(π) from the reward, where H(π) measures how spread-out the policy distribution is — make sure beta decays over training. A constant high beta keeps the snake perpetually random; it never commits to learned behavior. A common schedule: beta = max(0.01, 0.5 * exp(-step / 1000)) — start exploratory, gradually lock in the policy.

A few things worth calling out in this code:

Return normalization. This connects back to the variance problem from Part 6. Raw REINFORCE returns can vary wildly: one episode might give +10 (the snake ate everything and won), the next -5 (died on turn 3). Without normalization, the gradient points in one direction after the first episode and the opposite direction after the second. Training is noisy — the snake can’t tell which actions were actually good.

Normalizing fixes this: subtract the mean, divide by standard deviation. Now every batch of returns centers on zero with unit variance. The gradient consistently pushes toward better actions, regardless of whether the snake got lucky or unlucky in a given episode. This is the same variance-reduction idea that Part 6 mentioned — normalization makes the signal reliable enough for training to converge.

Epsilon decay. Early in training, the snake needs to explore (epsilon = 0.2). Late in training, it mostly exploits what it’s learned (epsilon = 0.01). The linear decay schedule is simple but effective. More sophisticated schedules (exponential, cosine) can help, but linear decay works well enough for BattleSnake.

Checkpoint evaluation. Every 10 iterations, we pit the current network against the previous checkpoint. A 55% win rate means the current network is slightly better — good enough to justify saving. We’re looking for incremental progress, not dominance.

What the network learns through self-play

Imitation learning taught the network to copy a heuristic. Basic REINFORCE taught it to eat food and avoid dying. Self-play teaches it something neither of those could: how to compete.

If training is working well, you might observe the network discovering the following strategies:

  • Space control. Claiming territory near the center of the board gives more room to maneuver. The heuristic knows to go for food, but it doesn’t know that position matters when food isn’t nearby.
  • Opponent tracking. The network learns to check where the enemy head is — not only where its body is. A nearby enemy head means a contested food or a head-to-head collision risk.
  • Patience. Sometimes the best move is to wait. If the opponent is between you and food, burning health to race them is worse than waiting for a better opportunity.
  • Endgame awareness. In the late game, when both snakes are long and the board is mostly body segments, the network learns to value space more than food.

None of these are explicitly in the reward function. They emerge from the pressure of playing against an opponent that’s also learning. The strategies are discovered, not designed. If you don’t see any of these behaviors after a reasonable number of iterations, the opponent pool may need adjustment — too-strong opponents can make learning impossible, while too-weak opponents don’t provide useful pressure.

The variance problem (and how to think about it)

REINFORCE is notoriously high-variance. One episode might produce a total return of +15 (the snake ate everything, won the game). The next might produce -5 (the snake died on turn 3). The gradient flips direction between episodes. Training is noisy.

Return normalization helps. So does averaging over many episodes per iteration. But the fundamental problem remains: in BattleSnake, one bad move can kill you, and the network has to figure out which move out of 20 was the bad one.

There are more sophisticated algorithms that address this — PPO (Proximal Policy Optimization) and A2C (Advantage Actor-Critic) are the standard ones. They use a baseline to reduce variance: instead of asking “was this action good?” they ask “was this action better than expected?” The “expected” part is a value function — a second network that predicts how good a state is.

For this tutorial, REINFORCE is the right starting point. It’s simple, it works, and it teaches the core idea (policy gradient) without the overhead of a value function. If you want to go further, swapping REINFORCE for PPO is the natural next step.

A note on training time

Self-play is slow. Each episode runs a full game (50–200 turns). Each turn requires a forward pass through the network for each player. A training run of 1000 iterations with 10 episodes per iteration is 10,000 games. On CPU, that might take hours. On GPU, it’s faster, but Candle’s GPU support varies by platform.

The practical approach: start small. Train for 100 iterations, check the win rate, adjust the hyperparameters, repeat. Don’t run a 10,000-iteration training job on your first try. The network’s behavior after 100 iterations will tell you whether the reward signal is working.

Part 9 takes the trained model and wires it into the web server from Part 1 — making the snake live on BattleSnake.

Previous: Part 7 — Self-Play: The Game Simulator · Next: Part 9 — Wiring It to the Web