Becoming a Backprop Ninja

Goal: Manually derive and implement the backward pass through every operation in the MLP language model — cross-entropy, tanh, batch normalization, linear layers, embedding lookup. Verify each against PyTorch autograd. Inspired by Karpathy’s makemore Part 4.

Prerequisites: Backpropagation, Chain Rule, Batch Normalization, 07 - Backpropagation Step by Step, 17 - MLP Language Model


Why Do This?

Tutorial 07 traced gradients through 5 weights. Real networks have thousands. This tutorial bridges that gap — you’ll backprop through cross-entropy, batch normalization, and matrix operations on real-sized tensors. After this, autograd is no longer magic.


Setup: The Forward Pass

Build the MLP from tutorial 17, capture every intermediate value:

import torch
import torch.nn.functional as F
 
# Minimal dataset
words = ["emma", "olivia", "ava", "sophia", "isabella", "mia", "charlotte",
         "amelia", "harper", "evelyn", "abigail", "emily", "ella", "grace"]
 
chars = sorted(set(''.join(words)))
stoi = {c: i+1 for i, c in enumerate(chars)}; stoi['.'] = 0
itos = {i: c for c, i in stoi.items()}
vocab_size = len(stoi)
block_size = 3
 
def build_dataset(words):
    X, Y = [], []
    for w in words:
        ctx = [0] * block_size
        for ch in w + '.':
            X.append(ctx[:]); Y.append(stoi[ch]); ctx = ctx[1:] + [stoi[ch]]
    return torch.tensor(X), torch.tensor(Y)
 
X, Y = build_dataset(words)
 
# Model parameters
n_embed, n_hidden = 10, 64
torch.manual_seed(42)
 
C  = torch.randn(vocab_size, n_embed)
W1 = torch.randn(block_size * n_embed, n_hidden) * (5/3) / (block_size * n_embed)**0.5
b1 = torch.randn(n_hidden) * 0.01
# BatchNorm parameters
bn_gain = torch.ones(n_hidden)
bn_bias = torch.zeros(n_hidden)
W2 = torch.randn(n_hidden, vocab_size) * 0.01
b2 = torch.zeros(vocab_size)
 
parameters = [C, W1, b1, bn_gain, bn_bias, W2, b2]
for p in parameters:
    p.requires_grad = True

Complete forward pass with all intermediates

n = len(X)
 
# 1. Embedding lookup
emb = C[X]                                          # (n, block_size, n_embed)
emb_cat = emb.view(n, -1)                           # (n, block_size * n_embed)
 
# 2. First linear layer
h_prebn = emb_cat @ W1 + b1                         # (n, n_hidden)
 
# 3. Batch normalization
bn_mean = h_prebn.mean(dim=0, keepdim=True)          # (1, n_hidden)
bn_diff = h_prebn - bn_mean                          # (n, n_hidden)
bn_diff_sq = bn_diff ** 2                             # (n, n_hidden)
bn_var = bn_diff_sq.mean(dim=0, keepdim=True)        # (1, n_hidden)  (1/n * sum)
bn_var_inv = (bn_var + 1e-5) ** -0.5                  # (1, n_hidden)
bn_raw = bn_diff * bn_var_inv                         # (n, n_hidden)  normalized
h_preact = bn_gain * bn_raw + bn_bias                 # (n, n_hidden)
 
# 4. Activation
h = torch.tanh(h_preact)                              # (n, n_hidden)
 
# 5. Output layer
logits = h @ W2 + b2                                  # (n, vocab_size)
 
# 6. Cross-entropy loss (manual, not F.cross_entropy)
logit_maxes = logits.max(dim=1, keepdim=True).values
norm_logits = logits - logit_maxes                     # for numerical stability
counts = norm_logits.exp()
counts_sum = counts.sum(dim=1, keepdim=True)
counts_sum_inv = counts_sum ** -1
probs = counts * counts_sum_inv                        # softmax
logprobs = probs.log()
loss = -logprobs[range(n), Y].mean()
 
print(f"Forward pass loss: {loss.item():.4f}")

The Backward Pass: Layer by Layer

We’ll compute gradients for every intermediate tensor and verify each against PyTorch:

# Let PyTorch compute the reference gradients
for p in parameters:
    p.grad = None
loss.backward()
 
def cmp(name, dt, t):
    """Compare our manual gradient with PyTorch's."""
    ex = torch.all(dt == t.grad).item()
    approx = torch.allclose(dt, t.grad, atol=1e-5)
    maxdiff = (dt - t.grad).abs().max().item()
    print(f"{name:15s} | exact: {ex} | approx: {approx} | maxdiff: {maxdiff:.6e}")

6. Cross-entropy backward

# d(loss)/d(logprobs)
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(n), Y] = -1.0 / n
 
# d(logprobs)/d(probs) — derivative of log
dprobs = (1.0 / probs) * dlogprobs
 
# d(probs)/d(counts_sum_inv)
dcounts_sum_inv = (counts * dprobs).sum(dim=1, keepdim=True)
 
# d(probs)/d(counts)  — two paths: through probs and through counts_sum
dcounts = counts_sum_inv * dprobs
 
# d(counts_sum_inv)/d(counts_sum)
dcounts_sum = -counts_sum ** -2 * dcounts_sum_inv
 
