Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Battlesnake AI with Candle

The game is simple: a snake moves around an 11×11 board, eating food while avoiding walls, its own body, and other snakes. Your server receives the board state and responds with a move. That’s it.

The interesting part is how to decide which move to make. A naive approach works okay — always go toward the nearest food, avoid the nearest wall. But a naive approach dies fast against anything competitive. This tutorial is about building a snake that learns.

We’ll use two tools for this:

Candle is a Rust machine learning library from Hugging Face. It’s lightweight, no Python required, and runs on CPU or GPU. We’ll use it to build the neural network that takes the board state and outputs a move decision.

Battlesnake is the game — it’s a hosted competitive programming game where your server is a web endpoint that gets called every turn. The docs are at docs.battlesnake.com. You don’t need to understand the game deeply to start; we’ll cover the parts that matter.

The plan:

  • Parts 1–3: The board as data, your first simple network, and what a neural net is actually doing in this context
  • Parts 4–5: A heuristic baseline (rule-based snake) and imitation learning — use the heuristic to generate training data, then train the net to copy it
  • Parts 6–7: Reinforcement learning — rewards, policy gradients, and self-play training loops
  • Part 8: Deploying — wire the trained model to a web server that talks to the Battlesnake engine
  • Part 9: Scaling up — GPU training, larger boards, better architectures

We touch on both imitation learning and reinforcement learning, but always grounded in the concrete problem rather than abstract ML theory. The heuristic from the early parts becomes the foundation for everything that follows.

Version note. This tutorial uses candle-core 0.6 and candle-nn 0.6. Candle is pre-1.0 and the API has changed significantly since 0.6 (the current version is 0.10). The concepts — tensor operations, gradient tracking, training loops, policy gradients — are the same across versions. The specific method names and module paths may differ. If you’re on a newer version, check docs.rs for the current API.

Prerequisites

  • Rust (1.70+) — rustup.rs
  • mdBookcargo install mdbook

What we’re building

The project is a Rust workspace with two crates:

battlesnake/           ← workspace root
├── Cargo.toml         ← workspace definition
├── snake-ml/          ← ML library (encoding, network, training)
│   ├── Cargo.toml
│   └── src/
│       └── lib.rs
└── snake-server/       ← HTTP server (serves /move, runs the snake)
    ├── Cargo.toml
    └── src/
        └── main.rs

Every file is created in this tutorial. No external repositories.


Next: Part 1 — Your Snake’s First Move — We start with what Battlesnake sends you and what a move response looks like. No ML yet — enough to have a working web server that the game engine can call.

Part 1 — Your Snake’s First Move

Before we build anything smart, we need a server that the game engine can actually talk to. BattleSnake is a web server: the game engine sends you an HTTP request every turn, and you send back a direction. That’s the whole contract.

Let’s start there.

Creating the project

We’ll use a Rust workspace with two crates:

  • snake-ml — the ML library (board encoding, network, training)
  • snake-server — the HTTP server that responds to the game engine

Create the directory structure and all files from scratch. No git required.

Root Cargo.toml (workspace definition):

[workspace]
resolver = "2"
members = [
  "snake-server",
  "snake-ml",
]

snake-ml/Cargo.toml:

[package]
name = "snake-ml"
version = "0.1.0"
edition = "2021"

[dependencies]
candle-core = "0.6"
candle-nn = "0.6"
# Pinned to avoid a rand version conflict with candle-core's internals.
half = "=2.4.0"
serde = { version = "1", features = ["derive"] }
serde_json = "1"

Version note. This tutorial targets Candle 0.6. Candle is a fast-moving library — the latest version is 0.10 as of June 2026, and the API has changed significantly between versions (tensor methods, training loop API, optimizer signatures). The code examples in this tutorial compile with 0.6. If you use a newer version, expect method name changes and signature differences. The concepts (tensor operations, forward passes, backpropagation, training loops) are the same across versions — it’s the API surface that shifts.

