Activations and Initialization Deep Dive
Goal: Diagnose unhealthy activations, dead neurons, and saturated gradients inside a neural network. Fix them with proper initialization and batch normalization. Inspired by Karpathy’s makemore Part 3.
Prerequisites: Vanishing and Exploding Gradients, Batch Normalization, Neurons and Activation Functions, 17 - MLP Language Model
The Problem You Can’t See
A network trains, loss decreases… slowly. You add layers, loss gets worse. Why? The activations inside the network are sick — but you never look at them.
This tutorial is about opening the hood and diagnosing what’s happening inside.
Setup: A Deep MLP
We’ll build the character-level MLP from tutorial 17, but deeper, and instrument it:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
# Dataset (reuse from tutorial 17)
words = open('names.txt', 'r').read().splitlines() if __import__('os').path.exists('names.txt') else [
"emma", "olivia", "ava", "sophia", "isabella", "mia", "charlotte", "amelia",
"harper", "evelyn", "abigail", "emily", "elizabeth", "sofia", "avery",
"ella", "scarlett", "grace", "chloe", "victoria", "riley", "aria", "lily",
"aurora", "zoey", "nora", "luna", "hannah", "penelope", "layla",
]
chars = sorted(set(''.join(words)))
stoi = {c: i+1 for i, c in enumerate(chars)}; stoi['.'] = 0
itos = {i: c for c, i in stoi.items()}
vocab_size = len(stoi)
block_size = 3
def build_dataset(words):
X, Y = [], []
for w in words:
ctx = [0] * block_size
for ch in w + '.':
X.append(ctx[:]); Y.append(stoi[ch]); ctx = ctx[1:] + [stoi[ch]]
return torch.tensor(X), torch.tensor(Y)
X_train, Y_train = build_dataset(words)Bad Initialization: Watch It Break
torch.manual_seed(42)
n_embed, n_hidden = 10, 100
C = torch.randn(vocab_size, n_embed)
# Deliberately bad: large random weights
W1 = torch.randn(block_size * n_embed, n_hidden) * 1.0 # too large
b1 = torch.randn(n_hidden) * 1.0
W2 = torch.randn(n_hidden, n_hidden) * 1.0
b2 = torch.randn(n_hidden) * 1.0
W3 = torch.randn(n_hidden, vocab_size) * 1.0
b3 = torch.randn(vocab_size) * 0
# Forward pass and capture activations
emb = C[X_train[:1000]]
x = emb.view(-1, block_size * n_embed)
h1_pre = x @ W1 + b1; h1 = torch.tanh(h1_pre)
h2_pre = h1 @ W2 + b2; h2 = torch.tanh(h2_pre)
logits = h2 @ W3 + b3
# Visualize activation distributions
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for ax, data, title in [
(axes[0], h1_pre.detach(), "Layer 1 pre-activation"),
(axes[1], h1.detach(), "Layer 1 post-tanh"),
(axes[2], h2.detach(), "Layer 2 post-tanh"),
]:
ax.hist(data.numpy().ravel(), bins=50, density=True)
ax.set_title(title)
ax.set_xlabel("value")
plt.suptitle("BAD INIT: activations saturated at ±1", fontsize=14, color='red')
plt.tight_layout()
plt.show()
# What fraction of neurons are saturated?
print(f"Layer 1: {(h1.abs() > 0.99).float().mean():.1%} saturated")
print(f"Layer 2: {(h2.abs() > 0.99).float().mean():.1%} saturated")Why this kills training
When tanh outputs are near ±1, its gradient is near 0. The gradient signal can’t flow backward. This is the vanishing gradient problem in action.
Fix 1: Kaiming Initialization
Scale weights so that the variance of activations stays ~1 across layers:
For tanh: where gain = 5/3
torch.manual_seed(42)
gain = 5/3 # recommended gain for tanh
C = torch.randn(vocab_size, n_embed)
W1 = torch.randn(block_size * n_embed, n_hidden) * gain / (block_size * n_embed)**0.5
b1 = torch.randn(n_hidden) * 0.01
W2 = torch.randn(n_hidden, n_hidden) * gain / n_hidden**0.5
b2 = torch.randn(n_hidden) * 0.01
W3 = torch.randn(n_hidden, vocab_size) * 0.01
b3 = torch.randn(vocab_size) * 0
# Forward pass
emb = C[X_train[:1000]]
x = emb.view(-1, block_size * n_embed)
h1_pre = x @ W1 + b1; h1 = torch.tanh(h1_pre)
h2_pre = h1 @ W2 + b2; h2 = torch.tanh(h2_pre)
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for ax, data, title in [
(axes[0], h1_pre.detach(), "Layer 1 pre-activation"),
(axes[1], h1.detach(), "Layer 1 post-tanh"),
(axes[2], h2.detach(), "Layer 2 post-tanh"),
]:
ax.hist(data.numpy().ravel(), bins=50, density=True)
ax.set_title(title)
plt.suptitle("KAIMING INIT: activations nicely distributed", fontsize=14, color='green')
plt.tight_layout()
plt.show()
print(f"Layer 1: {(h1.abs() > 0.99).float().mean():.1%} saturated")
print(f"Layer 2: {(h2.abs() > 0.99).float().mean():.1%} saturated")Fix 2: Batch Normalization
Force each layer’s pre-activations to be zero-mean, unit-variance — then let the network learn the optimal scale and shift:
class BatchNorm1d:
def __init__(self, dim):
self.gamma = torch.ones(dim) # learnable scale
self.beta = torch.zeros(dim) # learnable shift
# Running stats for inference
self.running_mean = torch.zeros(dim)
self.running_var = torch.ones(dim)
self.training = True
def __call__(self, x):
if self.training:
mean = x.mean(dim=0)
var = x.var(dim=0)
# Update running stats
with torch.no_grad():
self.running_mean = 0.999 * self.running_mean + 0.001 * mean
self.running_var = 0.999 * self.running_var + 0.001 * var
else:
mean = self.running_mean
var = self.running_var
x_hat = (x - mean) / torch.sqrt(var + 1e-5)
return self.gamma * x_hat + self.beta
def parameters(self):
return [self.gamma, self.beta]
# Apply to our network
bn1 = BatchNorm1d(n_hidden)
bn2 = BatchNorm1d(n_hidden)
emb = C[X_train[:1000]]
x = emb.view(-1, block_size * n_embed)
h1_pre = x @ W1 + b1; h1_bn = bn1(h1_pre); h1 = torch.tanh(h1_bn)
h2_pre = h1 @ W2 + b2; h2_bn = bn2(h2_pre); h2 = torch.tanh(h2_bn)
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for ax, data, title in [
(axes[0], h1_bn.detach(), "Layer 1 after BatchNorm"),
(axes[1], h1.detach(), "Layer 1 post-tanh"),
(axes[2], h2.detach(), "Layer 2 post-tanh"),
]:
ax.hist(data.numpy().ravel(), bins=50, density=True)
ax.set_title(title)
plt.suptitle("WITH BATCHNORM: perfectly conditioned", fontsize=14, color='green')
plt.tight_layout()
plt.show()Diagnostic: Gradient Distributions
Healthy gradients have similar magnitude across layers. Unhealthy = they shrink or grow:
# Full model with both fixes
torch.manual_seed(42)
C = torch.randn(vocab_size, n_embed)
layers = [
(torch.randn(block_size * n_embed, n_hidden) * gain / (block_size * n_embed)**0.5,
torch.zeros(n_hidden)),
(torch.randn(n_hidden, n_hidden) * gain / n_hidden**0.5,
torch.zeros(n_hidden)),
(torch.randn(n_hidden, n_hidden) * gain / n_hidden**0.5,
torch.zeros(n_hidden)),
(torch.randn(n_hidden, vocab_size) * 0.01,
torch.zeros(vocab_size)),
]
parameters = [C]
for W, b in layers:
W.requires_grad = True; b.requires_grad = True
parameters.extend([W, b])
# Forward
emb = C[X_train[:500]]
x = emb.view(-1, block_size * n_embed)
activations = [x]
for i, (W, b) in enumerate(layers[:-1]):
x = torch.tanh(x @ W + b)
activations.append(x)
W_last, b_last = layers[-1]
logits = x @ W_last + b_last
loss = F.cross_entropy(logits, Y_train[:500])
# Backward
for p in parameters:
p.grad = None
loss.backward()
# Plot gradient statistics per layer
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Activation means and stds
means = [a.detach().mean().item() for a in activations]
stds = [a.detach().std().item() for a in activations]
axes[0].bar(range(len(means)), stds)
axes[0].set_xlabel("Layer"); axes[0].set_ylabel("Std of activations")
axes[0].set_title("Activation std per layer (should be ~constant)")
# Gradient magnitudes
grad_means = []
for W, b in layers:
grad_means.append(W.grad.abs().mean().item())
axes[1].bar(range(len(grad_means)), grad_means)
axes[1].set_xlabel("Layer"); axes[1].set_ylabel("Mean |gradient|")
axes[1].set_title("Gradient magnitude per layer")
axes[1].set_yscale("log")
plt.tight_layout()
plt.show()Diagnostic: Dead Neuron Check
With ReLU, neurons that always output 0 are “dead” — they never recover:
# Replace tanh with ReLU and check
torch.manual_seed(42)
W1_relu = torch.randn(block_size * n_embed, n_hidden) * (2 / (block_size * n_embed))**0.5
b1_relu = torch.zeros(n_hidden)
emb = C[X_train[:1000]]
x = emb.view(-1, block_size * n_embed)
h = F.relu(x @ W1_relu + b1_relu)
dead_fraction = (h == 0).all(dim=0).float().mean()
print(f"Dead neurons: {dead_fraction:.1%}")
# Visualize: which neurons are active
plt.figure(figsize=(14, 3))
plt.imshow(h[:100].detach().numpy().T, aspect='auto', cmap='viridis')
plt.xlabel("Sample"); plt.ylabel("Neuron")
plt.title("ReLU activations — black columns = dead neurons")
plt.colorbar()
plt.show()Training Comparison: Bad Init vs Good Init vs BatchNorm
def train_model(init_type, steps=5000):
torch.manual_seed(42)
C = torch.randn(vocab_size, n_embed)
if init_type == "bad":
W1 = torch.randn(block_size * n_embed, n_hidden) * 1.0
W2 = torch.randn(n_hidden, vocab_size) * 1.0
elif init_type == "kaiming":
W1 = torch.randn(block_size * n_embed, n_hidden) * gain / (block_size * n_embed)**0.5
W2 = torch.randn(n_hidden, vocab_size) * 0.01
elif init_type == "batchnorm":
W1 = torch.randn(block_size * n_embed, n_hidden) * 1.0 # bad init, but BN fixes it
W2 = torch.randn(n_hidden, vocab_size) * 0.01
b1 = torch.zeros(n_hidden)
b2 = torch.zeros(vocab_size)
bn = BatchNorm1d(n_hidden) if init_type == "batchnorm" else None
params = [C, W1, b1, W2, b2]
if bn: params += bn.parameters()
for p in params: p.requires_grad = True
losses = []
for step in range(steps):
ix = torch.randint(0, len(X_train), (32,))
emb = C[X_train[ix]].view(-1, block_size * n_embed)
h_pre = emb @ W1 + b1
if bn: h_pre = bn(h_pre)
h = torch.tanh(h_pre)
logits = h @ W2 + b2
loss = F.cross_entropy(logits, Y_train[ix])
for p in params: p.grad = None
loss.backward()
lr = 0.1 if step < 3000 else 0.01
for p in params: p.data -= lr * p.grad
losses.append(loss.item())
return losses
fig, ax = plt.subplots(figsize=(10, 5))
for init_type in ["bad", "kaiming", "batchnorm"]:
losses = train_model(init_type)
# Smooth for plotting
smooth = [sum(losses[max(0,i-50):i+1])/min(i+1,51) for i in range(len(losses))]
ax.plot(smooth, label=init_type)
ax.set_xlabel("Step"); ax.set_ylabel("Loss")
ax.legend(); ax.set_title("Impact of initialization and normalization")
plt.show()The Diagnostic Checklist
Run these checks whenever a network trains slowly or poorly:
| Check | Healthy | Sick |
|---|---|---|
| Pre-activation std | ~1.0 per layer | Grows or shrinks across layers |
| Post-tanh distribution | Spread between -1 and 1 | Piled up at ±1 (saturated) |
| ReLU dead fraction | <5% | >20% (dead neurons) |
| Gradient magnitude | Similar across layers | Shrinks 10x+ per layer |
| Initial loss | for classes | Much higher (confident wrong predictions) |
Exercises
-
10-layer network: Stack 10 hidden layers with bad init. Visualize activations — they’ll be all ±1 by layer 3. Add batchnorm to each and show the fix.
-
LayerNorm vs BatchNorm: Implement LayerNorm (normalize across features, not batch). Compare training curves. LayerNorm doesn’t depend on batch size.
-
Weight histograms over training: Every 1000 steps, save a histogram of each weight matrix. Plot them as a grid. Healthy weights change smoothly; sick weights barely move.
-
The gain matters: Try gain values of 0.5, 1.0, 5/3, 2.0, 5.0 for tanh init. For each, plot activation distributions. Find the sweet spot.
Next: 19 - Becoming a Backprop Ninja — compute every gradient by hand through this network.