Activations and Initialization Deep Dive

Goal: Diagnose unhealthy activations, dead neurons, and saturated gradients inside a neural network. Fix them with proper initialization and batch normalization. Inspired by Karpathy’s makemore Part 3.

Prerequisites: Vanishing and Exploding Gradients, Batch Normalization, Neurons and Activation Functions, 17 - MLP Language Model


The Problem You Can’t See

A network trains, loss decreases… slowly. You add layers, loss gets worse. Why? The activations inside the network are sick — but you never look at them.

This tutorial is about opening the hood and diagnosing what’s happening inside.


Setup: A Deep MLP

We’ll build the character-level MLP from tutorial 17, but deeper, and instrument it:

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
 
# Dataset (reuse from tutorial 17)
words = open('names.txt', 'r').read().splitlines() if __import__('os').path.exists('names.txt') else [
    "emma", "olivia", "ava", "sophia", "isabella", "mia", "charlotte", "amelia",
    "harper", "evelyn", "abigail", "emily", "elizabeth", "sofia", "avery",
    "ella", "scarlett", "grace", "chloe", "victoria", "riley", "aria", "lily",
    "aurora", "zoey", "nora", "luna", "hannah", "penelope", "layla",
]
 
chars = sorted(set(''.join(words)))
stoi = {c: i+1 for i, c in enumerate(chars)}; stoi['.'] = 0
itos = {i: c for c, i in stoi.items()}
vocab_size = len(stoi)
 
block_size = 3
def build_dataset(words):
    X, Y = [], []
    for w in words:
        ctx = [0] * block_size
        for ch in w + '.':
            X.append(ctx[:]); Y.append(stoi[ch]); ctx = ctx[1:] + [stoi[ch]]
    return torch.tensor(X), torch.tensor(Y)
 
X_train, Y_train = build_dataset(words)

Bad Initialization: Watch It Break

torch.manual_seed(42)
n_embed, n_hidden = 10, 100
 
C = torch.randn(vocab_size, n_embed)
# Deliberately bad: large random weights
W1 = torch.randn(block_size * n_embed, n_hidden) * 1.0  # too large
b1 = torch.randn(n_hidden) * 1.0
W2 = torch.randn(n_hidden, n_hidden) * 1.0
b2 = torch.randn(n_hidden) * 1.0
W3 = torch.randn(n_hidden, vocab_size) * 1.0
b3 = torch.randn(vocab_size) * 0
 
# Forward pass and capture activations
emb = C[X_train[:1000]]
x = emb.view(-1, block_size * n_embed)
h1_pre = x @ W1 + b1;   h1 = torch.tanh(h1_pre)
h2_pre = h1 @ W2 + b2;  h2 = torch.tanh(h2_pre)
logits = h2 @ W3 + b3
 
# Visualize activation distributions
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for ax, data, title in [
    (axes[0], h1_pre.detach(), "Layer 1 pre-activation"),
    (axes[1], h1.detach(), "Layer 1 post-tanh"),
    (axes[2], h2.detach(), "Layer 2 post-tanh"),
]:
    ax.hist(data.numpy().ravel(), bins=50, density=True)
    ax.set_title(title)
    ax.set_xlabel("value")
 
plt.suptitle("BAD INIT: activations saturated at ±1", fontsize=14, color='red')
plt.tight_layout()
plt.show()
 
# What fraction of neurons are saturated?
print(f"Layer 1: {(h1.abs() > 0.99).float().mean():.1%} saturated")
print(f"Layer 2: {(h2.abs() > 0.99).float().mean():.1%} saturated")

Why this kills training

When tanh outputs are near ±1, its gradient is near 0. The gradient signal can’t flow backward. This is the vanishing gradient problem in action.


Fix 1: Kaiming Initialization

Scale weights so that the variance of activations stays ~1 across layers:

For tanh: where gain = 5/3

torch.manual_seed(42)
gain = 5/3  # recommended gain for tanh
 
C = torch.randn(vocab_size, n_embed)
W1 = torch.randn(block_size * n_embed, n_hidden) * gain / (block_size * n_embed)**0.5
b1 = torch.randn(n_hidden) * 0.01
W2 = torch.randn(n_hidden, n_hidden) * gain / n_hidden**0.5
b2 = torch.randn(n_hidden) * 0.01
W3 = torch.randn(n_hidden, vocab_size) * 0.01
b3 = torch.randn(vocab_size) * 0
 