#![allow(unused)]
fn main() {
//! ML components for the Battlesnake AI.

use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Point {
    pub x: i32,
    pub y: i32,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Snake {
    pub id: String,
    pub body: Vec<Point>,
    pub health: u32,
    pub head: Point,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Board {
    pub width: u32,
    pub height: u32,
    #[serde(default)]
    pub food: Vec<Point>,
    #[serde(default)]
    pub snakes: Vec<Snake>,
    #[serde(default)]
    pub hazards: Vec<Point>,
}
}

snake-server/Cargo.toml:

[package]
name = "snake-server"
version = "0.1.0"
edition = "2021"

[dependencies]
actix-web = "4"
log = "0.4"
env_logger = "0.11"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
snake-ml = { path = "../snake-ml" }

snake-server/src/main.rs — we’ll build this out in the sections below.

To verify everything compiles:

cd battlesnake
cargo build

If it builds, the project structure is correct. Now let’s write the server.

What the game sends you

Every turn, the engine POSTs JSON to your /move endpoint. It looks like this:

{
  "game": {
    "id": "totally-unique-game-id",
    "ruleset": { "name": "standard", "version": "v1.1.15" },
    "timeout": 500
  },
  "turn": 14,
  "board": {
    "height": 11,
    "width": 11,
    "food":    [{ "x": 5, "y": 5 }],
    "hazards": [{ "x": 3, "y": 2 }],
    "snakes":  [
      {
        "id": "snake-508e96ac-94ad-11ea-bb37",
        "name": "My Snake",
        "health": 54,
        "body": [{ "x": 0, "y": 0 }, { "x": 1, "y": 0 }, { "x": 2, "y": 0 }],
        "head": { "x": 0, "y": 0 },
        "length": 3
      },
      {
        "id": "snake-b67f4906-94ae-11ea-bb37",
        "name": "Enemy",
        "health": 16,
        "body": [{ "x": 5, "y": 4 }, { "x": 5, "y": 3 }, { "x": 6, "y": 3 }, { "x": 6, "y": 2 }],
        "head": { "x": 5, "y": 4 },
        "length": 4
      }
    ]
  },
  "you": {
    "id": "snake-508e96ac-94ad-11ea-bb37",
    "name": "My Snake",
    "health": 54,
    "body": [{ "x": 0, "y": 0 }, { "x": 1, "y": 0 }, { "x": 2, "y": 0 }],
    "head": { "x": 0, "y": 0 },
    "length": 3
  }
}

That’s a lot. The important parts:

  • board — the whole game state. width and height are the board dimensions (usually 11×11, but can be larger). food is where the food pieces are. snakes is every snake in the game, including yours.
  • you — your snake specifically. It has a head (the front square), a body (all the squares), and health (goes down every turn unless you eat).
  • turn — what turn the game is on. Starts at 0.

The game runs on a 0-indexed grid. (0, 0) is the bottom-left corner. (10, 10) would be the top-right on an 11×11 board. The x-axis goes right; the y-axis goes up. This is a standard Cartesian coordinate system, not screen coordinates — “up” adds to y because y increases upward. When you move up, your snake’s head goes from (3, 2) to (3, 3). When you move down, it goes from (3, 2) to (3, 1).

What you send back

Simple:

{ "move": "up" }

The move is one of "up", "down", "left", or "right". That’s the entire response contract.

A minimal server

We’ll use Actix Web — it’s lightweight and straightforward. We already added it to snake-server/Cargo.toml above.

Here’s a working server that always moves up:

use actix_web::{web, App, HttpServer, HttpResponse};
use serde::{Deserialize, Serialize};
use snake_ml::{Board, Point, Snake};

// The request body the game sends you every turn
#[derive(Deserialize)]
struct MoveRequest {
    board: Board,
    you: Snake,
}

// What you send back — the direction
#[derive(Serialize)]
struct MoveResponse {
    #[serde(rename = "move")]
    direction: String,
}

// The POST /move endpoint
async fn move_handler(req: web::Json<MoveRequest>) -> HttpResponse {
    let _my_head = &req.you.head;

    // For now, always move up. The snake will walk into the wall
    // eventually, but we have a working server.
    let direction = "up";

    HttpResponse::Ok().json(MoveResponse {
        direction: direction.to_string(),
    })
}

#[actix_web::main]
async fn main() -> std::io::Result<()> {
    HttpServer::new(|| {
        App::new()
            .route("/move", web::post().to(move_handler))
    })
    .bind(("0.0.0.0", 8080))?
    .run()
    .await
}

The server imports Board, Point, and Snake from the snake-ml crate we set up above — that’s the whole point of the workspace. Those types are shared between the server and the ML library. The server-specific types (MoveRequest, MoveResponse) stay here because they’re HTTP concerns, not game logic.

Run it:

cargo run

You now have a server on port 8080. The Battlesnake engine can call it if it’s publicly reachable.

For local development, install the official Battlesnake rules engine and run games against your local server without any tunnel or public URL:

# Start your server in one terminal
cargo run -p snake-server

# Run a local game in another terminal
snake game create --width 11 --height 11 --seed 42 \
  http://localhost:8080 \
  http://localhost:9090

The rules engine is also what Parts 6–7 use for the self-play training loop, so installing it early means you’re ready when we get there.

Either way: that’s a working server. Not a good one, but a working one.

The game loop

Here’s what the full game loop looks like across your snake’s lifetime:

  1. Engine starts the game — POST to /start with game info. You can acknowledge this; we don’t need to do much here.
  2. Each turn, engine calls /move — you receive the board state, decide on a direction, send it back. The engine moves the snakes, spawns food, checks for collisions.
  3. Engine ends the game — POST to /end. Clean up if you need to.

/end receives the same JSON as /move. /start is different — it fires once at the beginning of a game and sends only the basics: board size, your snake ID, and the timeout. It looks like this:

{
  "gameId": "a3w4d2-6f91-11ef-bc0f-9fe5aa27e6d3",
  "width": 11,
  "height": 11,
  "snakeId": "snake-508e96ac-94ad-11ea-bb37",
  "snakeName": "My Snake",
  "timeout": 500
}

The response for /start and /end is an empty {}. Return it and move on.

For this tutorial we only need /move, but the arena engine requires all three endpoints. Here’s how to wire them up:

#![allow(unused)]
fn main() {
// A minimal type for what /start sends
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct StartRequest {
    game_id: String,
    width: u32,
    height: u32,
    #[serde(default)]
    timeout: u32,
}

async fn start_handler(_req: web::Json<StartRequest>) -> HttpResponse {
    HttpResponse::Ok().json(serde_json::json!({}))
}

// /end uses the same MoveRequest as /move
async fn end_handler(_req: web::Json<MoveRequest>) -> HttpResponse {
    HttpResponse::Ok().json(serde_json::json!({}))
}

// In main():
App::new()
    .route("/start", web::post().to(start_handler))
    .route("/move", web::post().to(move_handler))
    .route("/end", web::post().to(end_handler))
}

Making the snake a little smarter

Moving straight up is a fast way to die. Let’s at least look at where the head is and pick a direction that doesn’t immediately hit a wall:

#![allow(unused)]
fn main() {
async fn move_handler(req: web::Json<MoveRequest>) -> HttpResponse {
    let head = &req.you.head;
    let width = req.board.width as i32;
    let height = req.board.height as i32;

    // Check which directions are safe
    let mut safe_moves = Vec::new();

    if head.y + 1 < height {
        safe_moves.push("up");
    }
    if head.y - 1 >= 0 {
        safe_moves.push("down");
    }
    if head.x - 1 >= 0 {
        safe_moves.push("left");
    }
    if head.x + 1 < width {
        safe_moves.push("right");
    }

    // Pick the first safe direction (up, then down, left, right).
    // This isn't random — it's deterministic and predictable. Good enough
    // to avoid immediate walls, bad enough to motivate the neural network.
    let direction = safe_moves
        .into_iter()
        .next()
        .unwrap_or("up");

    HttpResponse::Ok().json(MoveResponse {
        direction: direction.to_string(),
    })
}
}

This snake will wander toward a wall, then wander along it. It’s better than always moving up, but it has no strategy. It doesn’t know what food is, doesn’t avoid other snakes, doesn’t plan ahead.

That’s what the neural network is for. Before we get there, we need to figure out how to represent the board state in a way the network can understand — which is what the next part is about.

Next: Part 2 — The Board as Numbers

Part 2 — The Board as Numbers

Let’s talk about what happens between “the game sends you JSON” and “the network decides a move.” The bridge between those two things is the board encoding — converting the structured game state into a grid of numbers that a neural network can read.

If Part 1 was about the API contract, this part is about the data contract. What format does the network actually need?

What a tensor is

The game sends you a JSON object with lists, nested structs, strings — rich structure. The network wants numbers. Specifically, it wants a tensor.

A tensor is a multi-dimensional array of numbers. A 1D tensor is a vector. A 2D tensor is a matrix. A 3D tensor is a stack of matrices. That’s all.

In Candle, a tensor is candle_core::Tensor. Under the hood it’s a blob of f32 values with a shape. The shape (3, 11, 11) means “3 channels, each 11 rows by 11 columns” — a stack of three 11×11 grids.

That’s our board representation.

Feature planes

The trick is deciding what goes in each channel. Each channel is a grid the same size as the board, where each cell is either 0 or 1 (or sometimes a number like health / 100).

Here’s what we’ll use:

ChannelContents
0Open space (1 for every valid board square)
1Food (1 on food squares)
2My body excluding head (1 on body squares that aren’t the head)
3My head (1 on my head square)
4Enemy bodies (1 on all enemy body squares)
5Enemy heads (1 on all enemy head squares)
6Hazard (dangerous squares — 1 where damage applies)
7My health normalized (my_health / 100, repeated across all squares)

Seven channels, plus a health scalar that applies uniformly. The network gets a (8, height, width) tensor.

Why this layout? Three things worth explaining:

Channel 0 — open space. Without this, the network couldn’t tell the difference between a valid board square and a position outside the board (both would be 0 everywhere). Setting every valid square to 1.0 makes the board boundary explicit. This channel is the only thing preventing the network from hallucinating moves off the board.

Channels 2 and 3 — body without head, head alone. The head square is part of the body in the game data — it appears in body and as head. But we split them across channels. Channel 2 holds the body excluding the head. Channel 3 holds only the head. This gives the network a clean signal: channel 2 is “squares I can’t move to” and channel 3 is “where I am right now.” If we included the head in both channels, the network would need to learn that the overlap is the same position — splitting them makes the representation unambiguous.

Channel 7 — health everywhere. Health is a single scalar on the Snake struct, but the network needs it at every input position. We broadcast it across the entire channel — every square gets the same value. This is a little wasteful but keeps the tensor shape regular, and it works: the network learns “low health → take fewer risks” the same way it learns spatial patterns.

A concrete example

Let’s encode a tiny 5×5 board to see it in action:

Board (5×5, Battlesnake coordinate system — y increases upward):

  x: 0  1  2  3  4
y:4  .  .  .  ⭐ .  (food at (3,4))
   3  .  .  .  .  .
   2  .  🐍 🐍 .  .  (body at (1,2) and (2,2))
   1  .  .  🐍 .  .  (head at (2,1))
   0  .  .  .  .  .

My snake: head at (2,1), body: (2,1), (2,2), (1,2)
No enemies. No hazards.

Encoded (one channel at a time):

Channel 0 (open space):
1 1 1 1 1    ← y=4
1 1 1 1 1
1 1 1 1 1
1 1 1 1 1
1 1 1 1 1    ← y=0, every valid square is 1.0

Channel 1 (food):
0 0 0 1 0    ← y=4, (3,4) has food
0 0 0 0 0
0 0 0 0 0
0 0 0 0 0
0 0 0 0 0    ← y=0

Channel 2 (my body, excluding head):
0 0 0 0 0    ← y=4
0 0 0 0 0
0 1 1 0 0    ← y=2, (1,2) and (2,2) — body segments only
0 0 0 0 0    ← y=1, head at (2,1) is NOT in this channel
0 0 0 0 0    ← y=0

Channel 3 (my head only):
0 0 0 0 0    ← y=4
0 0 0 0 0
0 0 0 0 0
0 0 1 0 0    ← y=1, (2,1) — head gets its own channel
0 0 0 0 0    ← y=0

Channels 4–6 (enemies, hazards): all zeros

The network sees a 5×5×8 blob of numbers (8 channels, 5 rows, 5 columns). It learns to associate patterns in these channels with good and bad moves.

(Real games use 11×11 boards. The same encoding works — more zeros around the edges.)

The implementation

We’ll put the encoder in snake-ml/src/lib.rs. It takes the deserialized Board and Snake structs and produces a Tensor:

#![allow(unused)]
fn main() {
use candle_core::{Device, Result, Tensor};
use serde::Deserialize;

#[derive(Debug, Clone, Deserialize)]
pub struct Point {
    pub x: i32,
    pub y: i32,
}

#[derive(Debug, Clone, Deserialize)]
pub struct Snake {
    pub id: String,
    pub body: Vec<Point>,
    pub health: u32,
    pub head: Point,
}

#[derive(Debug, Clone, Deserialize)]
pub struct Board {
    pub width: u32,
    pub height: u32,
    #[serde(default)]
    pub food: Vec<Point>,
    #[serde(default)]
    pub snakes: Vec<Snake>,
    #[serde(default)]
    pub hazards: Vec<Point>,
}

/// Convert a board + my_snake into a feature-plane tensor.
/// Shape: (8, height, width)
/// Channels: open_space, food, my_body, my_head, enemy_bodies, enemy_heads, hazards, health
pub fn encode_board(board: &Board, my_snake: &Snake) -> Result<Tensor> {
    let h = board.height as usize;
    let w = board.width as usize;
    let n = h * w;

    // Flat buffer: 8 channels × h × w
    let mut data = vec![0.0_f32; 8 * n];

    for y in 0..h {
        for x in 0..w {
            let idx = y * w + x;

            // Channel 0: open space — set to 1.0 for all valid squares
            data[idx] = 1.0;
            if board.food.iter().any(|f| f.x == x as i32 && f.y == y as i32) {
                data[n + idx] = 1.0;
            }

            // My body (channel 2) — body segments excluding head.
            // The head gets its own channel (3), so channel 2 is pure "squares I occupy but am not at."
            if my_snake.body.iter().any(|p| {
                p.x == x as i32 && p.y == y as i32
                    && !(p.x == my_snake.head.x && p.y == my_snake.head.y)
            }) {
                data[2 * n + idx] = 1.0;
            }

            // My head (channel 3)
            if my_snake.head.x == x as i32 && my_snake.head.y == y as i32 {
                data[3 * n + idx] = 1.0;
            }

            // Enemy bodies (channel 4) and heads (channel 5)
            for snake in &board.snakes {
                if snake.id == my_snake.id {
                    continue;
                }
                for point in &snake.body {
                    if point.x == x as i32 && point.y == y as i32 {
                        data[4 * n + idx] = 1.0;
                    }
                }
                if snake.head.x == x as i32 && snake.head.y == y as i32 {
                    data[5 * n + idx] = 1.0;
                }
            }

            // Hazards (channel 6)
            if board.hazards.iter().any(|h| h.x == x as i32 && h.y == y as i32) {
                data[6 * n + idx] = 1.0;
            }
        }
    }

    // Channel 7: my health normalized [0, 1]
    let health = my_snake.health as f32 / 100.0;
    for idx in 0..n {
        data[7 * n + idx] = health;
    }

    // Build the tensor: shape (8, h, w), on CPU
    let dev = Device::Cpu;
    let tensor = Tensor::from_vec(data, (8, h, w), &dev)?;

    Ok(tensor)
}
}

Run it and you’ll get a tensor with shape [8, height, width]. Candle tensors are row-major: the first dimension is the channel axis.

One thing worth calling out: we encode hazards even though they’re usually empty in standard BattleSnake rules. We include them because some game modes use them, and it costs us nothing to handle the general case.

Flattening for the MLP

The network in Part 3 is a multi-layer perceptron (MLP) — it takes a flat input vector, not a 3D tensor. So after encoding, we’ll flatten:

#![allow(unused)]
fn main() {
let flat = tensor.flatten_all()?;  // (8*h*w,) = (968,) for an 11×11 board
// To add a batch dimension for the network:
// let batched = flat.unsqueeze(0)?;  // (1, 968)
}

For an 11×11 board with 8 channels, that’s 8 * 11 * 11 = 968 input values. Manageable.

We’ll use this flatten-and-encode step often enough that it’s worth a helper:

#![allow(unused)]
fn main() {
/// Encode and flatten. Shape: (8 * H * W,) = (968,) for an 11×11 board.
pub fn encode_board_flat(board: &Board, my_snake: &Snake) -> Result<Vec<f32>> {
    let tensor = encode_board(board, my_snake)?;
    tensor.flatten_all()?.to_vec1()
}

/// Encode and return flat encoding with a unit marker for tuple destructuring.
/// The () is a convenience — some call sites destructure with `let (flat, _) = ...`
/// when they only need the flat vector but the API shape allows future extension.
pub fn encode_board_and_flat(
    board: &Board,
    my_snake: &Snake,
) -> Result<(Vec<f32>, ())> {
    encode_board_flat(board, my_snake).map(|v| (v, ()))
}
}

These are the three encoding functions you’ll see throughout the tutorial: encode_board (returns a Tensor), encode_board_flat (returns a Vec<f32>), and encode_board_and_flat (returns a (Vec<f32>, ()) for destructuring convenience in the self-play training loop).

If this feels wasteful — we’re duplicating the health channel across all squares — you’re right. Part 9 will show a better architecture using convolutional layers that share spatial weights. But for now, the flat representation works fine and keeps the code simple.

What we have

JSON game state  ──encode_board()──▶  Tensor(8, 11, 11)
                                          │
                                          ▼ flatten
                                     Vec<f32> (968 values)
                                          │
                                          ▼ Part 3 network
                                     Move probabilities

The encoding is deterministic and reversible — same board state always produces the same tensor. That’s important: the network can only learn from signal in the data, and signal requires consistency.

Next up: we feed that 968-element vector into a neural network and see what comes out.

Next: Part 3 — Your First Network

Part 3 — Your First Network

Last time we turned the board into 968 numbers. Now we feed those numbers into a neural network and get a move out.

Let’s demystify what a neural network is actually doing in this context.

What a network is

A neural network is a function. It takes a list of numbers in and spits a list of numbers out. That’s all. The “neural” part is a particular way of composing simple mathematical functions — linear transforms followed by nonlinear activations — so that the whole thing can be trained to approximate any mapping.

For us: input is 968 numbers (the flattened board encoding). Output is 4 numbers — the “score” for each of the four directions. We pick the highest score.

The intermediate layers are where the magic happens. Each layer is a set of “units” (neurons) that look for patterns in the previous layer’s output. Early layers tend to learn simple spatial patterns (edges, food nearby). Later layers combine those into higher-level concepts (a path to food, a trap closing in).

The architecture

We’ll build a simple multi-layer perceptron (MLP) — three linear layers with ReLU (Rectified Linear Unit — the activation function that outputs max(0, x)) activations between them:

Input (968) → Linear(968 → 256) → ReLU
           → Linear(256 → 128) → ReLU
           → Linear(128 → 4)   → logits

The final layer has 4 outputs (one per direction) without an activation — we apply softmax afterward to turn logits into probabilities.

In Candle:

#![allow(unused)]
fn main() {
use candle_nn::{linear, Linear, Module, VarBuilder};
use candle_core::{Device, Result, Tensor};

pub struct SnakeNet {
    // Three linear layers
    l1: Linear,
    l2: Linear,
    l3: Linear,
}

impl SnakeNet {
    pub fn new(vs: VarBuilder) -> Result<Self> {
        // 11 * 11 * 8 = 968 input features
        let l1 = linear(968, 256, vs)?;
        let l2 = linear(256, 128, vs)?;
        let l3 = linear(128, 4, vs)?;  // 4 directions
        Ok(Self { l1, l2, l3 })
    }

    /// Forward pass: flat_input has shape (batch, 968)
    pub fn forward(&self, flat_input: &Tensor) -> Result<Tensor> {
        let x = self.l1.forward(flat_input)?.relu()?;
        let x = self.l2.forward(&x)?.relu()?;
        self.l3.forward(&x)
    }

    /// Apply softmax with optional temperature.
    /// Higher temperature → softer, more exploratory distributions.
    /// Lower temperature → sharper, more greedy.
    /// A few useful reference points:
    ///   - temperature = 1.0  → standard softmax (default)
    ///   - temperature = 0.5  → sharper, more confident predictions
    ///   - temperature = 2.0  → flatter, more random-looking distributions
    /// This matters in Part 6 when we add entropy regularization.
    pub fn probs(&self, flat_input: &Tensor, temperature: f32) -> Result<Tensor> {
        let logits = self.forward(flat_input)?;
        let scaled = if (temperature - 1.0).abs() > 1e-5 {
            // Scale by 1/temperature. Create a (1,1) tensor that broadcasts
            // to (batch, 4) — Candle broadcasts (1,1) to (N, M) but (1,)
            // doesn't broadcast to (N, M) because the rank differs.
            let scale = Tensor::new(&[[1.0_f32 / temperature]], &Device::Cpu)?;
            (&logits * &scale)?
        } else {
            logits
        };
        candle_nn::ops::softmax(&scaled, 1)
    }

    /// Pick the direction with the highest logit score (greedy).
    pub fn pick_direction(&self, flat_input: &Tensor) -> Result<Direction> {
        let logits = self.forward(flat_input)?;
        // argmax(1) on (batch, 4) → (batch,). squeeze(0) → scalar.
        // Without squeeze, to_scalar() fails — it expects a 0-D tensor, not (1,).
        let idx = logits.argmax(1)?.squeeze(0)?.to_scalar::<u32>()? as usize;
        Direction::from_index(idx).ok_or_else(|| {
            candle_core::Error::Msg(format!("invalid direction index: {idx}"))
        })
    }
}

#[derive(Debug, Clone, Copy)]
pub enum Direction {
    Up, Down, Left, Right,
}

impl Direction {
    pub fn from_index(idx: usize) -> Option<Self> {
        match idx {
            0 => Some(Self::Up),
            1 => Some(Self::Down),
            2 => Some(Self::Left),
            3 => Some(Self::Right),
            _ => None,
        }
    }

    pub fn as_str(&self) -> &'static str {
        match self {
            Self::Up    => "up",
            Self::Down  => "down",
            Self::Left  => "left",
            Self::Right => "right",
        }
    }

    /// Direction as a 0-based index (Up=0, Down=1, Left=2, Right=3).
    /// Used by the training loop to record actions.
    pub fn index(&self) -> usize {
        match self {
            Self::Up    => 0,
            Self::Down  => 1,
            Self::Left  => 2,
            Self::Right => 3,
        }
    }
}
}

To load from saved weights later:

#![allow(unused)]
fn main() {
use candle_nn::{VarBuilder, VarMap};
use candle_core::{DType, Device};
use std::sync::Arc;

pub struct Model {
    device: Arc<Device>,
    var_map: VarMap,
    net: SnakeNet,
}

impl Model {
    /// Build a new model with fresh random weights
    pub fn new() -> Result<Self> {
        let device = Device::Cpu;
        let var_map = VarMap::new();
        let vs = VarBuilder::from_varmap(&var_map, DType::F32, &device);
        let net = SnakeNet::new(vs)?;
        Ok(Self { device: Arc::new(device), var_map, net })
    }

    /// Load weights from a file (for Part 8)
    pub fn load(path: &str) -> Result<Self> {
        let device = Device::Cpu;
        let mut var_map = VarMap::new();
        var_map.load(path)?; // in-place load
        let vs = VarBuilder::from_varmap(&var_map, DType::F32, &device);
        let net = SnakeNet::new(vs)?;
        Ok(Self { device: Arc::new(device), var_map, net })
    }

    pub fn pick_direction(&self, flat_input: &Tensor) -> Result<Direction> {
        self.net.pick_direction(flat_input)
    }
}
}

Run it on the encoded board:

#![allow(unused)]
fn main() {
use snake_ml::{encode_board, Model, Direction};

fn decide_move(board: &Board, my_snake: &Snake) -> Result<Direction, candle_core::Error> {
    let model = Model::new()?;
    let tensor = encode_board(board, my_snake)?;
    let flat = tensor.flatten_all()?;
    // Add batch dimension: (1, 968)
    let flat = flat.unsqueeze(0)?;
    model.pick_direction(&flat)
}
}

And that’s the forward pass working end-to-end.

What the network learns (and what it doesn’t)

Here’s the thing: a freshly initialized network is worthless. The weights start random. The output is gibberish. A random network will pick moves about as well as a dart-throwing monkey.

The network needs two things before it becomes useful:

  1. Architecture — we’ve built that. The shape of the function.
  2. Weights — the numbers inside. Right now they’re random. They need to be trained.

Training is the process of adjusting the weights to make the network output good moves. There are two ways to train this network:

  • Imitation learning (Part 5): generate data from a teacher (the heuristic from Part 4), then train the network to copy the teacher. This gets us to “decent” fast.
  • Reinforcement learning (Parts 6–7): let the snake play, learn from wins and losses, discover strategies the teacher never thought of.

We’ll tackle imitation learning first — it’s simpler, more predictable, and produces a solid baseline to build on.

But before all that, Part 4 builds the teacher. The heuristic snake that will feed us training data. That’s where A* pathfinding comes in.

Previous: Part 2 — The Board as Numbers · Next: Part 4 — The Heuristic Baseline

Part 4 — The Heuristic Baseline

Here’s the situation: we have a neural network that takes a board encoding and outputs move probabilities. But the weights are random. It’s useless.

Before we can train the network to play well, we need a teacher — a snake that makes smart enough decisions to generate training data. That’s what we build in this part.

The teacher doesn’t have to be brilliant. It needs to be better than random. We’ll use a simple rule-based snake that:

  1. Never moves into a wall, itself, or another snake (survival)
  2. Uses A* pathfinding to find a path to the nearest food

This becomes the “ground truth” for imitation learning in Part 5.

A* pathfinding

A* (pronounced “A-star”) is a graph search algorithm that finds the shortest path between two points. It’s the same algorithm used in video games for character navigation and in GPS systems for route planning.

The core idea: start at the current position, explore neighbors by priority, where priority is cost so far + estimated distance to goal. Always expand the lowest-priority node next.

open_set: priority queue sorted by (g + h)
closed_set: positions we've already explored

start at current position
while open_set not empty:
    pop the position with lowest priority
    if it's the goal: reconstruct path, done
    for each neighbor:
        if neighbor not in closed_set and not blocked:
            tentative_g = g(current) + 1
            if this path to neighbor is better than any previous:
                update parent, update g score
                push neighbor to open_set

The heuristic h (estimated distance to goal) is what makes A* admissible — it never overestimates, so the solution it finds is guaranteed to be optimal. For a grid, Manhattan distance works: |x1 - x2| + |y1 - y2|.

Here’s the implementation in pure Rust, no external crates:

#![allow(unused)]
fn main() {
use std::collections::{BinaryHeap, HashMap, HashSet};
use std::cmp::Ordering;

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Pos {
    pub x: i32,
    pub y: i32,
}

impl Pos {
    pub fn manhattan(&self, other: &Pos) -> i32 {
        (self.x - other.x).abs() + (self.y - other.y).abs()
    }
}

// Node in the A* search
#[derive(Clone)]
struct Node {
    pos: Pos,
    g: i32,      // cost from start
    f: i32,      // g + h
}

impl PartialEq for Node {
    fn eq(&self, other: &Self) -> bool { self.f == other.f }
}

impl Eq for Node {}

impl Ord for Node {
    fn cmp(&self, other: &Self) -> Ordering {
        // BinaryHeap is a max-heap; we want min-f, so flip the comparison
        other.f.cmp(&self.f)
    }
}

impl PartialOrd for Node {
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
        Some(self.cmp(other))
    }
}

/// A* pathfinder for a rectangular grid.
/// Returns the first step toward the goal, or None if no path exists.
pub struct Astar {
    width: i32,
    height: i32,
    blocked: HashSet<Pos>,
}

impl Astar {
    pub fn new(width: i32, height: i32) -> Self {
        Self {
            width,
            height,
            blocked: HashSet::new(),
        }
    }

    /// Mark a position as impassable (walls, other snakes' bodies)
    pub fn block(&mut self, pos: Pos) {
        self.blocked.insert(pos);
    }

    /// Find the first step toward `goal` from `start`.
    /// Returns `None` if no path exists.
    pub fn step(&self, start: &Pos, goal: &Pos) -> Option<Pos> {
        let mut open_set = BinaryHeap::new();
        let mut came_from: HashMap<Pos, Pos> = HashMap::new();
        let mut g_score: HashMap<Pos, i32> = HashMap::new();
        let mut closed: HashSet<Pos> = HashSet::new();

        g_score.insert(start.clone(), 0);
        open_set.push(Node {
            pos: start.clone(),
            g: 0,
            f: start.manhattan(goal),
        });

        while let Some(Node { pos, g, f: _ }) = open_set.pop() {
            if pos == *goal {
                // Reconstruct: find the step right after `start`
                return self.reconstruct_first_step(&came_from, start, goal);
            }
            if closed.contains(&pos) {
                continue;
            }
            closed.insert(pos.clone());

            for neighbor in self.neighbors(&pos) {
                if closed.contains(&neighbor) || self.blocked.contains(&neighbor) {
                    continue;
                }
                let tentative_g = g + 1;
                let is_better = g_score
                    .get(&neighbor)
                    .map(|current| tentative_g < *current)
                    .unwrap_or(true);

                if is_better {
                    came_from.insert(neighbor.clone(), pos.clone());
                    g_score.insert(neighbor.clone(), tentative_g);
                    open_set.push(Node {
                        pos: neighbor,
                        g: tentative_g,
                        f: tentative_g + neighbor.manhattan(goal),
                    });
                }
            }
        }

        None // no path found
    }

    fn neighbors(&self, pos: &Pos) -> Vec<Pos> {
        let mut n = Vec::with_capacity(4);
        if pos.y + 1 < self.height { n.push(Pos { x: pos.x, y: pos.y + 1 }); } // up
        if pos.y - 1 >= 0         { n.push(Pos { x: pos.x, y: pos.y - 1 }); } // down
        if pos.x - 1 >= 0         { n.push(Pos { x: pos.x - 1, y: pos.y }); } // left
        if pos.x + 1 < self.width { n.push(Pos { x: pos.x + 1, y: pos.y }); } // right
        n
    }

    fn reconstruct_first_step(&self, came_from: &HashMap<Pos, Pos>, start: &Pos, goal: &Pos) -> Option<Pos> {
        // Walk backward from goal to start, then return the step after start
        let mut current = goal.clone();
        while let Some(parent) = came_from.get(&current) {
            if parent == start {
                return Some(current);
            }
            current = parent.clone();
        }
        None
    }
}
}