# d(counts_sum)/d(counts) — sum broadcasts, so gradient fans out
dcounts += torch.ones_like(counts) * dcounts_sum
 
# d(counts)/d(norm_logits) — derivative of exp
dnorm_logits = counts * dcounts
 
# d(norm_logits)/d(logits) — subtract doesn't change gradient (maxes treated as constant)
dlogits = dnorm_logits.clone()
# The max subtraction path: sum of gradient flowing back
dlogit_maxes = -dnorm_logits.sum(dim=1, keepdim=True)
# But logit_maxes is detached (argmax is not differentiable), so dlogits is just dnorm_logits
# Actually for softmax: dlogits = probs - one_hot(Y) / n  (the famous shortcut)
 
# Verify against the shortcut
dlogits_shortcut = probs.clone()
dlogits_shortcut[range(n), Y] -= 1
dlogits_shortcut /= n
print(f"Softmax gradient matches shortcut: {torch.allclose(dlogits, dlogits_shortcut, atol=1e-5)}")

5. Output layer backward

# logits = h @ W2 + b2
dh = dlogits @ W2.T                          # (n, n_hidden)
dW2 = h.T @ dlogits                          # (n_hidden, vocab_size)
db2 = dlogits.sum(dim=0)                     # (vocab_size,)
 
cmp('W2', dW2, W2)
cmp('b2', db2, b2)

4. tanh backward

# h = tanh(h_preact)  →  dh_preact = dh * (1 - h^2)
dh_preact = dh * (1.0 - h ** 2)

3. Batch normalization backward (the hard one)

# h_preact = bn_gain * bn_raw + bn_bias
dbn_gain = (bn_raw * dh_preact).sum(dim=0)
dbn_raw = bn_gain * dh_preact
dbn_bias = dh_preact.sum(dim=0)
 
cmp('bn_gain', dbn_gain, bn_gain)
cmp('bn_bias', dbn_bias, bn_bias)
 
# bn_raw = bn_diff * bn_var_inv
dbn_diff = bn_var_inv * dbn_raw
dbn_var_inv = (bn_diff * dbn_raw).sum(dim=0, keepdim=True)
 
# bn_var_inv = (bn_var + eps) ^ -0.5
dbn_var = -0.5 * (bn_var + 1e-5) ** -1.5 * dbn_var_inv
 
# bn_var = 1/n * sum(bn_diff^2)
dbn_diff_sq = (1.0 / n) * torch.ones_like(bn_diff_sq) * dbn_var
 
# bn_diff_sq = bn_diff ^ 2
dbn_diff += 2.0 * bn_diff * dbn_diff_sq
 
# bn_diff = h_prebn - bn_mean
dh_prebn = dbn_diff.clone()
dbn_mean = -dbn_diff.sum(dim=0, keepdim=True)
 
# bn_mean = 1/n * sum(h_prebn)
dh_prebn += (1.0 / n) * torch.ones_like(h_prebn) * dbn_mean

The BatchNorm backward shortcut

All of the above simplifies to one formula:

# The efficient BatchNorm backward (what frameworks actually use):
dh_prebn_fast = (1.0 / n) * bn_var_inv * (
    n * dbn_raw
    - dbn_raw.sum(dim=0)
    - bn_raw * (dbn_raw * bn_raw).sum(dim=0)
)
print(f"BN backward shortcut matches: {torch.allclose(dh_prebn, dh_prebn_fast, atol=1e-5)}")

2. First linear layer backward

# h_prebn = emb_cat @ W1 + b1
demb_cat = dh_prebn @ W1.T
dW1 = emb_cat.T @ dh_prebn
db1 = dh_prebn.sum(dim=0)
 
cmp('W1', dW1, W1)
cmp('b1', db1, b1)

1. Embedding backward

# emb_cat = emb.view(n, -1)  →  just reshape
demb = demb_cat.view(emb.shape)
 
# emb = C[X]  →  scatter the gradients back to the embedding table
dC = torch.zeros_like(C)
for i in range(n):
    for j in range(block_size):
        dC[X[i, j]] += demb[i, j]
 
cmp('C', dC, C)

All Gradients Verified

print("\n=== Final verification ===")
cmp('C ', dC, C)
cmp('W1', dW1, W1)
cmp('b1', db1, b1)
cmp('bn_gain', dbn_gain, bn_gain)
cmp('bn_bias', dbn_bias, bn_bias)
cmp('W2', dW2, W2)
cmp('b2', db2, b2)

Every gradient should show approx: True. If not, trace back through the chain rule — one of the intermediate derivatives has a bug.


Key Patterns

OperationForwardBackward
sumfan out (copy gradient to both)
multiplyswap and multiply:
matmul,
reducebroadcast (expand gradient)
power
tanh
exp
softmax+CEcomplex (beautiful)

Exercises

  1. Add dropout: Insert dropout after the tanh. Forward: zero random elements, scale by . Backward: zero the same elements, scale the same way.

  2. Cross-attention backward: Implement the backward pass through from 08 - Attention Mechanism from Scratch.

  3. Residual connection backward: Add a skip connection: . What’s the backward? (Hint: it’s trivially simple — that’s why residuals work.)

  4. Speed comparison: Time your manual backward vs loss.backward(). How many orders of magnitude slower is the manual version?


Next: 20 - Build GPT from Scratch — put attention, embeddings, and layer norms together into a transformer.