# Forward pass
emb = C[X_train[:1000]]
x = emb.view(-1, block_size * n_embed)
h1_pre = x @ W1 + b1;   h1 = torch.tanh(h1_pre)
h2_pre = h1 @ W2 + b2;  h2 = torch.tanh(h2_pre)
 
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for ax, data, title in [
    (axes[0], h1_pre.detach(), "Layer 1 pre-activation"),
    (axes[1], h1.detach(), "Layer 1 post-tanh"),
    (axes[2], h2.detach(), "Layer 2 post-tanh"),
]:
    ax.hist(data.numpy().ravel(), bins=50, density=True)
    ax.set_title(title)
 
plt.suptitle("KAIMING INIT: activations nicely distributed", fontsize=14, color='green')
plt.tight_layout()
plt.show()
 
print(f"Layer 1: {(h1.abs() > 0.99).float().mean():.1%} saturated")
print(f"Layer 2: {(h2.abs() > 0.99).float().mean():.1%} saturated")

Fix 2: Batch Normalization

Force each layer’s pre-activations to be zero-mean, unit-variance — then let the network learn the optimal scale and shift:

class BatchNorm1d:
    def __init__(self, dim):
        self.gamma = torch.ones(dim)   # learnable scale
        self.beta = torch.zeros(dim)   # learnable shift
        # Running stats for inference
        self.running_mean = torch.zeros(dim)
        self.running_var = torch.ones(dim)
        self.training = True
 
    def __call__(self, x):
        if self.training:
            mean = x.mean(dim=0)
            var = x.var(dim=0)
            # Update running stats
            with torch.no_grad():
                self.running_mean = 0.999 * self.running_mean + 0.001 * mean
                self.running_var = 0.999 * self.running_var + 0.001 * var
        else:
            mean = self.running_mean
            var = self.running_var
 
        x_hat = (x - mean) / torch.sqrt(var + 1e-5)
        return self.gamma * x_hat + self.beta
 
    def parameters(self):
        return [self.gamma, self.beta]
 
# Apply to our network
bn1 = BatchNorm1d(n_hidden)
bn2 = BatchNorm1d(n_hidden)
 
emb = C[X_train[:1000]]
x = emb.view(-1, block_size * n_embed)
h1_pre = x @ W1 + b1;   h1_bn = bn1(h1_pre);  h1 = torch.tanh(h1_bn)
h2_pre = h1 @ W2 + b2;  h2_bn = bn2(h2_pre);  h2 = torch.tanh(h2_bn)
 
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for ax, data, title in [
    (axes[0], h1_bn.detach(), "Layer 1 after BatchNorm"),
    (axes[1], h1.detach(), "Layer 1 post-tanh"),
    (axes[2], h2.detach(), "Layer 2 post-tanh"),
]:
    ax.hist(data.numpy().ravel(), bins=50, density=True)
    ax.set_title(title)
 
plt.suptitle("WITH BATCHNORM: perfectly conditioned", fontsize=14, color='green')
plt.tight_layout()
plt.show()

Diagnostic: Gradient Distributions

Healthy gradients have similar magnitude across layers. Unhealthy = they shrink or grow:

# Full model with both fixes
torch.manual_seed(42)
C = torch.randn(vocab_size, n_embed)
layers = [
    (torch.randn(block_size * n_embed, n_hidden) * gain / (block_size * n_embed)**0.5,
     torch.zeros(n_hidden)),
    (torch.randn(n_hidden, n_hidden) * gain / n_hidden**0.5,
     torch.zeros(n_hidden)),
    (torch.randn(n_hidden, n_hidden) * gain / n_hidden**0.5,
     torch.zeros(n_hidden)),
    (torch.randn(n_hidden, vocab_size) * 0.01,
     torch.zeros(vocab_size)),
]
 
parameters = [C]
for W, b in layers:
    W.requires_grad = True; b.requires_grad = True
    parameters.extend([W, b])
 
# Forward
emb = C[X_train[:500]]
x = emb.view(-1, block_size * n_embed)
activations = [x]
for i, (W, b) in enumerate(layers[:-1]):
    x = torch.tanh(x @ W + b)
    activations.append(x)
W_last, b_last = layers[-1]
logits = x @ W_last + b_last
loss = F.cross_entropy(logits, Y_train[:500])
 
# Backward
for p in parameters:
    p.grad = None
loss.backward()
 
# Plot gradient statistics per layer
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
 
# Activation means and stds
means = [a.detach().mean().item() for a in activations]
stds = [a.detach().std().item() for a in activations]
axes[0].bar(range(len(means)), stds)
axes[0].set_xlabel("Layer"); axes[0].set_ylabel("Std of activations")
axes[0].set_title("Activation std per layer (should be ~constant)")
 
# Gradient magnitudes
grad_means = []
for W, b in layers:
    grad_means.append(W.grad.abs().mean().item())