Note that we store blocked positions inside Astar. In practice you’ll create a fresh Astar for each query and populate it from the current game state. That’s fine — the search space is small (11×11 = 121 nodes).

Survival rules

One subtle edge case: when a snake moves, its tail vacates a square before the head arrives there. If you’re checking safety against the current board state, you might incorrectly think the tail is still occupying its current square. To handle this correctly, the board representation needs a next_state() method that simulates the next tick — or equivalently, treat the last segment of each snake’s body as not blocking for the current move (it’ll vacate). The survival rule in safe_directions below doesn’t need this because it only looks at walls and other snakes’ bodies directly; the game loop handles tail movement when it updates the board after each tick.

A* gets us to food, but it’s not enough on its own. The snake needs to also avoid immediately dying. Before following the A* path, we check which directions are immediately safe:

#![allow(unused)]
fn main() {
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Direction { Up, Down, Left, Right }

impl Direction {
    pub fn delta(&self) -> (i32, i32) {
        match self {
            Self::Up    => (0, 1),
            Self::Down  => (0, -1),
            Self::Left  => (-1, 0),
            Self::Right => (1, 0),
        }
    }
}

pub struct HeuristicSnake {
    width: i32,
    height: i32,
}

impl HeuristicSnake {
    pub fn new(width: i32, height: i32) -> Self {
        Self { width, height }
    }

    /// Pick the best direction using survival + A* pathfinding to nearest food
    pub fn decide(&self, my_snake: &Snake, board: &Board) -> Direction {
        let my_head = Pos { x: my_snake.head.x, y: my_snake.head.y };

        // 1. Find immediate safe moves (not walls, not bodies, not enemy heads)
        let safe = self.safe_directions(&my_head, board, my_snake);

        // 2. If no safe moves, we have to pick something (we'll die)
        if safe.is_empty() {
            return Direction::Up;
        }

        // 3. Find nearest food
        let nearest_food = board.food.iter()
            .map(|f| Pos { x: f.x, y: f.y })
            .min_by_key(|f| my_head.manhattan(f))
            .unwrap_or(Pos { x: self.width / 2, y: self.height / 2 });

        // 4. Build A* pathfinder with current snake bodies blocked
        let mut astar = Astar::new(self.width, self.height);

        // Block all body squares (including our own)
        for snake in &board.snakes {
            for segment in &snake.body {
                astar.block(Pos { x: segment.x, y: segment.y });
            }
        }

        // 5. Find A* path to food
        let target = astar.step(&my_head, &nearest_food);

        // 6. Choose the safe direction that gets us closest to food
        let best = target.and_then(|goal| {
            // Direction that moves toward the first step
            self.step_toward(&my_head, &goal)
        });

        match best {
            Some(dir) if safe.contains(&dir) => dir,
            _ => {
                // Fall back to the first safe direction
                safe.into_iter().next().unwrap()
            }
        }
    }

    fn safe_directions(&self, head: &Pos, board: &Board, my_snake: &Snake) -> Vec<Direction> {
        let mut safe = Vec::with_capacity(4);

        for dir in [Direction::Up, Direction::Down, Direction::Left, Direction::Right] {
            let (dx, dy) = dir.delta();
            let next = Pos { x: head.x + dx, y: head.y + dy };

            // Out of bounds
            if next.x < 0 || next.x >= self.width || next.y < 0 || next.y >= self.height {
                continue;
            }

            // Own body (any body segment)
            if my_snake.body.iter().any(|s| s.x == next.x && s.y == next.y) {
                continue;
            }

            // Any enemy body
            let body_hit = board.snakes.iter()
                .filter(|s| s.id != my_snake.id)
                .any(|s| s.body.iter().any(|b| b.x == next.x && b.y == next.y));

            if body_hit {
                continue;
            }

            safe.push(dir);
        }

        safe
    }

    fn step_toward(&self, from: &Pos, goal: &Pos) -> Option<Direction> {
        if goal.x > from.x { Some(Direction::Right) }
        else if goal.x < from.x { Some(Direction::Left) }
        else if goal.y > from.y { Some(Direction::Up) }
        else if goal.y < from.y { Some(Direction::Down) }
        else { None }
    }
}
}

Notice what we’re blocking: body segments (not only the head). In BattleSnake, you die if you move into any snake’s body. Enemy heads are passable (you can move into the square in front of an enemy head — the restriction is the neck, not the head itself).

This heuristic is competitive. It will beat a random snake every time. It knows where food is, it avoids dying immediately, and it plans paths rather than stumbling toward food.

What it’s missing

The heuristic is good, but it’s not perfect. It has blind spots:

  • Head-to-head collisions: If two snakes are moving toward each other, both might see the space in front of the opponent’s head as valid. A* doesn’t know about this.
  • Long snakes closing off paths: A* finds a path to food, but the path might go through an area that fills up with body segments on future turns.
  • Health management: It always goes for the nearest food, regardless of health. If you’re at full health and there’s food nearby, wandering toward food might be a waste of a turn.

Part 5 uses this heuristic to generate training data anyway. It’s good enough to teach the network something useful. And the network, trained on this data, will learn to avoid some of the heuristic’s blind spots by seeing patterns the heuristic couldn’t.

Wiring it into the server

Back in snake-server/src/main.rs, swap out the wall-check for the heuristic:

#![allow(unused)]
fn main() {
use snake_server::HeuristicSnake;

async fn move_handler(req: web::Json<MoveRequest>) -> HttpResponse {
    let snake = HeuristicSnake::new(req.board.width, req.board.height);
    let direction = snake.decide(&req.you, &req.board);

    HttpResponse::Ok().json(MoveResponse {
        direction: direction.as_str().to_string(),
    })
}
}

Run it, point it at BattleSnake, and watch it play. It won’t be world-class, but it’ll survive longer than anything random.

That’s the teacher. Next: we use it to generate training data and train the network to copy it.

Previous: Part 3 — Your First Network · Next: Part 5 — Imitation Learning

Part 5 — Imitation Learning

We have a working heuristic. We have a network. Now we connect them: generate training data from the heuristic, then train the network to predict the same moves.

The idea is called imitation learning or behavioral cloning. It’s supervised learning: for each board state, we want the network to predict the direction the heuristic chose. We’re training it to clone a teacher.

Collecting training data

We need a function that runs a game with the heuristic and records each state:

#![allow(unused)]
fn main() {
use snake_ml::{Board, Snake, Direction, encode_board};
use crate::heuristic::HeuristicSnake;
use serde::{Deserialize, Serialize};

/// One training example: a board encoding and the teacher's chosen direction
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingExample {
    /// Flat encoded board state: (8 * height * width,) f32
    pub board: Vec<f32>,
    /// Direction chosen by the teacher (0=up, 1=down, 2=left, 3=right)
    pub label: u32,
}

/// Run one game with the heuristic teacher and collect all examples.
///
/// Note: `advance_snakes` is a simplified game simulation. A full implementation
/// needs to handle movement, food eating, body growth, starvation, and collision
/// detection. The `GameEnv` in Part 6 shows a complete implementation you can use
/// instead of this placeholder.
pub fn collect_game(board: &Board, heuristic: &HeuristicSnake) -> Vec<TrainingExample> {
    let mut examples = Vec::new();
    let mut snakes = board.snakes.clone();
    let mut turn = 0u32;

    // Simulate up to 200 turns
    while turn < 200 && snakes.iter().all(|s| s.health > 0) {
        for my_idx in 0..snakes.len() {
            let board_snapshot = Board {
                width: board.width,
                height: board.height,
                food: board.food.clone(),
                snakes: snakes.clone(),
                hazards: board.hazards.clone(),
            };
            let my_snake = &snakes[my_idx];
            let dir = heuristic.decide(my_snake, &board_snapshot);
            examples.push(TrainingExample {
                board: encode_board_flat(&board_snapshot, my_snake),
                label: dir as u32,
            });
        }

        // Advance the game one step (simplified)
        snakes = advance_snakes(snakes, board);
        turn += 1;
    }

    examples
}

/// Encode + flatten in one step for training data
fn encode_board_flat(board: &Board, my_snake: &Snake) -> Vec<f32> {
    let tensor = encode_board(board, my_snake).unwrap();
    let flat = tensor.flatten_all().unwrap();
    flat.to_vec1::<f32>().unwrap()
}

/// Advance snakes by one turn. This is a minimal placeholder — it only
/// decrements health. For real training data, use the `GameEnv` from Part 6
/// which handles movement, food eating, body growth, wall collisions, and
/// snake-on-snake collisions correctly.
fn advance_snakes(snakes: Vec<Snake>, board: &Board) -> Vec<Snake> {
    snakes.into_iter().filter_map(|mut s| {
        if s.health <= 0 { return None; }
        s.health = s.health.saturating_sub(1);
        Some(s)
    }).collect()
}
}

In practice you’d generate thousands of games across different board configurations. The data goes into a JSONL file (one JSON object per line) for efficient reading during training.

The training loop

Here’s what training means at a high level before we look at the code:

We show the network many (board state, correct direction) pairs. For each pair, we ask the network: “what direction would you pick?” Then we compare its answer to the correct answer (which we got from the heuristic). If it’s wrong, we adjust the weights to make it more likely to be right next time. We repeat this for 20 full passes over the data. Each pass is called an epoch.

