Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Part 6 — Reinforcement Learning Basics

Imitation learning got us a network that copies the heuristic. That’s useful, but the heuristic is a ceiling. The network can never discover strategies the heuristic doesn’t know about.

Reinforcement learning breaks that ceiling. Instead of learning to copy a teacher, the snake learns from its own experience. It plays, gets rewarded for good outcomes and punished for bad ones, and gradually figures out better strategies on its own.

The RL framework

In reinforcement learning, an agent (our snake) interacts with an environment (the game). At each step:

  1. The agent observes the state (the board encoding)
  2. The agent takes an action (a direction)
  3. The environment emits a reward (a number)
  4. The environment transitions to the next state

The agent’s goal is to maximize the total reward it collects over an episode (one game).

For BattleSnake, the reward signal looks like this:

Eat food: +1.0. This is the strongest signal — food is the goal of the game. Getting +1 for eating tells the network “moving to this square is worth doing.

Survive a turn: +0.01. A small positive reward for staying alive. Without this, the network learns that dying quickly is no worse than surviving — both give “nothing happens” = 0. With survival reward, the network prefers to stay alive long enough to find food.

Die: -1.0. A strong negative signal. Why -1 and not -10? Because the scale matters. If food gives +1 and dying gives -10, the network focuses only on not dying (avoiding all risk). If dying gives -1, the network is willing to take calculated risks when the food reward is high enough. The ratio between reward values shapes how risk-averse the snake is.

Nothing happens: 0. No signal — nothing was learned.

The policy

In reinforcement learning, the function that maps states to action probabilities is called a policy. We’ve been building one since Part 3 — our network’s forward pass → softmax is already a policy. We didn’t call it that.

The return for a step is the discounted sum of rewards from that step to the end of the episode:

G_t = r_t + γ * r_{t+1} + γ² * r_{t+2} + ... + γⁿ * r_{t+n}

γ (gamma) is the discount factor, usually 0.99. Here is the intuition: a reward you get now is more valuable than a reward you might get later, because the snake might die before reaching it. We multiply future rewards by γ each step to account for that uncertainty.

Each step further into the future gets multiplied by γ one more time:

  • After 10 steps: the reward is worth 0.99¹⁰ ≈ 0.9 times its face value
  • After 100 steps: 0.99¹⁰⁰ ≈ 0.37 times its face value

Rewards 100 turns from now matter about a third as much as rewards right now. This gives the snake a reason to prefer food that’s nearby over food that’s far — even when both are reachable.

Policy gradient: REINFORCE

How do we adjust the weights to make good actions more likely?

The REINFORCE algorithm says: for each action the agent took, nudge the weights in the direction that would make that action more likely, proportional to how good the outcome was.

Before we look at the equation, here is what each symbol means — in terms you’ve already seen in this tutorial:

SymbolMeaningWhere you’ve seen it
θ (theta)The network’s weightsSame as the weights we’ve been updating since Part 5
α (alpha)Learning rateSame as the lr = 1e-3 from Part 5
π(a_t | s_t; θ)The probability the network assigns to action a_t in state s_t, given weights θThe softmax output from Part 3 — same thing
∇_θ“The gradient with respect to the weights” — the direction to nudge the weightsSame backpropagation from Part 5
G_tThe return from step t onward — how good the episode turned outThe discounted sum defined above

Now the update rule:

θ ← θ + α * ∇_θ log π(a_t | s_t; θ) * G_t

Intuition: If G_t is positive (the episode went well), we want to make the action we took more likely. The gradient ∇_θ log π tells us which direction to nudge the weights to increase the probability of that action. Multiplying by G_t scales the nudge — a big positive return means a big nudge; a negative return flips the direction and makes that action less likely next time.

The log is there because it’s easier to differentiate and more numerically stable than working with probabilities directly. It gives the same direction of improvement without the numerical headaches.

In Candle, we compute this by taking the log probabilities from the forward pass and scaling the gradient by the return:

#![allow(unused)]
fn main() {
use candle_nn::{AdamW, Optimizer, VarMap, VarBuilder};
use candle_core::{Device, Result, Tensor};

/// REINFORCE training step
pub fn reinforce_step(
    net: &SnakeNet,
    opt: &mut AdamW,
    states: &Tensor,       // (batch, 968)
    actions: &[u32],       // actions taken
    returns: &[f32],       // G_t for each step
    device: &Device,
) -> Result<f32> {
    // Forward pass → log probabilities
    let logits = net.forward(states)?;
    let log_probs = candle_nn::ops::log_softmax(&logits, 1)?;

    // Extract log prob for each action taken
    // log_probs shape: (batch, 4), actions: (batch,)
    let actions_t = Tensor::from_vec(actions.to_vec(), (actions.len(),), device)?;
    let returns_t = Tensor::from_vec(returns.to_vec(), (returns.len(),), device)?;

    // Gather the log prob for each specific action
    let chosen_log_probs = log_probs.gather(&actions_t.unsqueeze(1)?, 1)?
        .squeeze(1)?;

    // Loss = -(log_prob * return) — we minimize, so negate the product
    // to maximize log π(a) * G. Without the negation, the optimizer would
    // *decrease* the probability of good actions, which is backwards.
    let loss = ((chosen_log_probs * returns_t)? * -1.0)?.mean(0)?;

    opt.backward_step(&loss)?;

    Ok(loss.to_scalar::<f32>()?)
}
}