axes[1].bar(range(len(grad_means)), grad_means)
axes[1].set_xlabel("Layer"); axes[1].set_ylabel("Mean |gradient|")
axes[1].set_title("Gradient magnitude per layer")
axes[1].set_yscale("log")
 
plt.tight_layout()
plt.show()

Diagnostic: Dead Neuron Check

With ReLU, neurons that always output 0 are “dead” — they never recover:

# Replace tanh with ReLU and check
torch.manual_seed(42)
W1_relu = torch.randn(block_size * n_embed, n_hidden) * (2 / (block_size * n_embed))**0.5
b1_relu = torch.zeros(n_hidden)
 
emb = C[X_train[:1000]]
x = emb.view(-1, block_size * n_embed)
h = F.relu(x @ W1_relu + b1_relu)
 
dead_fraction = (h == 0).all(dim=0).float().mean()
print(f"Dead neurons: {dead_fraction:.1%}")
 
# Visualize: which neurons are active
plt.figure(figsize=(14, 3))
plt.imshow(h[:100].detach().numpy().T, aspect='auto', cmap='viridis')
plt.xlabel("Sample"); plt.ylabel("Neuron")
plt.title("ReLU activations — black columns = dead neurons")
plt.colorbar()
plt.show()

Training Comparison: Bad Init vs Good Init vs BatchNorm

def train_model(init_type, steps=5000):
    torch.manual_seed(42)
    C = torch.randn(vocab_size, n_embed)
 
    if init_type == "bad":
        W1 = torch.randn(block_size * n_embed, n_hidden) * 1.0
        W2 = torch.randn(n_hidden, vocab_size) * 1.0
    elif init_type == "kaiming":
        W1 = torch.randn(block_size * n_embed, n_hidden) * gain / (block_size * n_embed)**0.5
        W2 = torch.randn(n_hidden, vocab_size) * 0.01
    elif init_type == "batchnorm":
        W1 = torch.randn(block_size * n_embed, n_hidden) * 1.0  # bad init, but BN fixes it
        W2 = torch.randn(n_hidden, vocab_size) * 0.01
 
    b1 = torch.zeros(n_hidden)
    b2 = torch.zeros(vocab_size)
 
    bn = BatchNorm1d(n_hidden) if init_type == "batchnorm" else None
    params = [C, W1, b1, W2, b2]
    if bn: params += bn.parameters()
    for p in params: p.requires_grad = True
 
    losses = []
    for step in range(steps):
        ix = torch.randint(0, len(X_train), (32,))
        emb = C[X_train[ix]].view(-1, block_size * n_embed)
        h_pre = emb @ W1 + b1
        if bn: h_pre = bn(h_pre)
        h = torch.tanh(h_pre)
        logits = h @ W2 + b2
        loss = F.cross_entropy(logits, Y_train[ix])
 
        for p in params: p.grad = None
        loss.backward()
        lr = 0.1 if step < 3000 else 0.01
        for p in params: p.data -= lr * p.grad
        losses.append(loss.item())
 
    return losses
 
fig, ax = plt.subplots(figsize=(10, 5))
for init_type in ["bad", "kaiming", "batchnorm"]:
    losses = train_model(init_type)
    # Smooth for plotting
    smooth = [sum(losses[max(0,i-50):i+1])/min(i+1,51) for i in range(len(losses))]
    ax.plot(smooth, label=init_type)
ax.set_xlabel("Step"); ax.set_ylabel("Loss")
ax.legend(); ax.set_title("Impact of initialization and normalization")
plt.show()

The Diagnostic Checklist

Run these checks whenever a network trains slowly or poorly:

CheckHealthySick
Pre-activation std~1.0 per layerGrows or shrinks across layers
Post-tanh distributionSpread between -1 and 1Piled up at ±1 (saturated)
ReLU dead fraction<5%>20% (dead neurons)
Gradient magnitudeSimilar across layersShrinks 10x+ per layer
Initial loss for classesMuch higher (confident wrong predictions)

Exercises

  1. 10-layer network: Stack 10 hidden layers with bad init. Visualize activations — they’ll be all ±1 by layer 3. Add batchnorm to each and show the fix.

  2. LayerNorm vs BatchNorm: Implement LayerNorm (normalize across features, not batch). Compare training curves. LayerNorm doesn’t depend on batch size.

  3. Weight histograms over training: Every 1000 steps, save a histogram of each weight matrix. Plot them as a grid. Healthy weights change smoothly; sick weights barely move.

  4. The gain matters: Try gain values of 0.5, 1.0, 5/3, 2.0, 5.0 for tanh init. For each, plot activation distributions. Find the sweet spot.


Next: 19 - Becoming a Backprop Ninja — compute every gradient by hand through this network.