We use a loss function called cross-entropy to measure how far off the network’s prediction is from the correct answer. Minimizing cross-entropy is equivalent to maximizing the probability the network assigns to the heuristic’s choice — so the network learns to pick the same direction the heuristic would.

Here’s the code:

#![allow(unused)]
fn main() {
use candle_core::{Device, DType, Result, Tensor};
use candle_nn::{AdamW, Optimizer, VarBuilder, VarMap};
use candle_nn::loss::cross_entropy;
use snake_ml::{Model, SnakeNet, TrainingExample};

pub fn train() -> Result<()> {
    let device = Device::Cpu;
    let lr = 1e-3;

    // Load training data (JSONL: one JSON object per line)
    let examples: Vec<TrainingExample> = std::fs::read_to_string("training_data.jsonl")?
        .lines()
        .filter(|line| !line.is_empty())
        .map(|line| serde_json::from_str(line))
        .collect::<Result<_, _>>()?;

    println!("Loaded {} training examples", examples.len());

    // Initialize model
    let mut var_map = VarMap::new();
    let vs = VarBuilder::from_varmap(&var_map, DType::F32, &device);
    let net = SnakeNet::new(vs)?;
    let mut opt = AdamW::new(var_map.all_vars(), candle_nn::ParamsAdamW {
        lr,
        ..Default::default()
    })?;
    // Note: the companion code in snake-ml/src/lib.rs uses SGD (Stochastic Gradient Descent — the simpler optimizer that updates weights after each batch) instead of
    // AdamW and packs the training step into Model::train_batch(). The
    // math is the same (cross-entropy + gradient descent) — the difference
    // is the optimizer and the API surface. AdamW converges faster on
    // larger datasets; SGD is simpler for the companion's single-method API.

    let batch_size = 64;
    let epochs = 20;

    for epoch in 0..epochs {
        // Shuffle the dataset each epoch. Without this, the snake can fall into
        // a repeating cycle if batches always cover the same sequence of states
        // (e.g. the same food placements in the same order).
        let mut indices: Vec<usize> = (0..examples.len()).collect();
        use rand::seq::SliceRandom;
        indices.shuffle(&mut rand::thread_rng());

        let mut total_loss = 0.0_f64;
        let num_batches = examples.len() / batch_size;

        for batch_idx in 0..num_batches {
            let start = batch_idx * batch_size;
            let batch: &[usize] = &indices[start..start + batch_size];

            let xs: Vec<f32> = batch.iter()
                .flat_map(|&i| examples[i].board.iter().copied())
                .collect();
            let targets: Vec<u32> = batch.iter()
                .map(|&i| examples[i].label)
                .collect();

            // 968 = 8 channels × 11 × 11 (the default board size from Part 2)
            let xs_t = Tensor::from_vec(xs, (batch.len(), 968), &device)?;
            let targets_t = Tensor::from_vec(targets, (batch.len(),), &device)?;

            // Forward pass → softmax probabilities
            let logits = net.forward(&xs_t)?;
            let loss = cross_entropy(&logits, &targets_t)?;

            // Backward pass
            opt.backward_step(&loss)?;

            total_loss += loss.to_scalar::<f32>()? as f64;
        }

        let avg = total_loss / num_batches as f64;
        println!("Epoch {}: avg loss = {:.4}", epoch, avg);
    }

    // Save the trained weights
    var_map.save("snake_model.safetensors")?;
    println!("Model saved to snake_model.safetensors");

    Ok(())
}
}

The loss function intuition

Cross-entropy loss measures how far the predicted probabilities are from the target distribution. When the teacher always picks “up”, the target distribution is (1.0, 0.0, 0.0, 0.0) for (up, down, left, right). The loss is smaller when the network also predicts “up” with high probability.

Minimizing cross-entropy = maximizing the probability assigned to the teacher’s choice. That’s exactly what we want: the network learns to pick the same direction the teacher would.

cross_entropy from candle_nn::loss handles the softmax internally — we pass in raw logits, it applies softmax, then computes cross-entropy against the target labels. For 4-class classification, this is the standard loss function.

The training signal is strong because the heuristic makes consistent choices. The network learns the heuristic’s patterns.

What the network learns

After training, the network’s forward pass on a new board state produces a probability distribution over directions. Crucially, it should have generalized — if it’s working well, it’s not merely memorizing what the heuristic did in specific situations, it’s learned patterns that transfer.

If training is working well, you might observe the network doing the following:

  • Food in a particular direction → higher probability for that direction
  • Enemy nearby → avoid it
  • Narrow corridor → prefer the wider path

Some of these patterns the heuristic knew explicitly. Others it couldn’t express but the network picked up anyway. If you don’t see evidence of generalization (loss is flat, accuracy doesn’t improve), try generating more training data or training for more epochs.

Evaluation

Run the trained model against the heuristic on a set of held-out games:

#![allow(unused)]
fn main() {
use snake_ml::{Model, Direction, encode_board};
use crate::heuristic::HeuristicSnake;

/// Evaluate how often the model agrees with the heuristic.
/// Agreement rate of 80–90% is good — the network has learned the teacher's patterns.
pub fn evaluate(model: &Model, heuristic: &HeuristicSnake, games: usize) -> f32 {
    let mut agreement = 0;

    for _ in 0..games {
        let board = random_board(11, 11, 1); // generate a test board (see Part 6)
        let my_snake = &board.snakes[0];

        let tensor = encode_board(board, my_snake).unwrap();
        let flat = tensor.flatten_all().unwrap().unsqueeze(0).unwrap();

        let model_dir = model.pick_direction(&flat).unwrap();
        let heur_dir = heuristic.decide(my_snake, board);

        // Direction derives PartialEq — direct comparison, no string matching.
        if model_dir == heur_dir {
            agreement += 1;
        }
    }

    agreement as f32 / games as f32
}
}

At this stage, the network should agree with the heuristic 80–90% of the time. Lower agreement means the network hasn’t learned the heuristic’s patterns yet. Higher means it’s close to the teacher’s performance.

Next: we stop using the heuristic as a teacher and let the snake learn on its own.

Previous: Part 4 — The Heuristic Baseline · Next: Part 6 — Reinforcement Learning Basics

Part 6 — Reinforcement Learning Basics

Imitation learning got us a network that copies the heuristic. That’s useful, but the heuristic is a ceiling. The network can never discover strategies the heuristic doesn’t know about.

Reinforcement learning breaks that ceiling. Instead of learning to copy a teacher, the snake learns from its own experience. It plays, gets rewarded for good outcomes and punished for bad ones, and gradually figures out better strategies on its own.

The RL framework

In reinforcement learning, an agent (our snake) interacts with an environment (the game). At each step:

  1. The agent observes the state (the board encoding)
  2. The agent takes an action (a direction)
  3. The environment emits a reward (a number)
  4. The environment transitions to the next state

The agent’s goal is to maximize the total reward it collects over an episode (one game).

For BattleSnake, the reward signal looks like this:

Eat food: +1.0. This is the strongest signal — food is the goal of the game. Getting +1 for eating tells the network “moving to this square is worth doing.

Survive a turn: +0.01. A small positive reward for staying alive. Without this, the network learns that dying quickly is no worse than surviving — both give “nothing happens” = 0. With survival reward, the network prefers to stay alive long enough to find food.

Die: -1.0. A strong negative signal. Why -1 and not -10? Because the scale matters. If food gives +1 and dying gives -10, the network focuses only on not dying (avoiding all risk). If dying gives -1, the network is willing to take calculated risks when the food reward is high enough. The ratio between reward values shapes how risk-averse the snake is.

Nothing happens: 0. No signal — nothing was learned.

The policy

In reinforcement learning, the function that maps states to action probabilities is called a policy. We’ve been building one since Part 3 — our network’s forward pass → softmax is already a policy. We didn’t call it that.

The return for a step is the discounted sum of rewards from that step to the end of the episode:

G_t = r_t + γ * r_{t+1} + γ² * r_{t+2} + ... + γⁿ * r_{t+n}

γ (gamma) is the discount factor, usually 0.99. Here is the intuition: a reward you get now is more valuable than a reward you might get later, because the snake might die before reaching it. We multiply future rewards by γ each step to account for that uncertainty.

Each step further into the future gets multiplied by γ one more time:

  • After 10 steps: the reward is worth 0.99¹⁰ ≈ 0.9 times its face value
  • After 100 steps: 0.99¹⁰⁰ ≈ 0.37 times its face value

Rewards 100 turns from now matter about a third as much as rewards right now. This gives the snake a reason to prefer food that’s nearby over food that’s far — even when both are reachable.

Policy gradient: REINFORCE

How do we adjust the weights to make good actions more likely?

The REINFORCE algorithm says: for each action the agent took, nudge the weights in the direction that would make that action more likely, proportional to how good the outcome was.

Before we look at the equation, here is what each symbol means — in terms you’ve already seen in this tutorial:

SymbolMeaningWhere you’ve seen it
θ (theta)The network’s weightsSame as the weights we’ve been updating since Part 5
α (alpha)Learning rateSame as the lr = 1e-3 from Part 5
π(a_t | s_t; θ)The probability the network assigns to action a_t in state s_t, given weights θThe softmax output from Part 3 — same thing
∇_θ“The gradient with respect to the weights” — the direction to nudge the weightsSame backpropagation from Part 5
G_tThe return from step t onward — how good the episode turned outThe discounted sum defined above

Now the update rule:

θ ← θ + α * ∇_θ log π(a_t | s_t; θ) * G_t

Intuition: If G_t is positive (the episode went well), we want to make the action we took more likely. The gradient ∇_θ log π tells us which direction to nudge the weights to increase the probability of that action. Multiplying by G_t scales the nudge — a big positive return means a big nudge; a negative return flips the direction and makes that action less likely next time.

The log is there because it’s easier to differentiate and more numerically stable than working with probabilities directly. It gives the same direction of improvement without the numerical headaches.

In Candle, we compute this by taking the log probabilities from the forward pass and scaling the gradient by the return:

#![allow(unused)]
fn main() {
use candle_nn::{AdamW, Optimizer, VarMap, VarBuilder};
use candle_core::{Device, Result, Tensor};

/// REINFORCE training step
pub fn reinforce_step(
    net: &SnakeNet,
    opt: &mut AdamW,
    states: &Tensor,       // (batch, 968)
    actions: &[u32],       // actions taken
    returns: &[f32],       // G_t for each step
    device: &Device,
) -> Result<f32> {
    // Forward pass → log probabilities
    let logits = net.forward(states)?;
    let log_probs = candle_nn::ops::log_softmax(&logits, 1)?;

    // Extract log prob for each action taken
    // log_probs shape: (batch, 4), actions: (batch,)
    let actions_t = Tensor::from_vec(actions.to_vec(), (actions.len(),), device)?;
    let returns_t = Tensor::from_vec(returns.to_vec(), (returns.len(),), device)?;

    // Gather the log prob for each specific action
    let chosen_log_probs = log_probs.gather(&actions_t.unsqueeze(1)?, 1)?
        .squeeze(1)?;

    // Loss = -(log_prob * return) — we minimize, so negate the product
    // to maximize log π(a) * G. Without the negation, the optimizer would
    // *decrease* the probability of good actions, which is backwards.
    let loss = ((chosen_log_probs * returns_t)? * -1.0)?.mean(0)?;

    opt.backward_step(&loss)?;

    Ok(loss.to_scalar::<f32>()?)
}
}

The negative sign is important. We want to maximize log π(a) * G. Since the optimizer minimizes loss, we minimize -log π(a) * G, which is equivalent to maximizing log π(a) * G.

Variance reduction with a baseline. Pure REINFORCE gradients can be high-variance — a few lucky or unlucky rolls can dominate the update. A simple fix: subtract a baseline (often a moving average of recent episode returns) from the return before taking the gradient. This doesn’t change the expected gradient (the baseline has zero mean under the policy), but it reduces variance significantly. Implementation: keep a running mean V of past total_reward values, and use G - V instead of G directly. This isn’t required for a working tutorial, but it meaningfully speeds up convergence on BoardSnake boards.

The RL training loop

Here’s how it fits together — one episode (one game) at a time:

#![allow(unused)]
fn main() {
use std::collections::HashMap;

/// One step during a game
#[derive(Clone)]
struct Experience {
    state: Vec<f32>,       // flattened board encoding
    next_state: Vec<f32>,  // encoding after the action (used by the replay buffer in Part 9)
    action: u32,           // direction index (0–3)
    reward: f32,           // reward received
}

pub fn train_episode(
    net: &SnakeNet,
    opt: &mut AdamW,   // passed in so Adam momentum persists across episodes
    env: &GameEnv,
) -> Result<f32> {
    let device = Device::Cpu;
    let gamma = 0.99_f32;

    let mut experiences: Vec<Experience> = Vec::new();
    let mut state = env.reset();
    let mut total_reward = 0.0_f32;

    loop {
        // Encode the board state
        let tensor = encode_board(&state.board, &state.my_snake)?;
        let flat = tensor.flatten_all()?.unsqueeze(0)?;

        // Sample an action (ε-greedy: explore sometimes)
        let action: u32 = if rand::random::<f32>() < 0.1 {
            rand::random::<u32>() % 4  // explore: random
        } else {
            net.pick_direction(&flat)?.index() as u32  // Direction → index (Up=0, Down=1, Left=2, Right=3)
        };

        // Take the action, get reward and next state
        let (next_state, reward) = env.step(action);
        total_reward += reward;

        let state_vec = flat
            .squeeze(0)?  // (1, 968) → (968,)
            .to_vec1::<f32>()?;

        experiences.push(Experience {
            state: state_vec,
            next_state: {
                // For basic REINFORCE, next_state isn't used in the update —
                // it's recorded here so the same Experience struct works with
                // the replay buffer in Part 9. Part 7 shows a cleaner encoding
                // helper (encode_board_and_flat) that produces (flat, ()) in one call.
                vec![]
            },
            action,
            reward,
        });

        if env.is_done() {
            break;
        }
        state = next_state;
    }

    // Compute returns (G_t) using discounted rewards
    let returns = compute_returns(&experiences, gamma);

    // REINFORCE update
    let states_t = Tensor::from_vec(
        experiences.iter().flat_map(|e| e.state.iter().copied()).collect(),
        (experiences.len(), 968),
        &device,
    )?;
    let actions: Vec<u32> = experiences.iter().map(|e| e.action).collect();
    let returns_v: Vec<f32> = returns;

    // REINFORCE update — the optimizer is passed in from the training loop
    // so Adam's momentum buffers persist across episodes.
    let loss = reinforce_step(net, opt, &states_t, &actions, &returns_v, &device)?;

    Ok(loss)
}

fn compute_returns(experiences: &[Experience], gamma: f32) -> Vec<f32> {
    let mut returns = Vec::with_capacity(experiences.len());
    let mut g = 0.0_f32;

    for exp in experiences.iter().rev() {
        g = exp.reward + gamma * g;
        returns.push(g);
    }

    returns.reverse();
    returns
}
}

