Part 10 — Scaling Up
The snake works. It beats a random player, holds its own against the heuristic, and survives more than a few turns against real opponents on BattleSnake. That’s a good start. But “good start” isn’t the same as “competitive.”
This part isn’t a tutorial in the same way Parts 1–8 were. It’s a map of what to try next — the techniques, architectures, and training tricks that turn a decent snake into a strong one. Each direction is a rabbit hole. Pick one, go deep, measure the results.
The bottleneck: our architecture
The MLP from Part 3 works, but it has a structural limitation. Flattening the 8×11×11 tensor into a 968-element vector destroys spatial information. The first linear layer has to re-learn that adjacent cells in the input are adjacent on the board. That’s wasted capacity — the network is spending parameters to discover something the encoding already knows.
MLP → CNN reshape. Candle’s Conv2d expects
(N, C, H, W)input. When you replace the MLP with the CNN, remember that Part 3’s flattened MLP input(N, 8*H*W)has no spatial structure, while the CNN needs(N, 8, H, W). Part 2’sencode_boardreturns(8, H, W)— the CNN’sforwardhandles the reshape internally so callers can pass the raw encoding.
Convolutional layers fix this. A convolutional layer slides a small filter (typically 3×3) across the spatial dimensions of the input, applying the same weights at every position. The filter naturally captures local patterns: food two cells to the right, a body segment blocking the path, an enemy head approaching. These are exactly the patterns that matter in BattleSnake.
A simple CNN (Convolutional Neural Network — a network that uses convolution layers instead of fully-connected ones, preserving spatial structure in the input)
Replace the MLP with a small convolutional network:
Input: (8, 11, 11) — 8 feature planes
│
▼ Conv2d(8 → 32, kernel=3, padding=1) → ReLU
│ shape: (32, 11, 11)
▼ Conv2d(32 → 64, kernel=3, padding=1) → ReLU
│ shape: (64, 11, 11)
▼ Flatten
│ shape: (64 * 11 * 11,) = 7744
▼ Linear(7744 → 256) → ReLU
│
▼ Linear(256 → 4) → logits
In Candle:
#![allow(unused)]
fn main() {
use candle_core::{Result, Tensor, Device};
use candle_nn::{Conv2dConfig, VarBuilder, conv2d, linear, Linear, Module};
pub struct SnakeCnn {
conv1: Conv2d,
conv2: Conv2d,
fc1: Linear,
fc2: Linear,
}
impl SnakeCnn {
pub fn new(vs: VarBuilder) -> Result<Self> {
let conv_cfg = Conv2dConfig {
padding: 1, // same-padding: output spatial dims match input
..Default::default()
};
let conv1 = conv2d(8, 32, 3, conv_cfg, vs.pp("conv1"))?;
let conv2 = conv2d(32, 64, 3, conv_cfg, vs.pp("conv2"))?;
let fc1 = linear(64 * 11 * 11, 256, vs.pp("fc1"))?;
let fc2 = linear(256, 4, vs.pp("fc2"))?;
Ok(Self { conv1, conv2, fc1, fc2 })
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
// encode_board produces (8, 11, 11). The CNN needs (batch, 8, 11, 11).
// Reshape once at entry so callers can pass either shape.
let b = x.dim(0)?;
let x = x.reshape((b, 8, 11, 11))?;
let x = self.conv1.forward(&x)?.relu()?;
// (batch, 32, 11, 11)
let x = self.conv2.forward(&x)?.relu()?;
// (batch, 64, 11, 11)
let x = x.reshape((b, 64 * 11 * 11))?;
let x = self.fc1.forward(&x)?.relu()?;
self.fc2.forward(&x)
}
}
}
MLP vs CNN input shapes. The MLP from Part 3 expects a flattened tensor
(batch, 8*H*W)with no spatial structure. The CNN expects(batch, 8, H, W)with explicit channel and spatial dimensions. This is a common trip point when switching architectures — the reshape is part of the contract, not optional padding. The code above handles it by reshaping at the start offorward, so callers pass(8, 11, 11)directly and the network reshapes internally.
The CNN uses more parameters than the MLP (the fully-connected layers are bigger after convolution), but the convolutional filters share weights across positions. The first conv layer has 8 × 32 × 3 × 3 = 2,304 parameters regardless of board size. The MLP’s first layer has 968 × 256 = 247,808. The CNN is more parameter-efficient for spatial data.
Dueling networks
A standard policy network outputs 4 logits directly. A dueling network splits the final layers into two heads:
- V(s) — the value: how good is this position overall? One scalar. A position with food nearby and no enemies is good. One where you’re surrounded is bad.
- A(s, a) — the advantage of each action: how much better is this specific move compared to the average move? Four values (one per direction). Taking a food-adjacent cell has high advantage; taking a safe-but-useless cell has low advantage.
- Q(s, a) — the quality: the overall quality of taking action
ain states. This is what we want to maximize.
The key insight: it’s easier to learn two simple things than one complicated thing. Instead of learning Q directly, the network learns V and A separately, then combines them:
Q(s, a) = V(s) + A(s, a) - mean(A(s, ·))
The - mean(A) term needs explaining. Without it, V and A would be ambiguous: V could be 10 with A = [1, 2, 3, 4], or V could be 13 with A = [-2, -1, 0, 1] — both give the same Q values (13, 14, 15, 16). The network could learn either split, and the two options aren’t comparable. Subtracting the mean forces the advantages to average zero, so A only carries the differences between actions and V has to carry the base value. Now V and A each have a single, stable job, and training is more reliable.
The network learns V and A jointly, which stabilizes training significantly.
#![allow(unused)]
fn main() {
/// The standard forward is defined in the first `impl SnakeCnn` block above.
/// It reshapes to (batch, 8, 11, 11) at entry and handles the full conv → fc pipeline.
/// A dueling variant of the CNN with separate value and advantage heads.
/// The struct includes two additional linear layers not present in the base SnakeCnn.
pub struct SnakeCnnDueling {
conv1: Conv2d,
conv2: Conv2d,
shared_fc: Linear,
v_head: Linear, // value: (batch, 256) → (batch, 1)
a_head: Linear, // advantage: (batch, 256) → (batch, 4)
}
impl SnakeCnnDueling {
pub fn new(vs: VarBuilder) -> Result<Self> {
let conv_cfg = Conv2dConfig {
padding: 1,
..Default::default()
};
let conv1 = conv2d(8, 32, 3, conv_cfg, vs.pp("conv1"))?;
let conv2 = conv2d(32, 64, 3, conv_cfg, vs.pp("conv2"))?;
let shared_fc = linear(64 * 11 * 11, 256, vs.pp("shared_fc"))?;
let v_head = linear(256, 1, vs.pp("v_head"))?;
let a_head = linear(256, 4, vs.pp("a_head"))?;
Ok(Self { conv1, conv2, shared_fc, v_head, a_head })
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
// Reshape for callers who pass raw (8, 11, 11) encoding.
// Consistent with SnakeCnn::forward above.
let b = x.dim(0)?;
let x = x.reshape((b, 8, 11, 11))?;
let x = self.conv1.forward(&x)?.relu()?;
let x = self.conv2.forward(&x)?.relu()?;
let x = x.reshape((b, 64 * 11 * 11))?;
// Shared feature extraction
let shared = self.shared_fc.forward(&x)?.relu()?;
// Split into value and advantage heads
let v = self.v_head.forward(&shared)?; // (batch, 1)
let a = self.a_head.forward(&shared)?; // (batch, 4)
// Q = V + A - mean(A) — subtract mean for identifiability
let a_centered = (&a - &a.mean(1)?)?;
(&v + &a_centered)?
}
}
}
This is a standard production technique for RL agents (DQN, Actor-Critic, PPO). The added complexity is minimal — two small linear heads instead of one — but the training signal is meaningfully richer.
Residual connections
A deeper CNN can learn more complex patterns, but training deeper networks is harder. During backpropagation, gradients flow backward through each layer. Each layer multiplies the gradient by its local derivative. Multiply enough small numbers together and the result is tiny — the gradient “vanishes” and the early layers barely update at all. A 4-layer CNN might have this problem; a 10-layer CNN almost certainly will.
Residual connections fix this by giving the gradient a shortcut path:
x ──▶ Conv → ReLU → Conv → ReLU ──┐
│ │ + ──▶ ReLU
└──────────────────────────────────┘
The + is an element-wise addition. If the layer doesn’t need to change anything, it learns to output zeros and the shortcut passes the input through unchanged. The layer only needs to learn the difference between the input and the desired output (the “residual”), which is often a much smaller adjustment than learning the full transformation from scratch. This makes training deeper networks (4+ conv layers) practical.
#![allow(unused)]
fn main() {
fn residual_block(x: &Tensor, conv1: &Conv2d, conv2: &Conv2d) -> Result<Tensor> {
let out = conv1.forward(x)?.relu()?;
let out = conv2.forward(&out)?;
// Skip connection: add input to output
(out + x)?.relu()
}
}
Residual connections require that the input and output have the same shape (same channels, same spatial dimensions). If you want to change the channel count, add a 1×1 convolution on the skip path.
Board size and the fully-connected layer
The CNN’s convolutional layers work on any board size — a 3×3 filter slides the same way on 11×11 and 19×19. But the fully-connected layer after flattening is tied to a specific spatial size. 64 × 11 × 11 = 7,744 for an 11×11 board. 64 × 19 × 19 = 23,104 for a 19×19 board.
Two ways to handle this:
-
Global average pooling — instead of flattening, average each channel across the spatial dimensions. The fc layer input becomes (64,) regardless of board size. This loses spatial precision but makes the network board-size-agnostic.
-
Train on a fixed board size — most competitive BattleSnake snakes train on 11×11. Tournament games occasionally use other sizes, but 11×11 is the default. If you’re optimizing for competition, fixed-size training is fine.
The bottleneck: our training algorithm
REINFORCE works, but it’s the simplest policy gradient method. It has two well-known problems:
-
High variance. The gradient estimate is noisy. One episode might produce a total return of +15, the next might produce -5. The update direction flips between episodes. Training is slow and unstable.
-
No baseline. REINFORCE judges an action by the absolute return. But every action in a good episode looks good (the return is high), and every action in a bad episode looks bad (the return is low). The algorithm can’t distinguish “this was a good action in a bad episode” from “this was a bad action in a bad episode.”
PPO: the standard fix
Proximal Policy Optimization (PPO) addresses both problems. It uses a value function — a second network (or a second head on the same network) that predicts the expected return from a given state. The value function serves as the baseline: instead of scaling the gradient by the raw return, PPO scales it by the advantage — how much better this action was than expected.
advantage = return - value_function(state)
If the action led to a return of +5 but the value function expected +3, the advantage is +2. The action was better than expected — make it more likely.
PPO also constrains the size of each update. The policy can only change by a small amount per step (the “proximal” part). This prevents catastrophic updates that destroy learned behavior.
Implementing PPO in Candle is a larger project than REINFORCE — you need the value function, the advantage calculation, the clipped objective, and the epoch-based minibatch updates. The architecture change is straightforward (add a value head to the network). The training loop is where the complexity lives.
A simpler improvement: reward shaping
Before jumping to PPO, try improving the reward signal. REINFORCE with a better reward function can outperform REINFORCE with a bad reward function and a fancier algorithm.
Distance to food. A small positive reward for getting closer to the nearest food. This gives the network signal on every turn, not only when it eats. Space count. Count the number of reachable cells from the head (flood fill) — more space means more room to maneuver, so reward the network for staying in open areas. Opponent distance. A small penalty for being close to the opponent’s head when the opponent is larger, encouraging avoidance. Turns alive. Instead of +0.01 per turn, scale it by current health — a snake at health 5 gets a bigger survival bonus per turn than a snake at health 95.
The key principle: the more informative the reward, the faster the network learns. If the reward only fires at the end of the episode (win/lose), learning is slow because the credit assignment problem is hard. If the reward fires on every turn with a clear gradient (this move was slightly better than that move), learning is fast.
The bottleneck: training speed
CPU training is fine for development. For production-quality results, GPU training is the difference between “trained overnight” and “trained over the weekend.”
GPU training with Candle
Candle supports CUDA via the cuda feature flag:
[dependencies]
candle-core = { version = "0.6", features = ["cuda"] }
candle-nn = { version = "0.6", features = ["cuda"] }
Then create a CUDA device:
#![allow(unused)]
fn main() {
let device = candle_core::Device::new_cuda(0)?;
}
Most tensor operations work the same on CUDA as on CPU. The model code doesn’t change. The training loop doesn’t change. The only difference is which device the tensors live on.
For self-play, the speedup is significant. A forward pass on CPU takes ~1ms. On a modern GPU, it takes ~0.01ms. Across 10,000 training episodes, that’s the difference between 3 hours and 2 minutes.
Vectorized environments
The GameEnv from Part 8 runs one game at a time. On GPU, we can run many games in parallel by stacking the board states into a batch tensor. Instead of encoding one board into (1, 8, 11, 11), encode 64 boards into (64, 8, 11, 11) and run a single forward pass.
#![allow(unused)]
fn main() {
// Instead of:
for board in &boards {
let tensor = encode_board(board, my_snake)?;
let logits = net.forward(&tensor.unsqueeze(0)?)?;
// ...
}
// Do:
let batch: Vec<f32> = boards.iter()
.flat_map(|b| encode_board_flat(b))
.collect();
let batch_t = Tensor::from_vec(batch, (64, 8, 11, 11), &device)?;
let logits = net.forward(&batch_t)?; // single forward pass for 64 boards
}
Vectorized environments are how large-scale RL systems (AlphaGo, AlphaStar) achieve their training throughput. For BattleSnake, vectorizing 64 environments gives roughly a 50× speedup over sequential play (the forward pass is 64× faster; the game simulation is still sequential but negligible compared to inference).
The bottleneck: representation
The 8-channel encoding from Part 2 is a reasonable starting point, but it loses information. Here are some directions to explore:
Relative encoding
Instead of encoding the board from a global perspective (food at (5, 3), my head at (2, 1)), encode it relative to the snake’s head. The head is always at the center of the tensor. The rest of the board is shifted accordingly.
This has two benefits. First, the network doesn’t need to learn that (5, 3) and (2, 1) are close — it can see it directly in the input. Second, the same local pattern (food one cell to the right) always looks the same regardless of the head’s absolute position. This is translational invariance, and convolutional networks exploit it naturally.
The implementation requires rotating the board so that “up” (the direction the snake is facing) is always the same channel. This adds complexity to the encoder but significantly improves the network’s ability to generalize.
Temporal stacking
Feed the last 3–5 board states as additional channels. Instead of 8 channels, the input becomes 8 × 5 = 40 channels (current state + 4 previous states). This gives the network motion information — it can see that the enemy is moving toward it, or that it’s circling the same area.
The cost is 5× more input channels, which means 5× more parameters in the first convolutional layer. For a small CNN, this is manageable.
Full board state
Our encoding drops some information: snake IDs, snake lengths, the game ruleset, the turn number. Some of these are useful. Snake length matters for head-to-head collisions (the longer snake wins). The ruleset matters because “wrapped” mode has no walls — the edges connect. The turn number matters because food spawning behavior changes as the game progresses.
Add channels for:
- Snake length (normalized, one channel per snake)
- Turn number (repeated across the board)
- Game mode indicator (0 = standard, 1 = wrapped)
Each addition costs one channel. The network gets more signal for minimal overhead.
Where to go from here
The techniques in this part are ordered roughly by effort-to-reward ratio:
- Reward shaping — easiest to try, often the biggest improvement. Add a distance-to-food bonus and see if training converges faster.
- CNN architecture — swap the MLP for a CNN. More parameter-efficient for spatial data, faster convergence.
- GPU training — unlocks larger batch sizes and more training iterations. Necessary for PPO.
- PPO — the standard RL algorithm. More stable training, better sample efficiency. Requires a value function.
- Relative encoding — significant representational improvement, but requires rewriting the encoder.
- Temporal stacking — gives the network motion information. Easy to implement once you have the CNN working.
- Vectorized environments — large-scale training throughput. Requires GPU.
Start at the top. Measure after each change. If reward shaping gets you from 40% win rate to 55%, that’s real progress. If switching to PPO gets you from 55% to 60%, that’s also real progress. But if PPO gets you from 40% to 42% while reward shaping gets you to 55%, the reward signal was the real problem — fix that first.
The snake is live. The foundation is solid. Everything from here is iteration.
Previous: Part 9 — Wiring It to the Web