Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Part 3 — Your First Network

Last time we turned the board into 968 numbers. Now we feed those numbers into a neural network and get a move out.

Let’s demystify what a neural network is actually doing in this context.

What a network is

A neural network is a function. It takes a list of numbers in and spits a list of numbers out. That’s all. The “neural” part is a particular way of composing simple mathematical functions — linear transforms followed by nonlinear activations — so that the whole thing can be trained to approximate any mapping.

For us: input is 968 numbers (the flattened board encoding). Output is 4 numbers — the “score” for each of the four directions. We pick the highest score.

The intermediate layers are where the magic happens. Each layer is a set of “units” (neurons) that look for patterns in the previous layer’s output. Early layers tend to learn simple spatial patterns (edges, food nearby). Later layers combine those into higher-level concepts (a path to food, a trap closing in).

The architecture

We’ll build a simple multi-layer perceptron (MLP) — three linear layers with ReLU (Rectified Linear Unit — the activation function that outputs max(0, x)) activations between them:

Input (968) → Linear(968 → 256) → ReLU
           → Linear(256 → 128) → ReLU
           → Linear(128 → 4)   → logits

The final layer has 4 outputs (one per direction) without an activation — we apply softmax afterward to turn logits into probabilities.

In Candle:

#![allow(unused)]
fn main() {
use candle_nn::{linear, Linear, Module, VarBuilder};
use candle_core::{Device, Result, Tensor};

pub struct SnakeNet {
    // Three linear layers
    l1: Linear,
    l2: Linear,
    l3: Linear,
}

impl SnakeNet {
    pub fn new(vs: VarBuilder) -> Result<Self> {
        // 11 * 11 * 8 = 968 input features
        let l1 = linear(968, 256, vs)?;
        let l2 = linear(256, 128, vs)?;
        let l3 = linear(128, 4, vs)?;  // 4 directions
        Ok(Self { l1, l2, l3 })
    }

    /// Forward pass: flat_input has shape (batch, 968)
    pub fn forward(&self, flat_input: &Tensor) -> Result<Tensor> {
        let x = self.l1.forward(flat_input)?.relu()?;
        let x = self.l2.forward(&x)?.relu()?;
        self.l3.forward(&x)
    }

    /// Apply softmax with optional temperature.
    /// Higher temperature → softer, more exploratory distributions.
    /// Lower temperature → sharper, more greedy.
    /// A few useful reference points:
    ///   - temperature = 1.0  → standard softmax (default)
    ///   - temperature = 0.5  → sharper, more confident predictions
    ///   - temperature = 2.0  → flatter, more random-looking distributions
    /// This matters in Part 6 when we add entropy regularization.
    pub fn probs(&self, flat_input: &Tensor, temperature: f32) -> Result<Tensor> {
        let logits = self.forward(flat_input)?;
        let scaled = if (temperature - 1.0).abs() > 1e-5 {
            // Scale by 1/temperature. Create a (1,1) tensor that broadcasts
            // to (batch, 4) — Candle broadcasts (1,1) to (N, M) but (1,)
            // doesn't broadcast to (N, M) because the rank differs.
            let scale = Tensor::new(&[[1.0_f32 / temperature]], &Device::Cpu)?;
            (&logits * &scale)?
        } else {
            logits
        };
        candle_nn::ops::softmax(&scaled, 1)
    }

    /// Pick the direction with the highest logit score (greedy).
    pub fn pick_direction(&self, flat_input: &Tensor) -> Result<Direction> {
        let logits = self.forward(flat_input)?;
        // argmax(1) on (batch, 4) → (batch,). squeeze(0) → scalar.
        // Without squeeze, to_scalar() fails — it expects a 0-D tensor, not (1,).
        let idx = logits.argmax(1)?.squeeze(0)?.to_scalar::<u32>()? as usize;
        Direction::from_index(idx).ok_or_else(|| {
            candle_core::Error::Msg(format!("invalid direction index: {idx}"))
        })
    }
}

#[derive(Debug, Clone, Copy)]
pub enum Direction {
    Up, Down, Left, Right,
}

impl Direction {
    pub fn from_index(idx: usize) -> Option<Self> {
        match idx {
            0 => Some(Self::Up),
            1 => Some(Self::Down),
            2 => Some(Self::Left),
            3 => Some(Self::Right),
            _ => None,
        }
    }

    pub fn as_str(&self) -> &'static str {
        match self {
            Self::Up    => "up",
            Self::Down  => "down",
            Self::Left  => "left",
            Self::Right => "right",
        }
    }

    /// Direction as a 0-based index (Up=0, Down=1, Left=2, Right=3).
    /// Used by the training loop to record actions.
    pub fn index(&self) -> usize {
        match self {
            Self::Up    => 0,
            Self::Down  => 1,
            Self::Left  => 2,
            Self::Right => 3,
        }
    }
}
}

To load from saved weights later:

#![allow(unused)]
fn main() {
use candle_nn::{VarBuilder, VarMap};
use candle_core::{DType, Device};
use std::sync::Arc;

pub struct Model {
    device: Arc<Device>,
    var_map: VarMap,
    net: SnakeNet,
}

impl Model {
    /// Build a new model with fresh random weights
    pub fn new() -> Result<Self> {
        let device = Device::Cpu;
        let var_map = VarMap::new();
        let vs = VarBuilder::from_varmap(&var_map, DType::F32, &device);
        let net = SnakeNet::new(vs)?;
        Ok(Self { device: Arc::new(device), var_map, net })
    }

    /// Load weights from a file (for Part 8)
    pub fn load(path: &str) -> Result<Self> {
        let device = Device::Cpu;
        let mut var_map = VarMap::new();
        var_map.load(path)?; // in-place load
        let vs = VarBuilder::from_varmap(&var_map, DType::F32, &device);
        let net = SnakeNet::new(vs)?;
        Ok(Self { device: Arc::new(device), var_map, net })
    }

    pub fn pick_direction(&self, flat_input: &Tensor) -> Result<Direction> {
        self.net.pick_direction(flat_input)
    }
}
}

Run it on the encoded board:

#![allow(unused)]
fn main() {
use snake_ml::{encode_board, Model, Direction};

fn decide_move(board: &Board, my_snake: &Snake) -> Result<Direction, candle_core::Error> {
    let model = Model::new()?;
    let tensor = encode_board(board, my_snake)?;
    let flat = tensor.flatten_all()?;
    // Add batch dimension: (1, 968)
    let flat = flat.unsqueeze(0)?;
    model.pick_direction(&flat)
}
}

And that’s the forward pass working end-to-end.

What the network learns (and what it doesn’t)

Here’s the thing: a freshly initialized network is worthless. The weights start random. The output is gibberish. A random network will pick moves about as well as a dart-throwing monkey.

The network needs two things before it becomes useful:

  1. Architecture — we’ve built that. The shape of the function.
  2. Weights — the numbers inside. Right now they’re random. They need to be trained.

Training is the process of adjusting the weights to make the network output good moves. There are two ways to train this network:

  • Imitation learning (Part 5): generate data from a teacher (the heuristic from Part 4), then train the network to copy the teacher. This gets us to “decent” fast.
  • Reinforcement learning (Parts 6–7): let the snake play, learn from wins and losses, discover strategies the teacher never thought of.

We’ll tackle imitation learning first — it’s simpler, more predictable, and produces a solid baseline to build on.

But before all that, Part 4 builds the teacher. The heuristic snake that will feed us training data. That’s where A* pathfinding comes in.

Previous: Part 2 — The Board as Numbers · Next: Part 4 — The Heuristic Baseline