The GameEnv is a simplified game simulator — it holds the board state, advances the game when step() is called, and returns the reward. We don’t need the full BattleSnake server for training; we can simulate games locally.

What the network learns

With REINFORCE, the network starts to discover patterns that go beyond the heuristic. If training is working well, you might observe the network doing the following:

  • It might learn to wait — holding position when food is far, rather than burning health moving toward it
  • It might learn to trap opponents — positioning itself to force the opponent into its own body
  • It might learn risk assessment — accepting a small chance of death for a high-reward path

None of these strategies are explicitly programmed. They emerge from the reward signal. If you don’t see any of these behaviors after reasonable training time, the reward signal or training time may need adjustment.

The exploration-exploitation tradeoff

One subtlety: if the network always picks the highest-probability action, it never tries anything new. ε-greedy exploration fixes this: with probability ε, pick a random action. Typical values are 0.1 (10% random) early in training, decreasing to 0.01 (1%) later.

Part 7 combines REINFORCE with self-play, where the snake gets stronger opponents to train against as it improves.

Previous: Part 5 — Imitation Learning · Next: Part 7 — Self-Play Training

Part 7 — Self-Play: The Game Simulator

REINFORCE works. A single snake can learn from its own experience — eat food, get a reward, make that move more likely next time. But there’s a problem: the snake is playing against an environment that doesn’t fight back.

Real BattleSnake has opponents. Opponents that block your path, steal your food, and chase you into corners. Training against an empty board teaches you to eat. It doesn’t teach you to compete.

Self-play fixes this. The snake trains against itself — or more precisely, against past versions of itself. Every time it gets stronger, its opponent gets stronger too. The ceiling keeps rising.

Why self-play works

Here’s the intuition. Imagine you’re learning chess. Playing against a wall (someone who never moves) teaches you the rules. Playing against a weak opponent teaches you basic tactics. Playing against someone slightly better than you teaches you to see your own mistakes.

Self-play is that last one, automated. The key insight: a snake is always training against the current best version of itself. When it discovers a new strategy, the opponent (an older copy of the same network) doesn’t know that strategy yet. The new strategy wins. The network updates. Now the next training run faces a network that does know that strategy — so it has to discover a counter. And so on.

This is how AlphaGo trained. It’s how AlphaStar trained. It’s how every self-play system works: you are your own curriculum. The difficulty adjusts automatically.

The training loop

The self-play loop has four steps, repeated many times:

  1. Sample an opponent. Pick a past checkpoint — the current network, or a saved snapshot from an earlier training run.
  2. Play episodes. Run games where the training snake faces the opponent snake. Record the experience: states, actions, rewards.
  3. Update the training snake. Run REINFORCE on the collected experience. The training snake’s weights change; the opponent’s weights stay frozen.
  4. Evaluate and checkpoint. Periodically pit the current snake against the previous best. If it wins more than it loses, save a checkpoint. That checkpoint becomes a new candidate opponent.

The loop looks like this:

┌───────────────────────────────────────────┐
│  1. Sample opponent from checkpoint pool  │
│                   │                       │
│                   ▼                       │
│  2. Play N episodes (training vs opponent)│
│                   │                       │
│                   ▼                       │
│  3. REINFORCE update (training snake only)│
│                   │                       │
│                   ▼                       │
│  4. Evaluate: current vs best checkpoint  │
│     If win rate > 55%: save checkpoint    │
│     Go to 1                               │
└───────────────────────────────────────────┘

The 55% win rate threshold is deliberate. We don’t require a supermajority — we want to save checkpoints that are slightly better, because slightly-better opponents create the steady pressure that drives improvement. If we waited for 80% wins, the network would stagnate between checkpoints.

The game simulator

We need a game engine that runs locally — no HTTP, no web server, a function that advances the board state. This is the GameEnv that Part 6 referenced.

The simulator doesn’t need every BattleSnake rule. For training, we need: snakes move, food spawns, snakes die if they hit walls or bodies, health goes down, eating food restores health. That’s the core loop.

The key idea: GameEnv::step(action) advances the board by one turn, returns the reward, and reports whether the game is over. reset() starts a fresh game. Here’s the implementation:

#![allow(unused)]
fn main() {
use snake_ml::{Board, Point, Snake};

/// A player in the simulation — either a neural network or a heuristic
pub trait Player {
    fn decide(&self, board: &Board, my_snake: &Snake) -> u32;
}

/// The game simulator
pub struct GameEnv {
    board: Board,
    players: Vec<Box<dyn Player>>,
    done: bool,
    turn: u32,
    max_turns: u32,
}

#[derive(Debug, Clone)]
pub struct StepResult {
    pub board: Board,
    pub rewards: Vec<f32>,    // one reward per player
    pub done: bool,
    pub winner: Option<usize>, // Some(index) or None for draw
}

impl GameEnv {
    pub fn new(board: Board, players: Vec<Box<dyn Player>>) -> Self {
        Self {
            board,
            players,
            done: false,
            turn: 0,
            max_turns: 200,
        }
    }

    /// Reset the environment with a fresh board and return the initial state
    pub fn reset(&mut self) -> Board {
        self.board = random_board(11, 11, self.players.len());
        self.done = false;
        self.turn = 0;
        self.board.clone()
    }

    /// Advance the game one turn: each player decides, then we resolve
    pub fn step(&mut self) -> StepResult {
        // 1. Each player picks a direction
        let directions: Vec<u32> = self.players.iter()
            .enumerate()
            .map(|(i, p)| p.decide(&self.board, &self.board.snakes[i]))
            .collect();

        self.resolve_step(directions)
    }

    /// Advance the game one turn, where player `player_idx` takes `action`
    /// and all other players decide via their own `Player::decide()`. This
    /// variant is needed for REINFORCE training: the training loop must
    /// record the *actual* action the training player took (including
    /// ε-greedy exploration), not re-derive it after the fact.
    pub fn step_with_action(&mut self, player_idx: usize, action: u32) -> StepResult {
        let mut directions = Vec::with_capacity(self.players.len());
        for (i, player) in self.players.iter().enumerate() {
            if i == player_idx {
                directions.push(action);
            } else {
                directions.push(player.decide(&self.board, &self.board.snakes[i]));
            }
        }
        self.resolve_step(directions)
    }

    /// Internal: apply the movement, collision, and reward logic for a set of
    /// chosen directions. Called by both `step()` (all players decide) and
    /// `step_with_action()` (one player's action is specified externally).
    fn resolve_step(&mut self, directions: Vec<u32>) -> StepResult {
        // 2. Move snakes
        for (i, &dir) in directions.iter().enumerate() {
            let (dx, dy) = match dir {
                0 => (0, 1),   // up
                1 => (0, -1),  // down
                2 => (-1, 0),  // left
                3 => (1, 0),   // right
                _ => (0, 0),
            };

            let snake = &mut self.board.snakes[i];
            let new_head = Point {
                x: snake.head.x + dx,
                y: snake.head.y + dy,
            };

            // Insert new head at the front of the body
            snake.body.insert(0, new_head.clone());
            snake.head = new_head;

            // If the head isn't on food, remove the tail (snake doesn't grow)
            let on_food = self.board.food.iter()
                .any(|f| f.x == new_head.x && f.y == new_head.y);

            if !on_food {
                snake.body.pop();
            }

            // Decrement health every turn. If the snake ate food this
            // turn, reset to 100 *after* the decrement. Order matters:
            // if we set health=100 then decrement, the reward check
            // (health == 100) below never fires — health would be 99.
            snake.health = snake.health.saturating_sub(1);
            if on_food {
                snake.health = 100;
            }
        }

        // 3. Check for deaths
        let mut dead = vec![false; self.players.len()];

        for (i, snake) in self.board.snakes.iter().enumerate() {
            // Wall collision
            if snake.head.x < 0 || snake.head.x >= self.board.width as i32
                || snake.head.y < 0 || snake.head.y >= self.board.height as i32
            {
                dead[i] = true;
                continue;
            }

            // Self collision (head hit own body, starting from index 1)
            for (j, seg) in snake.body.iter().enumerate() {
                if j > 0 && seg.x == snake.head.x && seg.y == snake.head.y {
                    dead[i] = true;
                    break;
                }
            }

            // Body collision with other snakes
            if dead[i] { continue; }
            for (j, other) in self.board.snakes.iter().enumerate() {
                if i == j { continue; }
                // Did I hit the other snake's body?
                for seg in &other.body {
                    if seg.x == snake.head.x && seg.y == snake.head.y {
                        // Check for head-to-head: if both heads just moved here
                        if snake.head.x == other.head.x && snake.head.y == other.head.y {
                            // Shorter snake dies; if same length, both die
                            if snake.body.len() <= other.body.len() {
                                dead[i] = true;
                            }
                        } else {
                            dead[i] = true;
                        }
                        break;
                    }
                }
            }
        }

        // 4. Starvation check
        for (i, snake) in self.board.snakes.iter().enumerate() {
            if snake.health == 0 {
                dead[i] = true;
            }
        }

        // 5. Compute rewards
        let mut rewards = vec![0.0_f32; self.players.len()];
        let mut winner = None;
        let alive_count = dead.iter().filter(|&&d| !d).count();

        for (i, is_dead) in dead.iter().enumerate() {
            if *is_dead {
                rewards[i] = -1.0;
            } else {
                // Survived this turn
                rewards[i] = 0.01;

                // Ate food (health was reset to 100)
                if self.board.snakes[i].health == 100 {
                    rewards[i] += 1.0;
                }
            }
        }

        // If only one snake is alive, it wins
        if alive_count == 1 {
            winner = dead.iter().position(|&d| !d);
            if let Some(w) = winner {
                rewards[w] += 5.0; // bonus for winning
            }
            self.done = true;
        } else if alive_count == 0 {
            self.done = true;
        }

        self.turn += 1;
        if self.turn >= self.max_turns {
            self.done = true;
        }

        // 6. Remove eaten food
        self.board.food.retain(|f| {
            !self.board.snakes.iter().any(|s| s.head.x == f.x && s.head.y == f.y)
        });

        // 7. Spawn new food (one piece per empty slot, simplified)
        if self.board.food.is_empty() {
            if let Some(pos) = random_empty_cell(&self.board) {
                self.board.food.push(pos);
            }
        }

        StepResult {
            board: self.board.clone(),
            rewards,
            done: self.done,
            winner,
        }
    }

    pub fn is_done(&self) -> bool {
        self.done
    }
}
fn random_board(width: u32, height: u32, num_snakes: usize) -> Board {
    // Place snakes in opposite corners, place some food in the middle
    let mut snakes = Vec::with_capacity(num_snakes);
    let start_positions = [
        (1, 1),
        (width as i32 - 2, height as i32 - 2),
    ];

    for i in 0..num_snakes.min(start_positions.len()) {
        let (sx, sy) = start_positions[i];
        let body: Vec<Point> = (0..3).map(|j| Point { x: sx, y: sy + j }).collect();
        snakes.push(Snake {
            id: format!("snake-{i}"),
            body: body.clone(),
            head: body[0].clone(),
            health: 100,
        });
    }

    let mid_x = (width / 2) as i32;
    let mid_y = (height / 2) as i32;

    Board {
        width,
        height,
        food: vec![
            Point { x: mid_x, y: mid_y },
            Point { x: mid_x - 2, y: mid_y },
            Point { x: mid_x + 2, y: mid_y },
        ],
        snakes,
        hazards: vec![],
    }
}

fn random_empty_cell(board: &Board) -> Option<Point> {
    // Simplified: try a few random positions, return the first empty one
    use rand::Rng;
    let mut rng = rand::thread_rng();
    for _ in 0..20 {
        let x = rng.gen_range(0..board.width as i32);
        let y = rng.gen_range(0..board.height as i32);
        let occupied = board.snakes.iter()
            .any(|s| s.body.iter().any(|b| b.x == x && b.y == y));
        if !occupied {
            return Some(Point { x, y });
        }
    }
    None
}
}

This simulator handles the core mechanics. It’s not a perfect BattleSnake engine — for instance, it doesn’t handle simultaneous body-collision resolution the way the real engine does, and food spawning is simplified. But it’s good enough to generate training signal. The reward structure matches Part 6: eat food (+1), survive (+0.01), die (-1), win (+5).

The network as a player

Here’s what we need before the code: the GameEnv simulator calls player.decide() every turn to get the next direction. It doesn’t care whether it’s talking to a neural network or a heuristic — it only needs a direction. That’s why we use a trait. The Player trait abstracts over both so the same GameEnv can run games against the A* heuristic, against a random network, or against a trained one.

This is where the trained weights come in — a player with trained weights plays differently than one with random weights.

What we’re building: a struct that holds a SnakeNet and epsilon. When decide() is called, it encodes the board state, runs it through the network, and returns the direction. Here’s the full implementation:

#![allow(unused)]
fn main() {
use candle_core::{Device, DType, Tensor};
use candle_nn::{VarBuilder, VarMap};
use snake_ml::{SnakeNet, encode_board, Board, Snake};

/// A neural network player. Uses the trained (or random) weights
/// to pick a direction each turn.
pub struct NetworkPlayer {
    net: SnakeNet,
    device: Device,
    epsilon: f32, // exploration rate for ε-greedy
}

impl NetworkPlayer {
    pub fn from_var_map(var_map: &VarMap, epsilon: f32) -> Self {
        let device = Device::Cpu;
        let vs = VarBuilder::from_varmap(var_map, DType::F32, &device);
        let net = SnakeNet::new(vs).expect("failed to build network");
        Self { net, device, epsilon }
    }

    pub fn from_checkpoint(path: &str, epsilon: f32) -> Self {
        let device = Device::Cpu;
        let mut var_map = VarMap::new();
        var_map.load(path).expect("failed to load checkpoint");
        Self::from_var_map(&var_map, epsilon)
    }
}

impl Player for NetworkPlayer {
    fn decide(&self, board: &Board, my_snake: &Snake) -> u32 {
        // ε-greedy: explore sometimes
        if rand::random::<f32>() < self.epsilon {
            return rand::thread_rng().gen_range(0..4) as u32;
        }

        // Encode the board and pick the best direction
        let tensor = encode_board(board, my_snake)
            .expect("encoding failed");
        let flat = tensor.flatten_all()
            .expect("flatten failed")
            .unsqueeze(0)
            .expect("batch dim failed");

        let logits = self.net.forward(&flat)
            .expect("forward pass failed");

        logits.argmax(1)
            .expect("argmax failed")
            .to_scalar::<u32>()
            .expect("scalar failed") as u32
    }
}
}