The negative sign is important. We want to maximize log π(a) * G. Since the optimizer minimizes loss, we minimize -log π(a) * G, which is equivalent to maximizing log π(a) * G.

Variance reduction with a baseline. Pure REINFORCE gradients can be high-variance — a few lucky or unlucky rolls can dominate the update. A simple fix: subtract a baseline (often a moving average of recent episode returns) from the return before taking the gradient. This doesn’t change the expected gradient (the baseline has zero mean under the policy), but it reduces variance significantly. Implementation: keep a running mean V of past total_reward values, and use G - V instead of G directly. This isn’t required for a working tutorial, but it meaningfully speeds up convergence on BoardSnake boards.

The RL training loop

Here’s how it fits together — one episode (one game) at a time:

#![allow(unused)]
fn main() {
use std::collections::HashMap;

/// One step during a game
#[derive(Clone)]
struct Experience {
    state: Vec<f32>,       // flattened board encoding
    next_state: Vec<f32>,  // encoding after the action (used by the replay buffer in Part 9)
    action: u32,           // direction index (0–3)
    reward: f32,           // reward received
}

pub fn train_episode(
    net: &SnakeNet,
    opt: &mut AdamW,   // passed in so Adam momentum persists across episodes
    env: &GameEnv,
) -> Result<f32> {
    let device = Device::Cpu;
    let gamma = 0.99_f32;

    let mut experiences: Vec<Experience> = Vec::new();
    let mut state = env.reset();
    let mut total_reward = 0.0_f32;

    loop {
        // Encode the board state
        let tensor = encode_board(&state.board, &state.my_snake)?;
        let flat = tensor.flatten_all()?.unsqueeze(0)?;

        // Sample an action (ε-greedy: explore sometimes)
        let action: u32 = if rand::random::<f32>() < 0.1 {
            rand::random::<u32>() % 4  // explore: random
        } else {
            net.pick_direction(&flat)?.index() as u32  // Direction → index (Up=0, Down=1, Left=2, Right=3)
        };

        // Take the action, get reward and next state
        let (next_state, reward) = env.step(action);
        total_reward += reward;

        let state_vec = flat
            .squeeze(0)?  // (1, 968) → (968,)
            .to_vec1::<f32>()?;

        experiences.push(Experience {
            state: state_vec,
            next_state: {
                // For basic REINFORCE, next_state isn't used in the update —
                // it's recorded here so the same Experience struct works with
                // the replay buffer in Part 9. Part 7 shows a cleaner encoding
                // helper (encode_board_and_flat) that produces (flat, ()) in one call.
                vec![]
            },
            action,
            reward,
        });

        if env.is_done() {
            break;
        }
        state = next_state;
    }

    // Compute returns (G_t) using discounted rewards
    let returns = compute_returns(&experiences, gamma);

    // REINFORCE update
    let states_t = Tensor::from_vec(
        experiences.iter().flat_map(|e| e.state.iter().copied()).collect(),
        (experiences.len(), 968),
        &device,
    )?;
    let actions: Vec<u32> = experiences.iter().map(|e| e.action).collect();
    let returns_v: Vec<f32> = returns;

    // REINFORCE update — the optimizer is passed in from the training loop
    // so Adam's momentum buffers persist across episodes.
    let loss = reinforce_step(net, opt, &states_t, &actions, &returns_v, &device)?;

    Ok(loss)
}

fn compute_returns(experiences: &[Experience], gamma: f32) -> Vec<f32> {
    let mut returns = Vec::with_capacity(experiences.len());
    let mut g = 0.0_f32;

    for exp in experiences.iter().rev() {
        g = exp.reward + gamma * g;
        returns.push(g);
    }

    returns.reverse();
    returns
}
}

The GameEnv is a simplified game simulator — it holds the board state, advances the game when step() is called, and returns the reward. We don’t need the full BattleSnake server for training; we can simulate games locally.

What the network learns

With REINFORCE, the network starts to discover patterns that go beyond the heuristic. If training is working well, you might observe the network doing the following:

  • It might learn to wait — holding position when food is far, rather than burning health moving toward it
  • It might learn to trap opponents — positioning itself to force the opponent into its own body
  • It might learn risk assessment — accepting a small chance of death for a high-reward path

None of these strategies are explicitly programmed. They emerge from the reward signal. If you don’t see any of these behaviors after reasonable training time, the reward signal or training time may need adjustment.

The exploration-exploitation tradeoff

One subtlety: if the network always picks the highest-probability action, it never tries anything new. ε-greedy exploration fixes this: with probability ε, pick a random action. Typical values are 0.1 (10% random) early in training, decreasing to 0.01 (1%) later.

Part 7 combines REINFORCE with self-play, where the snake gets stronger opponents to train against as it improves.

Previous: Part 5 — Imitation Learning · Next: Part 7 — Self-Play Training