Part 6 — Reinforcement Learning Basics
Imitation learning got us a network that copies the heuristic. That’s useful, but the heuristic is a ceiling. The network can never discover strategies the heuristic doesn’t know about.
Reinforcement learning breaks that ceiling. Instead of learning to copy a teacher, the snake learns from its own experience. It plays, gets rewarded for good outcomes and punished for bad ones, and gradually figures out better strategies on its own.
The RL framework
In reinforcement learning, an agent (our snake) interacts with an environment (the game). At each step:
- The agent observes the state (the board encoding)
- The agent takes an action (a direction)
- The environment emits a reward (a number)
- The environment transitions to the next state
The agent’s goal is to maximize the total reward it collects over an episode (one game).
For BattleSnake, the reward signal looks like this:
Eat food: +1.0. This is the strongest signal — food is the goal of the game. Getting +1 for eating tells the network “moving to this square is worth doing.
Survive a turn: +0.01. A small positive reward for staying alive. Without this, the network learns that dying quickly is no worse than surviving — both give “nothing happens” = 0. With survival reward, the network prefers to stay alive long enough to find food.
Die: -1.0. A strong negative signal. Why -1 and not -10? Because the scale matters. If food gives +1 and dying gives -10, the network focuses only on not dying (avoiding all risk). If dying gives -1, the network is willing to take calculated risks when the food reward is high enough. The ratio between reward values shapes how risk-averse the snake is.
Nothing happens: 0. No signal — nothing was learned.
The policy
In reinforcement learning, the function that maps states to action probabilities is called a policy. We’ve been building one since Part 3 — our network’s forward pass → softmax is already a policy. We didn’t call it that.
The return for a step is the discounted sum of rewards from that step to the end of the episode:
G_t = r_t + γ * r_{t+1} + γ² * r_{t+2} + ... + γⁿ * r_{t+n}
γ (gamma) is the discount factor, usually 0.99. Here is the intuition: a reward you get now is more valuable than a reward you might get later, because the snake might die before reaching it. We multiply future rewards by γ each step to account for that uncertainty.
Each step further into the future gets multiplied by γ one more time:
- After 10 steps: the reward is worth
0.99¹⁰ ≈ 0.9times its face value - After 100 steps:
0.99¹⁰⁰ ≈ 0.37times its face value
Rewards 100 turns from now matter about a third as much as rewards right now. This gives the snake a reason to prefer food that’s nearby over food that’s far — even when both are reachable.
Policy gradient: REINFORCE
How do we adjust the weights to make good actions more likely?
The REINFORCE algorithm says: for each action the agent took, nudge the weights in the direction that would make that action more likely, proportional to how good the outcome was.
Before we look at the equation, here is what each symbol means — in terms you’ve already seen in this tutorial:
| Symbol | Meaning | Where you’ve seen it |
|---|---|---|
θ (theta) | The network’s weights | Same as the weights we’ve been updating since Part 5 |
α (alpha) | Learning rate | Same as the lr = 1e-3 from Part 5 |
π(a_t | s_t; θ) | The probability the network assigns to action a_t in state s_t, given weights θ | The softmax output from Part 3 — same thing |
∇_θ | “The gradient with respect to the weights” — the direction to nudge the weights | Same backpropagation from Part 5 |
G_t | The return from step t onward — how good the episode turned out | The discounted sum defined above |
Now the update rule:
θ ← θ + α * ∇_θ log π(a_t | s_t; θ) * G_t
Intuition: If G_t is positive (the episode went well), we want to make the action we took more likely. The gradient ∇_θ log π tells us which direction to nudge the weights to increase the probability of that action. Multiplying by G_t scales the nudge — a big positive return means a big nudge; a negative return flips the direction and makes that action less likely next time.
The log is there because it’s easier to differentiate and more numerically stable than working with probabilities directly. It gives the same direction of improvement without the numerical headaches.
In Candle, we compute this by taking the log probabilities from the forward pass and scaling the gradient by the return:
#![allow(unused)]
fn main() {
use candle_nn::{AdamW, Optimizer, VarMap, VarBuilder};
use candle_core::{Device, Result, Tensor};
/// REINFORCE training step
pub fn reinforce_step(
net: &SnakeNet,
opt: &mut AdamW,
states: &Tensor, // (batch, 968)
actions: &[u32], // actions taken
returns: &[f32], // G_t for each step
device: &Device,
) -> Result<f32> {
// Forward pass → log probabilities
let logits = net.forward(states)?;
let log_probs = candle_nn::ops::log_softmax(&logits, 1)?;
// Extract log prob for each action taken
// log_probs shape: (batch, 4), actions: (batch,)
let actions_t = Tensor::from_vec(actions.to_vec(), (actions.len(),), device)?;
let returns_t = Tensor::from_vec(returns.to_vec(), (returns.len(),), device)?;
// Gather the log prob for each specific action
let chosen_log_probs = log_probs.gather(&actions_t.unsqueeze(1)?, 1)?
.squeeze(1)?;
// Loss = -(log_prob * return) — we minimize, so negate the product
// to maximize log π(a) * G. Without the negation, the optimizer would
// *decrease* the probability of good actions, which is backwards.
let loss = ((chosen_log_probs * returns_t)? * -1.0)?.mean(0)?;
opt.backward_step(&loss)?;
Ok(loss.to_scalar::<f32>()?)
}
}
The negative sign is important. We want to maximize log π(a) * G. Since the optimizer minimizes loss, we minimize -log π(a) * G, which is equivalent to maximizing log π(a) * G.
Variance reduction with a baseline. Pure REINFORCE gradients can be high-variance — a few lucky or unlucky rolls can dominate the update. A simple fix: subtract a baseline (often a moving average of recent episode returns) from the return before taking the gradient. This doesn’t change the expected gradient (the baseline has zero mean under the policy), but it reduces variance significantly. Implementation: keep a running mean
Vof pasttotal_rewardvalues, and useG - Vinstead ofGdirectly. This isn’t required for a working tutorial, but it meaningfully speeds up convergence on BoardSnake boards.
The RL training loop
Here’s how it fits together — one episode (one game) at a time:
#![allow(unused)]
fn main() {
use std::collections::HashMap;
/// One step during a game
#[derive(Clone)]
struct Experience {
state: Vec<f32>, // flattened board encoding
next_state: Vec<f32>, // encoding after the action (used by the replay buffer in Part 9)
action: u32, // direction index (0–3)
reward: f32, // reward received
}
pub fn train_episode(
net: &SnakeNet,
opt: &mut AdamW, // passed in so Adam momentum persists across episodes
env: &GameEnv,
) -> Result<f32> {
let device = Device::Cpu;
let gamma = 0.99_f32;
let mut experiences: Vec<Experience> = Vec::new();
let mut state = env.reset();
let mut total_reward = 0.0_f32;
loop {
// Encode the board state
let tensor = encode_board(&state.board, &state.my_snake)?;
let flat = tensor.flatten_all()?.unsqueeze(0)?;
// Sample an action (ε-greedy: explore sometimes)
let action: u32 = if rand::random::<f32>() < 0.1 {
rand::random::<u32>() % 4 // explore: random
} else {
net.pick_direction(&flat)?.index() as u32 // Direction → index (Up=0, Down=1, Left=2, Right=3)
};
// Take the action, get reward and next state
let (next_state, reward) = env.step(action);
total_reward += reward;
let state_vec = flat
.squeeze(0)? // (1, 968) → (968,)
.to_vec1::<f32>()?;
experiences.push(Experience {
state: state_vec,
next_state: {
// For basic REINFORCE, next_state isn't used in the update —
// it's recorded here so the same Experience struct works with
// the replay buffer in Part 9. Part 7 shows a cleaner encoding
// helper (encode_board_and_flat) that produces (flat, ()) in one call.
vec![]
},
action,
reward,
});
if env.is_done() {
break;
}
state = next_state;
}
// Compute returns (G_t) using discounted rewards
let returns = compute_returns(&experiences, gamma);
// REINFORCE update
let states_t = Tensor::from_vec(
experiences.iter().flat_map(|e| e.state.iter().copied()).collect(),
(experiences.len(), 968),
&device,
)?;
let actions: Vec<u32> = experiences.iter().map(|e| e.action).collect();
let returns_v: Vec<f32> = returns;
// REINFORCE update — the optimizer is passed in from the training loop
// so Adam's momentum buffers persist across episodes.
let loss = reinforce_step(net, opt, &states_t, &actions, &returns_v, &device)?;
Ok(loss)
}
fn compute_returns(experiences: &[Experience], gamma: f32) -> Vec<f32> {
let mut returns = Vec::with_capacity(experiences.len());
let mut g = 0.0_f32;
for exp in experiences.iter().rev() {
g = exp.reward + gamma * g;
returns.push(g);
}
returns.reverse();
returns
}
}
The GameEnv is a simplified game simulator — it holds the board state, advances the game when step() is called, and returns the reward. We don’t need the full BattleSnake server for training; we can simulate games locally.
What the network learns
With REINFORCE, the network starts to discover patterns that go beyond the heuristic. If training is working well, you might observe the network doing the following:
- It might learn to wait — holding position when food is far, rather than burning health moving toward it
- It might learn to trap opponents — positioning itself to force the opponent into its own body
- It might learn risk assessment — accepting a small chance of death for a high-reward path
None of these strategies are explicitly programmed. They emerge from the reward signal. If you don’t see any of these behaviors after reasonable training time, the reward signal or training time may need adjustment.
The exploration-exploitation tradeoff
One subtlety: if the network always picks the highest-probability action, it never tries anything new. ε-greedy exploration fixes this: with probability ε, pick a random action. Typical values are 0.1 (10% random) early in training, decreasing to 0.01 (1%) later.
Part 7 combines REINFORCE with self-play, where the snake gets stronger opponents to train against as it improves.
Previous: Part 5 — Imitation Learning · Next: Part 7 — Self-Play Training