The epsilon parameter controls exploration. During training, the training snake uses a higher epsilon (0.1 — try random moves 10% of the time). The opponent uses epsilon = 0 (always pick the best move it knows) so it plays at full strength.

What comes next

We have the simulator and a way to plug the network in as a player. Part 8 puts the pieces together: the opponent pool, the full self-play training loop, and what it looks like when the training actually works.

Previous: Part 6 — Reinforcement Learning Basics · Next: Part 8 — Self-Play: The Training Loop

Part 8 — Self-Play: The Training Loop

Part 7 gave us the game simulator and the NetworkPlayer that plugs into it. Now we wire it all together: the opponent pool, the full self-play loop, and what training actually looks like.

The opponent pool

Here’s the self-play twist: we don’t always train against the current network. If we did, the opponent would be exactly as strong as the training snake, which means every game is a coin flip. The reward signal gets noisy.

Instead, we maintain a pool of checkpoints — saved snapshots of the network weights from different points in training. Each training episode, we sample an opponent from this pool. Sometimes the opponent is the current network. Sometimes it’s a weaker checkpoint from 100 training iterations ago. The variety stabilizes training.

#![allow(unused)]
fn main() {
use std::path::{Path, PathBuf};

/// A pool of saved model checkpoints for opponent sampling
pub struct OpponentPool {
    checkpoints: Vec<PathBuf>,
}

impl OpponentPool {
    pub fn new() -> Self {
        Self { checkpoints: Vec::new() }
    }

    /// Add a checkpoint to the pool
    pub fn add(&mut self, path: PathBuf) {
        self.checkpoints.push(path);
    }

    /// Sample a random opponent from the pool.
    /// If the pool is empty, returns None (use a random player instead).
    pub fn sample(&self) -> Option<PathBuf> {
        if self.checkpoints.is_empty() {
            return None;
        }
        use rand::Rng;
        let idx = rand::thread_rng().gen_range(0..self.checkpoints.len());
        Some(self.checkpoints[idx].clone())
    }

    /// Weighted sampling: prefer recent opponents, but occasionally
    /// pick older ones. This balances challenge (recent = harder)
    /// with diversity (older = different strategies).
    pub fn sample_weighted(&self) -> Option<PathBuf> {
        if self.checkpoints.is_empty() {
            return None;
        }

        // 50% chance: most recent checkpoint (hardest opponent)
        // 50% chance: uniformly random from the full pool
        use rand::Rng;
        let mut rng = rand::thread_rng();
        if rng.gen_range(0.0..1.0) < 0.5 {
            Some(self.checkpoints.last().unwrap().clone())
        } else {
            let idx = rng.gen_range(0..self.checkpoints.len());
            Some(self.checkpoints[idx].clone())
        }
    }

    pub fn len(&self) -> usize {
        self.checkpoints.len()
    }
}
}

The weighted sampling is a practical detail that matters. Training only against the strongest opponent is a curriculum, but it can be too hard — the training snake loses every game, the reward signal is uniformly negative, and nothing is learned. Mixing in weaker opponents gives the training snake some wins, some positive reward, and a gradient that points in a useful direction.

The full training loop

Everything comes together here. The loop runs for a fixed number of iterations. Each iteration: sample an opponent, play episodes, update the training snake, occasionally evaluate.

#![allow(unused)]
fn main() {
use std::fs;
use rand::Rng;
use candle_nn::{AdamW, Optimizer, VarMap};

pub fn train_self_play(
    var_map: &VarMap,
    num_iterations: usize,
    episodes_per_iteration: usize,
    checkpoint_dir: &str,
) -> Result<(), Box<dyn std::error::Error>> {
    let device = Device::Cpu;
    let learning_rate = 1e-3_f64;
    let gamma = 0.99_f32;
    let initial_epsilon = 0.2_f32;
    let min_epsilon = 0.01_f32;

    // Create the optimizer once so Adam's momentum buffers persist
    // across training iterations.
    let mut opt = AdamW::new(var_map.all_vars(), candle_nn::ParamsAdamW {
        lr: learning_rate,
        ..Default::default()
    })?;

    let mut pool = OpponentPool::new();

    // Save the initial random weights as the first checkpoint
    fs::create_dir_all(checkpoint_dir)?;
    let initial_path = format!("{checkpoint_dir}/checkpoint-0.safetensors");
    var_map.save(&initial_path)?;
    pool.add(PathBuf::from(&initial_path));

    for iteration in 1..=num_iterations {
        // Epsilon decays over time: less exploration as training progresses
        let epsilon = (initial_epsilon
            * (1.0 - iteration as f32 / num_iterations as f32))
            .max(min_epsilon);

        let mut all_experiences: Vec<Experience> = Vec::new();
        let mut total_reward = 0.0_f32;
        let mut wins = 0usize;
        let mut losses = 0usize;

        for _ in 0..episodes_per_iteration {
            // 1. Sample an opponent
            let opponent = match pool.sample_weighted() {
                Some(path) => NetworkPlayer::from_checkpoint(&path.to_string_lossy(), 0.0),
                // Note: VarMap doesn't implement Clone. In a real implementation, you'd
                // either share a &VarMap through a reference (if the player only needs it
                // briefly) or construct a fresh VarMap loaded from the same safetensors file.
                None => NetworkPlayer::from_var_map(var_map, 0.0),
            };

            // The training snake (with exploration)
            let training_player = NetworkPlayer::from_var_map(var_map, epsilon);

            // 2. Play an episode
            let board = random_board(11, 11, 2);
            let mut env = GameEnv::new(
                board,
                vec![Box::new(training_player), Box::new(opponent)],
            );

            env.reset();
            let mut experiences: Vec<Experience> = Vec::new();

            while !env.is_done() {
                let my_snake = &env.board.snakes[0];
                let state = encode_board(&env.board, my_snake)?
                    .flatten_all()?
                    .to_vec1::<f32>()?;

                // Decide the training player's action *before* calling env.step().
                // We can't re-derive the action after the fact — the training player
                // uses ε-greedy, so the actual action might differ from the
                // greedy argmax. If we called argmax after env.step(),
                // we'd always record the greedy action, not the action the player
                // actually took. That would train the network to reinforce actions
                // it didn't take.
                let state_t = Tensor::from_vec(
                    state.clone(),
                    (1, 968),
                    &device,
                )?;
                let action = if rand::random::<f32>() < epsilon {
                    rand::thread_rng().gen_range(0..4) as u32  // explore
                } else {
                    // Exploit: run a forward pass and pick the best direction
                    let vs = VarBuilder::from_varmap(var_map, DType::F32, &device);
                    let net = SnakeNet::new(vs)?;
                    net.forward(&state_t)?.argmax(1)?.to_scalar::<u32>()?
                };

                let result = env.step_with_action(0, action);

                // Encode the next board state for the replay buffer.
                // We need (state, action, reward, next_state) for every step
                // so the replay buffer in Part 9 can do prioritized sampling.
                // `encode_board_and_flat` is a small wrapper in snake-ml that calls
                // encode_board, flattens, and returns (Vec<f32, ()). The () is a
                // unit marker for destructuring convenience.
                let (next_state_flat, _) = encode_board_and_flat(&result.board, &result.board.snakes[0])?;

                experiences.push(Experience {
                    state,
                    next_state: next_state_flat,
                    action,
                    reward: result.rewards[0],
                });

                total_reward += result.rewards[0];

                if result.done {
                    match result.winner {
                        Some(0) => wins += 1,
                        Some(1) => losses += 1,
                        None => {} // draw
                    }
                }
            }

            all_experiences.extend(experiences);
        }

        // 3. Compute returns
        // We defined compute_returns in Part 6: it iterates backward through
        // the experience list, accumulating discounted rewards. The result is
        // G_t for each step — how good the outcome was from that point onward.
        let returns = compute_returns(&all_experiences, gamma);

        // Normalize returns (reduces variance in REINFORCE)
        let mean = returns.iter().sum::<f32>() / returns.len() as f32;
        let std = (returns.iter()
            .map(|r| (r - mean).powi(2))
            .sum::<f32>() / returns.len() as f32)
            .sqrt()
            .max(1e-8);
        let normalized: Vec<f32> = returns.iter()
            .map(|r| (r - mean) / std)
            .collect();

        // 4. REINFORCE update
        let states_t = Tensor::from_vec(
            all_experiences.iter()
                .flat_map(|e| e.state.iter().copied())
                .collect(),
            (all_experiences.len(), 968),
            &device,
        )?;
        let actions: Vec<u32> = all_experiences.iter().map(|e| e.action).collect();

        // The optimizer is created once outside the iteration loop so
        // Adam's momentum persists across training steps.
        reinforce_step(&{
            let vs = VarBuilder::from_varmap(var_map, DType::F32, &device);
            SnakeNet::new(vs)?
        }, &mut opt, &states_t, &actions, &normalized, &device)?;

        // 5. Evaluate and checkpoint
        if iteration % 10 == 0 {
            let win_rate = wins as f32 / (wins + losses).max(1) as f32;
            println!(
                "Iteration {iteration}: reward={total_reward:.1} \
                 win_rate={win_rate:.2} epsilon={epsilon:.3} pool={}",
                pool.len()
            );

            // Save a checkpoint
            let path = format!("{checkpoint_dir}/checkpoint-{iteration}.safetensors");
            var_map.save(&path)?;
            pool.add(PathBuf::from(&path));

            // Evaluate against the previous best
            if pool.len() > 1 {
                let prev = pool.checkpoints[pool.len() - 2].to_string_lossy().to_string();
                let eval_wr = evaluate_against(var_map, &prev, 50)?;
                println!("  eval vs previous: {eval_wr:.2} win rate");

                if eval_wr > 0.55 {
                    println!("  ✓ New best! Checkpoint saved.");
                }
            }
        }
    }

    // Save the final model
    let final_path = format!("{checkpoint_dir}/model-final.safetensors");
    var_map.save(&final_path)?;
    println!("Training complete. Final model saved to {final_path}");

    Ok(())
}
}

Entropy and beta decay. If you add entropy regularization — subtracting beta * H(π) from the reward, where H(π) measures how spread-out the policy distribution is — make sure beta decays over training. A constant high beta keeps the snake perpetually random; it never commits to learned behavior. A common schedule: beta = max(0.01, 0.5 * exp(-step / 1000)) — start exploratory, gradually lock in the policy.

A few things worth calling out in this code:

Return normalization. This connects back to the variance problem from Part 6. Raw REINFORCE returns can vary wildly: one episode might give +10 (the snake ate everything and won), the next -5 (died on turn 3). Without normalization, the gradient points in one direction after the first episode and the opposite direction after the second. Training is noisy — the snake can’t tell which actions were actually good.

Normalizing fixes this: subtract the mean, divide by standard deviation. Now every batch of returns centers on zero with unit variance. The gradient consistently pushes toward better actions, regardless of whether the snake got lucky or unlucky in a given episode. This is the same variance-reduction idea that Part 6 mentioned — normalization makes the signal reliable enough for training to converge.

Epsilon decay. Early in training, the snake needs to explore (epsilon = 0.2). Late in training, it mostly exploits what it’s learned (epsilon = 0.01). The linear decay schedule is simple but effective. More sophisticated schedules (exponential, cosine) can help, but linear decay works well enough for BattleSnake.

Checkpoint evaluation. Every 10 iterations, we pit the current network against the previous checkpoint. A 55% win rate means the current network is slightly better — good enough to justify saving. We’re looking for incremental progress, not dominance.

What the network learns through self-play

Imitation learning taught the network to copy a heuristic. Basic REINFORCE taught it to eat food and avoid dying. Self-play teaches it something neither of those could: how to compete.

If training is working well, you might observe the network discovering the following strategies:

  • Space control. Claiming territory near the center of the board gives more room to maneuver. The heuristic knows to go for food, but it doesn’t know that position matters when food isn’t nearby.
  • Opponent tracking. The network learns to check where the enemy head is — not only where its body is. A nearby enemy head means a contested food or a head-to-head collision risk.
  • Patience. Sometimes the best move is to wait. If the opponent is between you and food, burning health to race them is worse than waiting for a better opportunity.
  • Endgame awareness. In the late game, when both snakes are long and the board is mostly body segments, the network learns to value space more than food.

None of these are explicitly in the reward function. They emerge from the pressure of playing against an opponent that’s also learning. The strategies are discovered, not designed. If you don’t see any of these behaviors after a reasonable number of iterations, the opponent pool may need adjustment — too-strong opponents can make learning impossible, while too-weak opponents don’t provide useful pressure.

The variance problem (and how to think about it)

REINFORCE is notoriously high-variance. One episode might produce a total return of +15 (the snake ate everything, won the game). The next might produce -5 (the snake died on turn 3). The gradient flips direction between episodes. Training is noisy.

Return normalization helps. So does averaging over many episodes per iteration. But the fundamental problem remains: in BattleSnake, one bad move can kill you, and the network has to figure out which move out of 20 was the bad one.

There are more sophisticated algorithms that address this — PPO (Proximal Policy Optimization) and A2C (Advantage Actor-Critic) are the standard ones. They use a baseline to reduce variance: instead of asking “was this action good?” they ask “was this action better than expected?” The “expected” part is a value function — a second network that predicts how good a state is.

For this tutorial, REINFORCE is the right starting point. It’s simple, it works, and it teaches the core idea (policy gradient) without the overhead of a value function. If you want to go further, swapping REINFORCE for PPO is the natural next step.

A note on training time

Self-play is slow. Each episode runs a full game (50–200 turns). Each turn requires a forward pass through the network for each player. A training run of 1000 iterations with 10 episodes per iteration is 10,000 games. On CPU, that might take hours. On GPU, it’s faster, but Candle’s GPU support varies by platform.

The practical approach: start small. Train for 100 iterations, check the win rate, adjust the hyperparameters, repeat. Don’t run a 10,000-iteration training job on your first try. The network’s behavior after 100 iterations will tell you whether the reward signal is working.

Part 9 takes the trained model and wires it into the web server from Part 1 — making the snake live on BattleSnake.

Previous: Part 7 — Self-Play: The Game Simulator · Next: Part 9 — Wiring It to the Web

Part 9 — Wiring It to the Web

The snake is trained. The weights are saved to a .safetensors file on disk. Right now, that file is useless — it’s numbers on disk. We need to load those numbers into a network, feed it a board state, and return a move over HTTP.

This part takes everything we’ve built — the server from Part 1, the encoder from Part 2, the network from Part 3, the trained weights from Parts 5–7 — and connects it. The snake goes live.

