Neural Network Training Recipes
Goal: A systematic, checklist-driven approach to training neural networks. Not theory — practical debugging steps that catch 90% of problems. Adapted from Karpathy’s “A Recipe for Training Neural Networks” blog post.
Prerequisites: Gradient Descent, Bias-Variance Tradeoff, Evaluation Metrics, Vanishing and Exploding Gradients
The Core Problem
Neural network training is not “apply algorithm, get result.” It’s a debugging process. The code runs, the loss goes down, but the model is silently broken in ways that only surface in production.
This recipe is a checklist. Follow it every time.
Step 1: Understand Your Data
Before writing any model code, stare at your data.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
# Load data
digits = load_digits()
X, y = digits.data.astype(np.float32), digits.target
# LOOK at it
print(f"Shape: {X.shape}, Labels: {np.unique(y)}")
print(f"Min: {X.min()}, Max: {X.max()}, Mean: {X.mean():.2f}")
print(f"Class distribution: {np.bincount(y)}")
# Visualize samples
fig, axes = plt.subplots(2, 10, figsize=(15, 3))
for cls in range(10):
idx = np.where(y == cls)[0][:2]
for j, i in enumerate(idx):
axes[j, cls].imshow(X[i].reshape(8, 8), cmap='gray')
axes[j, cls].set_title(str(cls), fontsize=8)
axes[j, cls].axis('off')
plt.suptitle("Check: do the labels match the images?")
plt.show()
# Check for duplicates
print(f"Duplicate rows: {len(X) - len(np.unique(X, axis=0))}")
# Check for NaN/Inf
print(f"NaNs: {np.isnan(X).sum()}, Infs: {np.isinf(X).sum()}")Data checklist
- Can you visualize a random sample and verify the label?
- Are there class imbalance issues?
- Are features on similar scales?
- Any duplicates, NaNs, or outliers?
- Is train/val/test split stratified?
Step 2: Sanity Checks Before Training
Check 1: Loss at initialization
# Prepare data
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, stratify=y, random_state=42)
X_train_t = torch.tensor(X_train, dtype=torch.float32)
y_train_t = torch.tensor(y_train, dtype=torch.long)
X_test_t = torch.tensor(X_test, dtype=torch.float32)
y_test_t = torch.tensor(y_test, dtype=torch.long)
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
)
def forward(self, x):
return self.net(x)
model = MLP(64, 128, 10)
with torch.no_grad():
logits = model(X_train_t)
init_loss = F.cross_entropy(logits, y_train_t).item()
expected_loss = -np.log(1.0 / 10) # uniform over 10 classes
print(f"Initial loss: {init_loss:.4f}")
print(f"Expected loss: {expected_loss:.4f} (-log(1/10))")
print(f"Match: {abs(init_loss - expected_loss) < 0.5}")
# If initial loss is WAY higher → weights are too large (confident wrong predictions)
# If initial loss is WAY lower → data leakage or bugCheck 2: Overfit a single batch
# The model MUST be able to memorize a tiny batch
small_X = X_train_t[:32]
small_y = y_train_t[:32]
model_small = MLP(64, 128, 10)
optimizer = torch.optim.Adam(model_small.parameters(), lr=0.01)
for step in range(200):
logits = model_small(small_X)
loss = F.cross_entropy(logits, small_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc = (model_small(small_X).argmax(1) == small_y).float().mean()
print(f"Single batch accuracy after 200 steps: {acc:.4f}")
print(f"Should be: 1.0 (if not, model or data is broken)")If the model can’t overfit 32 examples, something is fundamentally wrong — data labels, model architecture, or gradient flow.
Step 3: Start Simple
# Don't start with your dream architecture. Start with the simplest thing that works.
# Baseline 1: Logistic regression
from sklearn.linear_model import LogisticRegression
lr_model = LogisticRegression(max_iter=1000)
lr_model.fit(X_train, y_train)
baseline = lr_model.score(X_test, y_test)
print(f"Logistic regression baseline: {baseline:.4f}")
# Baseline 2: Small MLP, no tricks
model = MLP(64, 64, 10)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(50):
logits = model(X_train_t)
loss = F.cross_entropy(logits, y_train_t)
optimizer.zero_grad()
loss.backward()
optimizer.step()
test_acc = (model(X_test_t).argmax(1) == y_test_t).float().mean()
print(f"Simple MLP: {test_acc:.4f}")
print(f"Beat baseline: {test_acc > baseline}")Add complexity only when simple models saturate.
Step 4: Learning Rate Finder
The most important hyperparameter. Find it empirically:
def lr_finder(model_class, X, y, min_lr=1e-6, max_lr=10, steps=200):
model = model_class(64, 128, 10)
optimizer = torch.optim.SGD(model.parameters(), lr=min_lr)
lr_factor = (max_lr / min_lr) ** (1 / steps)
lrs, losses_out = [], []
for step in range(steps):
logits = model(X)
loss = F.cross_entropy(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
lrs.append(optimizer.param_groups[0]['lr'])
losses_out.append(loss.item())
# Exponentially increase lr
for g in optimizer.param_groups:
g['lr'] *= lr_factor
if loss.item() > 4 * losses_out[0]: # diverging
break
return lrs, losses_out
lrs, losses = lr_finder(MLP, X_train_t, y_train_t)
plt.figure(figsize=(10, 5))
plt.plot(lrs, losses)
plt.xscale('log')
plt.xlabel("Learning rate")
plt.ylabel("Loss")
plt.title("LR Finder — pick the rate just before loss starts climbing")
plt.axvline(x=lrs[np.argmin(losses)], color='red', linestyle='--', alpha=0.5)
plt.show()
print(f"Suggested LR: {lrs[np.argmin(losses)]:.2e}")Step 5: Monitor Everything
def train_monitored(model, X_train, y_train, X_test, y_test, lr=1e-3, epochs=100):
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
history = {'train_loss': [], 'test_loss': [], 'train_acc': [], 'test_acc': [],
'grad_norms': [], 'weight_norms': []}
for epoch in range(epochs):
# Forward
logits = model(X_train)
loss = F.cross_entropy(logits, y_train)
# Backward
optimizer.zero_grad()
loss.backward()
# Monitor gradients BEFORE update
total_grad_norm = 0
for p in model.parameters():
if p.grad is not None:
total_grad_norm += p.grad.data.norm(2).item() ** 2
history['grad_norms'].append(total_grad_norm ** 0.5)
optimizer.step()
# Monitor weights
total_weight_norm = sum(p.data.norm(2).item() ** 2 for p in model.parameters()) ** 0.5
history['weight_norms'].append(total_weight_norm)
# Metrics
with torch.no_grad():
history['train_loss'].append(loss.item())
history['train_acc'].append((logits.argmax(1) == y_train).float().mean().item())
test_logits = model(X_test)
history['test_loss'].append(F.cross_entropy(test_logits, y_test).item())
history['test_acc'].append((test_logits.argmax(1) == y_test).float().mean().item())
return history
model = MLP(64, 128, 10)
history = train_monitored(model, X_train_t, y_train_t, X_test_t, y_test_t)
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes[0, 0].plot(history['train_loss'], label='train')
axes[0, 0].plot(history['test_loss'], label='test')
axes[0, 0].set_title("Loss"); axes[0, 0].legend()
axes[0, 1].plot(history['train_acc'], label='train')
axes[0, 1].plot(history['test_acc'], label='test')
axes[0, 1].set_title("Accuracy"); axes[0, 1].legend()
axes[1, 0].plot(history['grad_norms'])
axes[1, 0].set_title("Gradient norm")
axes[1, 0].set_yscale('log')
axes[1, 1].plot(history['weight_norms'])
axes[1, 1].set_title("Weight norm")
plt.tight_layout()
plt.show()What to look for
| Signal | Diagnosis | Fix |
|---|---|---|
| Train loss flat from start | LR too low or dead model | Increase LR |
| Train loss oscillates wildly | LR too high | Decrease LR |
| Train loss ↓, test loss ↑ | Overfitting | Regularize, more data |
| Train ≈ test, both high | Underfitting | More capacity, more features |
| Gradient norm → 0 | Vanishing gradients | Residuals, better init, LR warmup |
| Gradient norm → ∞ | Exploding gradients | Gradient clipping, lower LR |
| Weight norm grows unbounded | No regularization | Add weight decay |
Step 6: Regularize
Only after the model overfits (proves it has capacity):
class MLPRegularized(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.3):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, output_dim),
)
def forward(self, x):
return self.net(x)
model_reg = MLPRegularized(64, 256, 10, dropout=0.3)
optimizer = torch.optim.AdamW(model_reg.parameters(), lr=1e-3, weight_decay=1e-4)
history_reg = train_monitored(model_reg, X_train_t, y_train_t, X_test_t, y_test_t, epochs=200)
plt.plot(history_reg['train_loss'], label='train')
plt.plot(history_reg['test_loss'], label='test')
plt.legend(); plt.title("With dropout + weight decay")
plt.show()
print(f"Final test acc: {history_reg['test_acc'][-1]:.4f}")Regularization order
- Get more data (or augment)
- Weight decay (AdamW, not Adam)
- Dropout
- Early stopping
- Reduce model size (last resort — you want a big model that’s regularized)
The Full Recipe (Summary)
1. LOOK AT YOUR DATA
- Visualize, check labels, check distributions, check for leakage
2. SANITY CHECK
- Initial loss = -log(1/C)?
- Can overfit a single batch to 100%?
- Does the simplest baseline work?
3. START SIMPLE
- Small model, no tricks
- Beat the baseline, then add complexity
4. FIND THE LEARNING RATE
- LR finder sweep
- If in doubt, 3e-4 for Adam is a decent start
5. TRAIN AND MONITOR
- Plot train/val loss every epoch
- Plot gradient norms (catch vanishing/exploding)
- Plot weight norms (catch unbounded growth)
6. REGULARIZE (only after overfitting)
- Data augmentation → weight decay → dropout → early stopping
7. SQUEEZE
- LR schedule (cosine decay with warmup)
- Ensembling (average 3-5 models)
- Larger model with more regularization
Exercises
-
Break it on purpose: Introduce a subtle bug (swap two labels, corrupt 10% of data, normalize wrong). Can you detect it with these checks? Which check catches it?
-
Apply to your project: Take any tutorial model (e.g., 20 - Build GPT from Scratch) and run through the full recipe. What does the LR finder suggest? Does the initial loss match expectations?
-
Learning rate schedule: Implement cosine decay with warmup. Compare final accuracy with fixed LR vs scheduled.
-
Gradient monitoring tool: Build a reusable function that hooks into any
nn.Moduleand records per-layer gradient stats (mean, std, max, fraction of zeros).
This is the meta-tutorial — it applies to every other tutorial in this series.