Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Part 6 — Reinforcement Learning Basics

Imitation learning got us a network that copies the heuristic. That’s useful, but the heuristic is a ceiling. The network can never discover strategies the heuristic doesn’t know about.

Reinforcement learning breaks that ceiling. Instead of learning to copy a teacher, the snake learns from its own experience. It plays, gets rewarded for good outcomes and punished for bad ones, and gradually figures out better strategies on its own.

The RL framework

In reinforcement learning, an agent (our snake) interacts with an environment (the game). At each step:

  1. The agent observes the state (the board encoding)
  2. The agent takes an action (a direction)
  3. The environment emits a reward (a number)
  4. The environment transitions to the next state

The agent’s goal is to maximize the total reward it collects over an episode (one game).

For BattleSnake, the reward signal looks like this:

EventReward
Eat food+1.0
Survive a turn (not starve)+0.01
Die-1.0
Nothing happens0

Eating food is the strongest signal — it’s the goal of the game. Surviving gets a small positive reward so the network doesn’t learn to kamikaze toward food. Dying gets a strong negative reward so the network avoids it.

The policy

The policy is what the network represents: π(state) → probabilities over actions. In our case, that’s the forward pass → softmax.

The return for a step is the discounted sum of rewards from that step to the end of the episode:

G_t = r_t + γ * r_{t+1} + γ² * r_{t+2} + ... + γⁿ * r_{t+n}

where γ (gamma) is the discount factor, usually 0.99. Rewards further in the future are worth less than immediate rewards — a reward 10 steps away is scaled by 0.99¹⁰ ≈ 0.90, almost full weight, but a reward 100 steps away is scaled by 0.99¹⁰⁰ ≈ 0.37, less than half. The agent cares about what happens soon almost as much as what happens right now, but distant outcomes fade.

Policy gradient: REINFORCE

How do we adjust the weights to make good actions more likely?

The REINFORCE algorithm says: for each action the agent took, nudge the weights in the direction that would make that action more likely, proportional to how good the return was.

The update rule:

θ ← θ + α * ∇_θ log π(a_t | s_t; θ) * G_t

∇_θ log π(a_t | s_t; θ) is the gradient of the log probability assigned to the action we took. G_t is the return from that point onward. If G_t is positive, we make that action more likely. If G_t is negative, we make it less likely.

In Burn, we compute this by taking the log softmax from the forward pass and scaling the gradient by the return:

#![allow(unused)]
fn main() {
use burn::backend::Autodiff;
use burn::optim::{AdamWConfig, AdamW, GradientsParams};
use burn::tensor::{Tensor, TensorData, Shape};
use burn::tensor::activation::softmax;
use burn::module::Module;

type Backend = Autodiff<burn::backend::Cpu>;

/// REINFORCE training step. Takes ownership of `net` and returns the
/// updated network — Burn's optimizer step consumes the network and returns
/// a new one with updated weights. The optimizer is passed mutably so its
/// momentum/velocity buffers persist across steps.
pub fn reinforce_step(
    mut net: SnakeNet<Backend>,
    optim: &mut AdamW<Backend>,
    lr: f64,
    states: &Tensor<Backend, 2>,  // (batch, 8*11*11)
    actions: &[i64],              // actions taken (class indices)
    returns: &[f32],             // G_t for each step
    device: &<Backend as burn::tensor::backend::Backend>::Device,
) -> (SnakeNet<Backend>, f32) {
    // Forward pass → log probabilities
    let logits = net.forward(states);
    let log_probs: Tensor<Backend, 2> = softmax(logits, 1).log();

    // Extract log prob for each action taken
    // log_probs shape: (batch, 4), actions: (batch,)
    let n = actions.len() as i64;
    let actions_t: Tensor<Backend, 1> = Tensor::from_data(
        TensorData::new(actions.iter().map(|&a| a as f32).collect::<Vec<_>>(), Shape::new([n])),
        device,
    );
    let returns_t: Tensor<Backend, 1> = Tensor::from_data(
        TensorData::new(returns.to_vec(), Shape::new([n])),
        device,
    );

    // Gather the log prob for each specific action
    let chosen_log_probs = log_probs.select(actions_t.clone(), 1).squeeze(1);

    // Loss = -log_prob * return (we minimize, so flip sign)
    let loss = (chosen_log_probs * returns_t).mean();

    // Backward pass: optimizer step returns the updated network.
    // `net = optim.step(...)` mutates the outer `net` (declared `mut` above)
    // so subsequent calls see the updated weights.
    let grads = loss.backward();
    let grads = GradientsParams::from_grads(grads, &net);
    net = optim.step(lr, net, grads);

    let loss_val = loss.to_data().to_vec::<f32>().unwrap()[0];
    (net, loss_val)
}
}

The negative sign is important. We want to maximize log π(a) * G. Since the optimizer minimizes loss, we minimize -log π(a) * G, which is equivalent to maximizing log π(a) * G.