What we’re building

The server from Part 1 was simple: receive a move request, return a direction. The trained snake server is the same shape, but the decision comes from the network instead of a heuristic or a random pick.

The pipeline for each move request:

HTTP request (JSON)
    │
    ▼  deserialize
Board + my_snake
    │
    ▼  encode_board()
Tensor(8, 11, 11)
    │
    ▼  flatten + batch dim
Tensor(1, 968)
    │
    ▼  net.forward() → argmax
direction index
    │
    ▼  map to string
"up" / "down" / "left" / "right"

Every turn, the server does this entire pipeline in under 500 milliseconds — that’s the BattleSnake timeout. On CPU, inference takes about 1ms for our small MLP. Plenty of headroom.

Loading the model at startup

The model loads once, when the server starts. We don’t want to reload the weights on every request — that would be slow and the server would miss its timeout.

use actix_web::{web, App, HttpServer, HttpResponse};
use candle_core::{Device, Tensor};
use candle_nn::VarMap;
use snake_ml::{SnakeNet, encode_board, Board, Snake, Point};
use serde::{Deserialize, Serialize};
use std::sync::Mutex;

// Shared model state — loaded once, used by every request
struct AppState {
    net: SnakeNet,
    var_map: VarMap,
}

#[derive(Deserialize)]
struct MoveRequest {
    board: Board,
    you: Snake,
    #[allow(dead_code)]
    game: GameInfo,
    #[allow(dead_code)]
    turn: u32,
}

#[derive(Deserialize)]
struct GameInfo {
    #[allow(dead_code)]
    id: String,
    #[allow(dead_code)]
    timeout: u32,
}

#[derive(Serialize)]
struct MoveResponse {
    #[serde(rename = "move")]
    direction: String,
    shout: Option<String>,
}

#[actix_web::main]
async fn main() -> std::io::Result<()> {
    let model_path = std::env::args()
        .nth(1)
        .unwrap_or_else(|| "model-final.safetensors".to_string());

    println!("Loading model from {model_path}...");

    let device = Device::Cpu;
    let mut var_map = VarMap::new();
    var_map.load(&model_path)
        .expect("failed to load model weights");
    let vs = candle_nn::VarBuilder::from_varmap(&var_map, DType::F32, &device);
    let net = SnakeNet::new(vs).expect("failed to build network");

    println!("Model loaded. Starting server on port 8080...");

    let state = web::Data::new(Mutex::new(AppState { net, var_map }));

    HttpServer::new(move || {
        App::new()
            .app_data(state.clone())
            .route("/start", web::post().to(start_handler))
            .route("/move", web::post().to(move_handler))
            .route("/end", web::post().to(end_handler))
    })
    .bind(("0.0.0.0", 8080))?
    .run()
    .await
}

async fn start_handler() -> HttpResponse {
    HttpResponse::Ok().json(serde_json::json!({
        "apiversion": "1",
        "author": "your-name",
        "color": "#7B2D8E",
        "head": "default",
        "tail": "default"
    }))
}

async fn move_handler(
    state: web::Data<Mutex<AppState>>,
    req: web::json::Json<MoveRequest>,
) -> HttpResponse {
    let direction = decide_move(&state, &req.board, &req.you)
        .unwrap_or_else(|e| {
            eprintln!("Inference error: {e}");
            "up".to_string()
        });

    HttpResponse::Ok().json(MoveResponse {
        direction,
        shout: None,
    })
}

async fn end_handler() -> HttpResponse {
    HttpResponse::Ok().finish()
}

The Mutex wraps the model so Actix can share it across async handlers. Our network is read-only during inference — the weights don’t change — so the lock is never contested. A std::sync::RwLock or even an unsafe static would also work, but Mutex is correct and simple.

Blocking in async handlers. Candle’s tensor operations are CPU-bound and block the thread they run on. In an async runtime like Tokio (which Actix uses under the hood), blocking the main IO threads kills concurrency — every request has to wait for the blocking operation to finish. The fix: wrap the inference call in tokio::task::spawn_blocking, which moves it to a dedicated thread pool. For our small MLP this is a microsecond-level concern, but it’s the right pattern for production:

#![allow(unused)]
fn main() {
async fn move_handler(state: web::Data<Mutex<AppState>>, req: web::Json<MoveRequest>) -> HttpResponse {
    let board = req.board.clone();
    let my_snake = req.you.clone();
    let state = state.clone();

    // Run inference on the blocking thread pool
    let direction = tokio::task::spawn_blocking(move || {
        decide_move(&state, &board, &my_snake)
    })
    .await
    .unwrap_or_else(|_| Err(candle_core::Error::Msg("task join failed".into())))
    .unwrap_or_else(|_| "up".to_string());

    HttpResponse::Ok().json(MoveResponse { direction, shout: None })
}
}

If you load the model synchronously in main() before binding the server, startup is delayed until the weights are loaded. A cleaner pattern: load the model in a background spawn_blocking task before HttpServer::bind(), then share the loaded model via Arc. This way the server is ready to accept connections while the model is loading.

The decision function

This is the heart of the live snake. It’s the full pipeline in one function: encode → forward pass → pick a direction.

#![allow(unused)]
fn main() {
fn decide_move(
    state: &web::Data<Mutex<AppState>>,
    board: &Board,
    my_snake: &Snake,
) -> Result<String, candle_core::Error> {
    let app = state.lock().expect("state lock poisoned");
    let device = Device::Cpu;

    // 1. Encode the board into a feature-plane tensor
    let tensor = encode_board(board, my_snake)?;

    // 2. Flatten to (1, 968) — batch dimension of 1 for a single board
    let flat = tensor.flatten_all()?.unsqueeze(0)?;

    // 3. Forward pass → logits → pick the best direction
    let logits = app.net.forward(&flat)?;
    let best = logits.argmax(1)?.to_scalar::<u32>()? as usize;

    // 4. Map index to direction string
    let direction = match best {
        0 => "up",
        1 => "down",
        2 => "left",
        3 => "right",
        _ => "up", // shouldn't happen
    };

    Ok(direction.to_string())
}
}

That’s the entire inference pipeline. On CPU, this takes about 1ms. The BattleSnake timeout is typically 500ms. We have a 499ms margin.

Handling the edge case: what if the network picks a deadly move?

The network can pick a direction that walks into a wall or another snake. During training, that’s fine — the reward signal teaches it not to. During a live game, one bad move and the snake is dead.

The practical fix: validate the network’s choice. If it picks a direction that leads to immediate death, fall back to the first safe direction. This is the same survival check the heuristic used in Part 4, applied as a safety net.

#![allow(unused)]
fn main() {
fn decide_move_safe(
    state: &web::Data<Mutex<AppState>>,
    board: &Board,
    my_snake: &Snake,
) -> Result<String, candle_core::Error> {
    let app = state.lock().expect("state lock poisoned");

    // Compute safe directions (not walls, not bodies)
    let safe = safe_directions(board, my_snake);

    if safe.is_empty() {
        // No safe move — we're going to die regardless.
        return decide_move(state, board, my_snake);
    }

    // Get the network's preference (probability for each direction)
    let tensor = encode_board(board, my_snake)?;
    let flat = tensor.flatten_all()?.unsqueeze(0)?;
    let logits = app.net.forward(&flat)?;
    let probs = candle_nn::ops::softmax(&logits, 1)?;

    // Rank directions by probability, then pick the highest-probability safe one
    let probs_vec = probs.squeeze(0)?.to_vec1::<f32>()?;  // (1,4) → (4,)
    let mut ranked: Vec<(usize, f32)> = probs_vec
        .iter()
        .enumerate()
        .map(|(i, &p)| (i, *p))
        .collect();
    ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());

    let direction_names = ["up", "down", "left", "right"];

    for (idx, _prob) in ranked {
        if safe.contains(&idx) {
            return Ok(direction_names[idx].to_string());
        }
    }

    // Fall back to the first safe direction
    Ok(direction_names[safe[0]].to_string())
}

fn safe_directions(board: &Board, my_snake: &Snake) -> Vec<usize> {
    let head = &my_snake.head;
    let width = board.width as i32;
    let height = board.height as i32;

    // Collect all body positions (our own + enemies)
    let bodies: Vec<&Point> = board.snakes.iter()
        .flat_map(|s| s.body.iter())
        .collect();

    let mut safe = Vec::with_capacity(4);
    let candidates = [
        (0, 0, 1),   // up
        (1, 0, -1),  // down
        (2, -1, 0),  // left
        (3, 1, 0),   // right
    ];

    for (idx, dx, dy) in candidates {
        let nx = head.x + dx;
        let ny = head.y + dy;

        // Wall check
        if nx < 0 || nx >= width || ny < 0 || ny >= height {
            continue;
        }

        // Body check
        let hit_body = bodies.iter().any(|b| b.x == nx && b.y == ny);
        if hit_body {
            continue;
        }

        safe.push(idx);
    }

    safe
}
}

This is the “network with a safety net” pattern. The network makes the strategic decision. The safety net prevents tactical blunders — moves that kill the snake on the very next turn. Between the two, the snake plays well and doesn’t do obviously suicidal things.

This isn’t cheating. The network should learn to avoid walls and bodies on its own — the reward signal punishes death. But the network is probabilistic, and on any given turn it might make a mistake. The safety net catches the most obvious ones.

The /start response

BattleSnake requires a response to the /start endpoint that includes metadata about your snake. We can use it:

{
  "apiversion": "1",
  "author": "your-name",
  "color": "#7B2D8E",
  "head": "smart-caterpillar",
  "tail": "pixel",
  "version": "0.1.0"
}

The color, head, and tail fields control the snake’s appearance on the game board. They’re cosmetic, but picking a distinctive color makes it easier to spot your snake during playback.

Running the live snake

Build and run with the path to your trained model:

cargo run -- model-final.safetensors

The server starts on port 8080. For local testing, tunnel it so the BattleSnake engine can reach it:

# Using cloudflared
cloudflared tunnel --url http://localhost:8080

# Or using ngrok
ngrok http 8080

Then go to play.battlesnake.com, register your snake with your tunnel URL, and start a game. Watch it play.

What to look for

A trained snake should:

  • Navigate toward food consistently, not only when it happens to be in the right direction
  • Avoid walls and its own body even in tight spaces
  • React to the opponent — not walk straight into the enemy’s body
  • Survive past turn 10 — a random snake typically dies within 5–10 turns on an 11×11 board

A well-trained self-play snake should also:

  • Control space — stay near the center when there’s no food pressure
  • Win head-to-head — avoid walking into a larger snake, but challenge a smaller one
  • Manage health — not starve because it was doing something else

If the snake is dying on turn 3, the model weights are the problem — either training didn’t converge, or the reward signal isn’t working. Go back to training, check the loss and win rate curves, and make sure the network is learning at all before worrying about strategy.

Monitoring in production

BattleSnake games happen in real time. If something goes wrong — the server is slow, the model produces garbage, the encoding is wrong — you need to know.

Add logging to the move handler:

#![allow(unused)]
fn main() {
async fn move_handler(
    state: web::Data<Mutex<AppState>>,
    req: web::Json<MoveRequest>,
) -> HttpResponse {
    let turn = req.turn;
    let health = req.you.health;

    let direction = decide_move_safe(&state, &req.board, &req.you)
        .unwrap_or_else(|e| {
            eprintln!("Turn {turn}: inference error: {e}");
            "up".to_string()
        });

    println!("Turn {turn}: health={health} → {direction}");

    HttpResponse::Ok().json(MoveResponse {
        direction,
        shout: None,
    })
}
}

The log line gives you the turn number, the snake’s health, and the chosen direction. If you see health=0 → up, the snake starved. If you see the same direction repeated every turn, the network is stuck. If you see inference error, something is wrong with the model or the encoding.

The complete server implementation

The full working server is in snake-server/src/main.rs. It includes:

  • A* pathfinding to the nearest food (pure Rust, BinaryHeap + HashMap, no external crates)
  • Survival check — immediately safe directions checked before A* runs
  • Fallback chain — A* failure drops to the first safe direction, then "up"
  • spawn_blocking wrapping on every request — the decision function runs on the blocking thread pool, not the async IO threads

You can run it right now without any trained weights. The A* heuristic is competitive enough to play a decent game against a random opponent on an 11×11 board.

Using the model: when you have trained weights, load them into a Model via Model::load("weights.safetensors") at startup and add the forward pass to decide(). The fallback chain means the server still works if the model is wrong or missing.

Part 10 covers scaling up: GPU training, convolutional architectures, and what to try next if you want a genuinely competitive snake.

Previous: Part 8 — Self-Play: The Training Loop · Next: Part 10 — Scaling Up

Part 10 — Scaling Up

The snake works. It beats a random player, holds its own against the heuristic, and survives more than a few turns against real opponents on BattleSnake. That’s a good start. But “good start” isn’t the same as “competitive.”

This part isn’t a tutorial in the same way Parts 1–8 were. It’s a map of what to try next — the techniques, architectures, and training tricks that turn a decent snake into a strong one. Each direction is a rabbit hole. Pick one, go deep, measure the results.

The bottleneck: our architecture

The MLP from Part 3 works, but it has a structural limitation. Flattening the 8×11×11 tensor into a 968-element vector destroys spatial information. The first linear layer has to re-learn that adjacent cells in the input are adjacent on the board. That’s wasted capacity — the network is spending parameters to discover something the encoding already knows.

MLP → CNN reshape. Candle’s Conv2d expects (N, C, H, W) input. When you replace the MLP with the CNN, remember that Part 3’s flattened MLP input (N, 8*H*W) has no spatial structure, while the CNN needs (N, 8, H, W). Part 2’s encode_board returns (8, H, W) — the CNN’s forward handles the reshape internally so callers can pass the raw encoding.

Convolutional layers fix this. A convolutional layer slides a small filter (typically 3×3) across the spatial dimensions of the input, applying the same weights at every position. The filter naturally captures local patterns: food two cells to the right, a body segment blocking the path, an enemy head approaching. These are exactly the patterns that matter in BattleSnake.

A simple CNN (Convolutional Neural Network — a network that uses convolution layers instead of fully-connected ones, preserving spatial structure in the input)

Replace the MLP with a small convolutional network:

Input: (8, 11, 11) — 8 feature planes
    │
    ▼ Conv2d(8 → 32, kernel=3, padding=1) → ReLU
    │  shape: (32, 11, 11)
    ▼ Conv2d(32 → 64, kernel=3, padding=1) → ReLU
    │  shape: (64, 11, 11)
    ▼ Flatten
    │  shape: (64 * 11 * 11,) = 7744
    ▼ Linear(7744 → 256) → ReLU
    │
    ▼ Linear(256 → 4) → logits

In Candle:

#![allow(unused)]
fn main() {
use candle_core::{Result, Tensor, Device};
use candle_nn::{Conv2dConfig, VarBuilder, conv2d, linear, Linear, Module};

pub struct SnakeCnn {
    conv1: Conv2d,
    conv2: Conv2d,
    fc1: Linear,
    fc2: Linear,
}

impl SnakeCnn {
    pub fn new(vs: VarBuilder) -> Result<Self> {
        let conv_cfg = Conv2dConfig {
            padding: 1, // same-padding: output spatial dims match input
            ..Default::default()
        };

        let conv1 = conv2d(8, 32, 3, conv_cfg, vs.pp("conv1"))?;
        let conv2 = conv2d(32, 64, 3, conv_cfg, vs.pp("conv2"))?;
        let fc1 = linear(64 * 11 * 11, 256, vs.pp("fc1"))?;
        let fc2 = linear(256, 4, vs.pp("fc2"))?;

        Ok(Self { conv1, conv2, fc1, fc2 })
    }

    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
        // encode_board produces (8, 11, 11). The CNN needs (batch, 8, 11, 11).
        // Reshape once at entry so callers can pass either shape.
        let b = x.dim(0)?;
        let x = x.reshape((b, 8, 11, 11))?;
        let x = self.conv1.forward(&x)?.relu()?;
        // (batch, 32, 11, 11)
        let x = self.conv2.forward(&x)?.relu()?;
        // (batch, 64, 11, 11)
        let x = x.reshape((b, 64 * 11 * 11))?;
        let x = self.fc1.forward(&x)?.relu()?;
        self.fc2.forward(&x)
    }
}
}

MLP vs CNN input shapes. The MLP from Part 3 expects a flattened tensor (batch, 8*H*W) with no spatial structure. The CNN expects (batch, 8, H, W) with explicit channel and spatial dimensions. This is a common trip point when switching architectures — the reshape is part of the contract, not optional padding. The code above handles it by reshaping at the start of forward, so callers pass (8, 11, 11) directly and the network reshapes internally.

The CNN uses more parameters than the MLP (the fully-connected layers are bigger after convolution), but the convolutional filters share weights across positions. The first conv layer has 8 × 32 × 3 × 3 = 2,304 parameters regardless of board size. The MLP’s first layer has 968 × 256 = 247,808. The CNN is more parameter-efficient for spatial data.

Dueling networks

A standard policy network outputs 4 logits directly. A dueling network splits the final layers into two heads:

  • V(s) — the value: how good is this position overall? One scalar. A position with food nearby and no enemies is good. One where you’re surrounded is bad.
  • A(s, a) — the advantage of each action: how much better is this specific move compared to the average move? Four values (one per direction). Taking a food-adjacent cell has high advantage; taking a safe-but-useless cell has low advantage.
  • Q(s, a) — the quality: the overall quality of taking action a in state s. This is what we want to maximize.

The key insight: it’s easier to learn two simple things than one complicated thing. Instead of learning Q directly, the network learns V and A separately, then combines them:

Q(s, a) = V(s) + A(s, a) - mean(A(s, ·))

The - mean(A) term needs explaining. Without it, V and A would be ambiguous: V could be 10 with A = [1, 2, 3, 4], or V could be 13 with A = [-2, -1, 0, 1] — both give the same Q values (13, 14, 15, 16). The network could learn either split, and the two options aren’t comparable. Subtracting the mean forces the advantages to average zero, so A only carries the differences between actions and V has to carry the base value. Now V and A each have a single, stable job, and training is more reliable.

The network learns V and A jointly, which stabilizes training significantly.

#![allow(unused)]
fn main() {
/// The standard forward is defined in the first `impl SnakeCnn` block above.
/// It reshapes to (batch, 8, 11, 11) at entry and handles the full conv → fc pipeline.


/// A dueling variant of the CNN with separate value and advantage heads.
/// The struct includes two additional linear layers not present in the base SnakeCnn.
pub struct SnakeCnnDueling {
    conv1: Conv2d,
    conv2: Conv2d,
    shared_fc: Linear,
    v_head: Linear,   // value: (batch, 256) → (batch, 1)
    a_head: Linear,   // advantage: (batch, 256) → (batch, 4)
}

impl SnakeCnnDueling {
    pub fn new(vs: VarBuilder) -> Result<Self> {
        let conv_cfg = Conv2dConfig {
            padding: 1,
            ..Default::default()
        };

        let conv1 = conv2d(8, 32, 3, conv_cfg, vs.pp("conv1"))?;
        let conv2 = conv2d(32, 64, 3, conv_cfg, vs.pp("conv2"))?;
        let shared_fc = linear(64 * 11 * 11, 256, vs.pp("shared_fc"))?;
        let v_head = linear(256, 1, vs.pp("v_head"))?;
        let a_head = linear(256, 4, vs.pp("a_head"))?;

        Ok(Self { conv1, conv2, shared_fc, v_head, a_head })
    }

    pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
        // Reshape for callers who pass raw (8, 11, 11) encoding.
        // Consistent with SnakeCnn::forward above.
        let b = x.dim(0)?;
        let x = x.reshape((b, 8, 11, 11))?;

        let x = self.conv1.forward(&x)?.relu()?;
        let x = self.conv2.forward(&x)?.relu()?;
        let x = x.reshape((b, 64 * 11 * 11))?;

        // Shared feature extraction
        let shared = self.shared_fc.forward(&x)?.relu()?;

        // Split into value and advantage heads
        let v = self.v_head.forward(&shared)?;               // (batch, 1)
        let a = self.a_head.forward(&shared)?;                // (batch, 4)

        // Q = V + A - mean(A)  — subtract mean for identifiability
        let a_centered = (&a - &a.mean(1)?)?;
        (&v + &a_centered)?
    }
}
}

This is a standard production technique for RL agents (DQN, Actor-Critic, PPO). The added complexity is minimal — two small linear heads instead of one — but the training signal is meaningfully richer.

Residual connections

A deeper CNN can learn more complex patterns, but training deeper networks is harder. During backpropagation, gradients flow backward through each layer. Each layer multiplies the gradient by its local derivative. Multiply enough small numbers together and the result is tiny — the gradient “vanishes” and the early layers barely update at all. A 4-layer CNN might have this problem; a 10-layer CNN almost certainly will.

Residual connections fix this by giving the gradient a shortcut path:

x ──▶ Conv → ReLU → Conv → ReLU ──┐
│                                  │  +  ──▶ ReLU
└──────────────────────────────────┘

The + is an element-wise addition. If the layer doesn’t need to change anything, it learns to output zeros and the shortcut passes the input through unchanged. The layer only needs to learn the difference between the input and the desired output (the “residual”), which is often a much smaller adjustment than learning the full transformation from scratch. This makes training deeper networks (4+ conv layers) practical.

#![allow(unused)]
fn main() {
fn residual_block(x: &Tensor, conv1: &Conv2d, conv2: &Conv2d) -> Result<Tensor> {
    let out = conv1.forward(x)?.relu()?;
    let out = conv2.forward(&out)?;
    // Skip connection: add input to output
    (out + x)?.relu()
}
}

Residual connections require that the input and output have the same shape (same channels, same spatial dimensions). If you want to change the channel count, add a 1×1 convolution on the skip path.

Board size and the fully-connected layer

The CNN’s convolutional layers work on any board size — a 3×3 filter slides the same way on 11×11 and 19×19. But the fully-connected layer after flattening is tied to a specific spatial size. 64 × 11 × 11 = 7,744 for an 11×11 board. 64 × 19 × 19 = 23,104 for a 19×19 board.

Two ways to handle this:

  1. Global average pooling — instead of flattening, average each channel across the spatial dimensions. The fc layer input becomes (64,) regardless of board size. This loses spatial precision but makes the network board-size-agnostic.

  2. Train on a fixed board size — most competitive BattleSnake snakes train on 11×11. Tournament games occasionally use other sizes, but 11×11 is the default. If you’re optimizing for competition, fixed-size training is fine.

The bottleneck: our training algorithm

REINFORCE works, but it’s the simplest policy gradient method. It has two well-known problems:

  1. High variance. The gradient estimate is noisy. One episode might produce a total return of +15, the next might produce -5. The update direction flips between episodes. Training is slow and unstable.

  2. No baseline. REINFORCE judges an action by the absolute return. But every action in a good episode looks good (the return is high), and every action in a bad episode looks bad (the return is low). The algorithm can’t distinguish “this was a good action in a bad episode” from “this was a bad action in a bad episode.”

PPO: the standard fix

Proximal Policy Optimization (PPO) addresses both problems. It uses a value function — a second network (or a second head on the same network) that predicts the expected return from a given state. The value function serves as the baseline: instead of scaling the gradient by the raw return, PPO scales it by the advantage — how much better this action was than expected.

advantage = return - value_function(state)

If the action led to a return of +5 but the value function expected +3, the advantage is +2. The action was better than expected — make it more likely.

PPO also constrains the size of each update. The policy can only change by a small amount per step (the “proximal” part). This prevents catastrophic updates that destroy learned behavior.

Implementing PPO in Candle is a larger project than REINFORCE — you need the value function, the advantage calculation, the clipped objective, and the epoch-based minibatch updates. The architecture change is straightforward (add a value head to the network). The training loop is where the complexity lives.

A simpler improvement: reward shaping

Before jumping to PPO, try improving the reward signal. REINFORCE with a better reward function can outperform REINFORCE with a bad reward function and a fancier algorithm.

Distance to food. A small positive reward for getting closer to the nearest food. This gives the network signal on every turn, not only when it eats. Space count. Count the number of reachable cells from the head (flood fill) — more space means more room to maneuver, so reward the network for staying in open areas. Opponent distance. A small penalty for being close to the opponent’s head when the opponent is larger, encouraging avoidance. Turns alive. Instead of +0.01 per turn, scale it by current health — a snake at health 5 gets a bigger survival bonus per turn than a snake at health 95.

The key principle: the more informative the reward, the faster the network learns. If the reward only fires at the end of the episode (win/lose), learning is slow because the credit assignment problem is hard. If the reward fires on every turn with a clear gradient (this move was slightly better than that move), learning is fast.

The bottleneck: training speed

CPU training is fine for development. For production-quality results, GPU training is the difference between “trained overnight” and “trained over the weekend.”

GPU training with Candle

Candle supports CUDA via the cuda feature flag:

[dependencies]
candle-core = { version = "0.6", features = ["cuda"] }
candle-nn = { version = "0.6", features = ["cuda"] }

Then create a CUDA device:

#![allow(unused)]
fn main() {
let device = candle_core::Device::new_cuda(0)?;
}

Most tensor operations work the same on CUDA as on CPU. The model code doesn’t change. The training loop doesn’t change. The only difference is which device the tensors live on.

For self-play, the speedup is significant. A forward pass on CPU takes ~1ms. On a modern GPU, it takes ~0.01ms. Across 10,000 training episodes, that’s the difference between 3 hours and 2 minutes.

Vectorized environments

The GameEnv from Part 8 runs one game at a time. On GPU, we can run many games in parallel by stacking the board states into a batch tensor. Instead of encoding one board into (1, 8, 11, 11), encode 64 boards into (64, 8, 11, 11) and run a single forward pass.

#![allow(unused)]
fn main() {
// Instead of:
for board in &boards {
    let tensor = encode_board(board, my_snake)?;
    let logits = net.forward(&tensor.unsqueeze(0)?)?;
    // ...
}

// Do:
let batch: Vec<f32> = boards.iter()
    .flat_map(|b| encode_board_flat(b))
    .collect();
let batch_t = Tensor::from_vec(batch, (64, 8, 11, 11), &device)?;
let logits = net.forward(&batch_t)?;  // single forward pass for 64 boards
}

Vectorized environments are how large-scale RL systems (AlphaGo, AlphaStar) achieve their training throughput. For BattleSnake, vectorizing 64 environments gives roughly a 50× speedup over sequential play (the forward pass is 64× faster; the game simulation is still sequential but negligible compared to inference).

The bottleneck: representation

The 8-channel encoding from Part 2 is a reasonable starting point, but it loses information. Here are some directions to explore:

Relative encoding

Instead of encoding the board from a global perspective (food at (5, 3), my head at (2, 1)), encode it relative to the snake’s head. The head is always at the center of the tensor. The rest of the board is shifted accordingly.

This has two benefits. First, the network doesn’t need to learn that (5, 3) and (2, 1) are close — it can see it directly in the input. Second, the same local pattern (food one cell to the right) always looks the same regardless of the head’s absolute position. This is translational invariance, and convolutional networks exploit it naturally.

The implementation requires rotating the board so that “up” (the direction the snake is facing) is always the same channel. This adds complexity to the encoder but significantly improves the network’s ability to generalize.

Temporal stacking

Feed the last 3–5 board states as additional channels. Instead of 8 channels, the input becomes 8 × 5 = 40 channels (current state + 4 previous states). This gives the network motion information — it can see that the enemy is moving toward it, or that it’s circling the same area.

The cost is 5× more input channels, which means 5× more parameters in the first convolutional layer. For a small CNN, this is manageable.

Full board state

Our encoding drops some information: snake IDs, snake lengths, the game ruleset, the turn number. Some of these are useful. Snake length matters for head-to-head collisions (the longer snake wins). The ruleset matters because “wrapped” mode has no walls — the edges connect. The turn number matters because food spawning behavior changes as the game progresses.

Add channels for:

  • Snake length (normalized, one channel per snake)
  • Turn number (repeated across the board)
  • Game mode indicator (0 = standard, 1 = wrapped)

Each addition costs one channel. The network gets more signal for minimal overhead.

Where to go from here

The techniques in this part are ordered roughly by effort-to-reward ratio:

  1. Reward shaping — easiest to try, often the biggest improvement. Add a distance-to-food bonus and see if training converges faster.
  2. CNN architecture — swap the MLP for a CNN. More parameter-efficient for spatial data, faster convergence.
  3. GPU training — unlocks larger batch sizes and more training iterations. Necessary for PPO.
  4. PPO — the standard RL algorithm. More stable training, better sample efficiency. Requires a value function.
  5. Relative encoding — significant representational improvement, but requires rewriting the encoder.
  6. Temporal stacking — gives the network motion information. Easy to implement once you have the CNN working.
  7. Vectorized environments — large-scale training throughput. Requires GPU.

Start at the top. Measure after each change. If reward shaping gets you from 40% win rate to 55%, that’s real progress. If switching to PPO gets you from 55% to 60%, that’s also real progress. But if PPO gets you from 40% to 42% while reward shaping gets you to 55%, the reward signal was the real problem — fix that first.

The snake is live. The foundation is solid. Everything from here is iteration.

Previous: Part 9 — Wiring It to the Web