Variance reduction with a baseline. Pure REINFORCE gradients can be high-variance — a few lucky or unlucky rolls can dominate the update. A simple fix: subtract a baseline (often a moving average of recent episode returns) from the return before taking the gradient. This doesn’t change the expected gradient (the baseline has zero mean under the policy), but it reduces variance significantly. Implementation: keep a running mean V of past total_reward values, and use G - V instead of G directly. This isn’t required for a working tutorial, but it meaningfully speeds up convergence on BoardSnake boards.

The RL training loop

Here’s how it fits together — one episode (one game) at a time:

#![allow(unused)]
fn main() {
use std::collections::HashMap;
use burn::backend::Autodiff;
use burn::optim::{AdamWConfig, AdamW};
use burn::tensor::{Tensor, TensorData, Shape};
use burn::module::Module;

type Backend = Autodiff<burn::backend::Cpu>;

/// One step during a game
#[derive(Clone)]
struct Experience {
    state: Vec<f32>,       // flattened board encoding
    next_state: Vec<f32>,  // encoding after the action (used by the replay buffer in Part 9)
    action: usize,         // direction index (0–3)
    reward: f32,           // reward received
}

pub fn train_episode(
    mut net: SnakeNet<Backend>,
    optim: &mut AdamW<Backend>,  // passed mut so optimizer state persists across episodes
    lr: f64,
    env: &mut GameEnv,
) -> (SnakeNet<Backend>, f32) {
    let device = <Backend as burn::tensor::backend::Backend>::Device::default();
    let gamma = 0.99_f32;

    let mut experiences: Vec<Experience> = Vec::new();
    let mut state = env.reset();
    let mut total_reward = 0.0_f32;

    loop {
        // Encode the board state
        let tensor = encode_board::<Backend>(&state.board, &state.my_snake, &device);
        let flat: Tensor<Backend, 1> = tensor.reshape([968]);
        let batched: Tensor<Backend, 2> = flat.reshape([1, 968]);

        // Sample an action (ε-greedy: explore sometimes)
        let action: usize = if rand::random::<f32>() < 0.1 {
            (rand::random::<u32>() % 4) as usize  // explore: random
        } else {
            net.pick_direction(batched.clone()).index()  // pick_direction takes Tensor by value; clone to keep batched for later
        };

        // Take the action, get reward and next state
        let (next_state, reward) = env.step(action);
        total_reward += reward;

        let state_vec: Vec<f32> = flat.to_data().to_vec::<f32>().unwrap();

        experiences.push(Experience {
            state: state_vec,
            next_state: {
                // For basic REINFORCE, next_state isn't used in the update —
                // it's recorded here so the same Experience struct works with
                // the replay buffer in Part 9. Part 7 shows a cleaner encoding
                // helper (encode_board_and_flat) that produces (flat, ()) in one call.
                vec![]
            },
            action,
            reward,
        });

        if env.is_done() {
            break;
        }
        state = next_state;
    }

    // Compute returns (G_t) using discounted rewards
    let returns = compute_returns(&experiences, gamma);

    // REINFORCE update
    let n = experiences.len() as i64;
    let states_t: Tensor<Backend, 2> = Tensor::from_data(
        TensorData::new(
            experiences.iter().flat_map(|e| e.state.iter().copied()).collect::<Vec<_>>(),
            Shape::new([n, 8 * 11 * 11]),
        ),
        &device,
    );
    let actions: Vec<i64> = experiences.iter().map(|e| e.action as i64).collect();
    let returns_v: Vec<f32> = returns;

    // REINFORCE update — the optimizer is passed mutably from the training loop
    // so Adam's momentum buffers persist across episodes. reinforce_step returns
    // the updated network and the loss value.
    let (net, loss) = reinforce_step(net, optim, lr, &states_t, &actions, &returns_v, &device);

    (net, loss)
}

fn compute_returns(experiences: &[Experience], gamma: f32) -> Vec<f32> {
    let mut returns = Vec::with_capacity(experiences.len());
    let mut g = 0.0_f32;

    for exp in experiences.iter().rev() {
        g = exp.reward + gamma * g;
        returns.push(g);
    }

    returns.reverse();
    returns
}
}

The GameEnv is a simplified game simulator — it holds the board state, advances the game when step() is called, and returns the reward. We don’t need the full BattleSnake server for training; we can simulate games locally.

What the network learns

With REINFORCE, the network starts to discover patterns that go beyond the heuristic:

  • It might learn to wait — holding position when food is far, rather than burning health moving toward it
  • It might learn to trap opponents — positioning itself to force the opponent into its own body
  • It might learn risk assessment — accepting a small chance of death for a high-reward path

None of these strategies are explicitly programmed. They emerge from the reward signal.

The exploration-exploitation tradeoff

One subtlety: if the network always picks the highest-probability action, it never tries anything new. ε-greedy exploration fixes this: with probability ε, pick a random action. Typical values are 0.1 (10% random) early in training, decreasing to 0.01 (1%) later.

Part 7 combines REINFORCE with self-play, where the snake gets stronger opponents to train against as it improves.

Previous: Part 5 — Imitation Learning · Next: Part 7 — Self-Play Training