Neural Network Training Recipes

Goal: A systematic, checklist-driven approach to training neural networks. Not theory — practical debugging steps that catch 90% of problems. Adapted from Karpathy’s “A Recipe for Training Neural Networks” blog post.

Prerequisites: Gradient Descent, Bias-Variance Tradeoff, Evaluation Metrics, Vanishing and Exploding Gradients


The Core Problem

Neural network training is not “apply algorithm, get result.” It’s a debugging process. The code runs, the loss goes down, but the model is silently broken in ways that only surface in production.

This recipe is a checklist. Follow it every time.


Step 1: Understand Your Data

Before writing any model code, stare at your data.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
 
# Load data
digits = load_digits()
X, y = digits.data.astype(np.float32), digits.target
 
# LOOK at it
print(f"Shape: {X.shape}, Labels: {np.unique(y)}")
print(f"Min: {X.min()}, Max: {X.max()}, Mean: {X.mean():.2f}")
print(f"Class distribution: {np.bincount(y)}")
 
# Visualize samples
fig, axes = plt.subplots(2, 10, figsize=(15, 3))
for cls in range(10):
    idx = np.where(y == cls)[0][:2]
    for j, i in enumerate(idx):
        axes[j, cls].imshow(X[i].reshape(8, 8), cmap='gray')
        axes[j, cls].set_title(str(cls), fontsize=8)
        axes[j, cls].axis('off')
plt.suptitle("Check: do the labels match the images?")
plt.show()
 
# Check for duplicates
print(f"Duplicate rows: {len(X) - len(np.unique(X, axis=0))}")
 
# Check for NaN/Inf
print(f"NaNs: {np.isnan(X).sum()}, Infs: {np.isinf(X).sum()}")

Data checklist

  • Can you visualize a random sample and verify the label?
  • Are there class imbalance issues?
  • Are features on similar scales?
  • Any duplicates, NaNs, or outliers?
  • Is train/val/test split stratified?

Step 2: Sanity Checks Before Training

Check 1: Loss at initialization

# Prepare data
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, stratify=y, random_state=42)
 
X_train_t = torch.tensor(X_train, dtype=torch.float32)
y_train_t = torch.tensor(y_train, dtype=torch.long)
X_test_t = torch.tensor(X_test, dtype=torch.float32)
y_test_t = torch.tensor(y_test, dtype=torch.long)
 
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )
 
    def forward(self, x):
        return self.net(x)
 
model = MLP(64, 128, 10)
with torch.no_grad():
    logits = model(X_train_t)
    init_loss = F.cross_entropy(logits, y_train_t).item()
 
expected_loss = -np.log(1.0 / 10)  # uniform over 10 classes
print(f"Initial loss:  {init_loss:.4f}")
print(f"Expected loss: {expected_loss:.4f} (-log(1/10))")
print(f"Match: {abs(init_loss - expected_loss) < 0.5}")
# If initial loss is WAY higher → weights are too large (confident wrong predictions)
# If initial loss is WAY lower → data leakage or bug

Check 2: Overfit a single batch

# The model MUST be able to memorize a tiny batch
small_X = X_train_t[:32]
small_y = y_train_t[:32]
 
model_small = MLP(64, 128, 10)
optimizer = torch.optim.Adam(model_small.parameters(), lr=0.01)
 
for step in range(200):
    logits = model_small(small_X)
    loss = F.cross_entropy(logits, small_y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
 
acc = (model_small(small_X).argmax(1) == small_y).float().mean()
print(f"Single batch accuracy after 200 steps: {acc:.4f}")
print(f"Should be: 1.0 (if not, model or data is broken)")

If the model can’t overfit 32 examples, something is fundamentally wrong — data labels, model architecture, or gradient flow.


Step 3: Start Simple

# Don't start with your dream architecture. Start with the simplest thing that works.
 
# Baseline 1: Logistic regression
from sklearn.linear_model import LogisticRegression
lr_model = LogisticRegression(max_iter=1000)
lr_model.fit(X_train, y_train)
baseline = lr_model.score(X_test, y_test)
print(f"Logistic regression baseline: {baseline:.4f}")
 
# Baseline 2: Small MLP, no tricks
model = MLP(64, 64, 10)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
 
for epoch in range(50):
    logits = model(X_train_t)
    loss = F.cross_entropy(logits, y_train_t)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
 
test_acc = (model(X_test_t).argmax(1) == y_test_t).float().mean()
print(f"Simple MLP: {test_acc:.4f}")
print(f"Beat baseline: {test_acc > baseline}")

Add complexity only when simple models saturate.


Step 4: Learning Rate Finder

The most important hyperparameter. Find it empirically:

def lr_finder(model_class, X, y, min_lr=1e-6, max_lr=10, steps=200):
    model = model_class(64, 128, 10)
    optimizer = torch.optim.SGD(model.parameters(), lr=min_lr)
    lr_factor = (max_lr / min_lr) ** (1 / steps)
 
    lrs, losses_out = [], []
    for step in range(steps):
        logits = model(X)
        loss = F.cross_entropy(logits, y)
 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
        lrs.append(optimizer.param_groups[0]['lr'])
        losses_out.append(loss.item())
 
        # Exponentially increase lr
        for g in optimizer.param_groups:
            g['lr'] *= lr_factor
 
        if loss.item() > 4 * losses_out[0]:  # diverging
            break
 
    return lrs, losses_out
 
lrs, losses = lr_finder(MLP, X_train_t, y_train_t)
 
plt.figure(figsize=(10, 5))
plt.plot(lrs, losses)
plt.xscale('log')
plt.xlabel("Learning rate")
plt.ylabel("Loss")
plt.title("LR Finder — pick the rate just before loss starts climbing")
plt.axvline(x=lrs[np.argmin(losses)], color='red', linestyle='--', alpha=0.5)
plt.show()
print(f"Suggested LR: {lrs[np.argmin(losses)]:.2e}")

Step 5: Monitor Everything

def train_monitored(model, X_train, y_train, X_test, y_test, lr=1e-3, epochs=100):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    history = {'train_loss': [], 'test_loss': [], 'train_acc': [], 'test_acc': [],
               'grad_norms': [], 'weight_norms': []}
 
    for epoch in range(epochs):
        # Forward
        logits = model(X_train)
        loss = F.cross_entropy(logits, y_train)
 
        # Backward
        optimizer.zero_grad()
        loss.backward()
 
        # Monitor gradients BEFORE update
        total_grad_norm = 0
        for p in model.parameters():
            if p.grad is not None:
                total_grad_norm += p.grad.data.norm(2).item() ** 2
        history['grad_norms'].append(total_grad_norm ** 0.5)
 
        optimizer.step()
 
        # Monitor weights
        total_weight_norm = sum(p.data.norm(2).item() ** 2 for p in model.parameters()) ** 0.5
        history['weight_norms'].append(total_weight_norm)
 
        # Metrics
        with torch.no_grad():
            history['train_loss'].append(loss.item())
            history['train_acc'].append((logits.argmax(1) == y_train).float().mean().item())
            test_logits = model(X_test)
            history['test_loss'].append(F.cross_entropy(test_logits, y_test).item())
            history['test_acc'].append((test_logits.argmax(1) == y_test).float().mean().item())
 
    return history
 
model = MLP(64, 128, 10)
history = train_monitored(model, X_train_t, y_train_t, X_test_t, y_test_t)
 
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
 
axes[0, 0].plot(history['train_loss'], label='train')
axes[0, 0].plot(history['test_loss'], label='test')
axes[0, 0].set_title("Loss"); axes[0, 0].legend()
 
axes[0, 1].plot(history['train_acc'], label='train')
axes[0, 1].plot(history['test_acc'], label='test')
axes[0, 1].set_title("Accuracy"); axes[0, 1].legend()
 
axes[1, 0].plot(history['grad_norms'])
axes[1, 0].set_title("Gradient norm")
axes[1, 0].set_yscale('log')
 
axes[1, 1].plot(history['weight_norms'])
axes[1, 1].set_title("Weight norm")
 
plt.tight_layout()
plt.show()

What to look for

SignalDiagnosisFix
Train loss flat from startLR too low or dead modelIncrease LR
Train loss oscillates wildlyLR too highDecrease LR
Train loss ↓, test loss ↑OverfittingRegularize, more data
Train ≈ test, both highUnderfittingMore capacity, more features
Gradient norm → 0Vanishing gradientsResiduals, better init, LR warmup
Gradient norm → ∞Exploding gradientsGradient clipping, lower LR
Weight norm grows unboundedNo regularizationAdd weight decay

Step 6: Regularize

Only after the model overfits (proves it has capacity):

class MLPRegularized(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim),
        )
 
    def forward(self, x):
        return self.net(x)
 
model_reg = MLPRegularized(64, 256, 10, dropout=0.3)
optimizer = torch.optim.AdamW(model_reg.parameters(), lr=1e-3, weight_decay=1e-4)
 
history_reg = train_monitored(model_reg, X_train_t, y_train_t, X_test_t, y_test_t, epochs=200)
 
plt.plot(history_reg['train_loss'], label='train')
plt.plot(history_reg['test_loss'], label='test')
plt.legend(); plt.title("With dropout + weight decay")
plt.show()
print(f"Final test acc: {history_reg['test_acc'][-1]:.4f}")

Regularization order

  1. Get more data (or augment)
  2. Weight decay (AdamW, not Adam)
  3. Dropout
  4. Early stopping
  5. Reduce model size (last resort — you want a big model that’s regularized)

The Full Recipe (Summary)

1. LOOK AT YOUR DATA
   - Visualize, check labels, check distributions, check for leakage

2. SANITY CHECK
   - Initial loss = -log(1/C)?
   - Can overfit a single batch to 100%?
   - Does the simplest baseline work?

3. START SIMPLE
   - Small model, no tricks
   - Beat the baseline, then add complexity

4. FIND THE LEARNING RATE
   - LR finder sweep
   - If in doubt, 3e-4 for Adam is a decent start

5. TRAIN AND MONITOR
   - Plot train/val loss every epoch
   - Plot gradient norms (catch vanishing/exploding)
   - Plot weight norms (catch unbounded growth)

6. REGULARIZE (only after overfitting)
   - Data augmentation → weight decay → dropout → early stopping

7. SQUEEZE
   - LR schedule (cosine decay with warmup)
   - Ensembling (average 3-5 models)
   - Larger model with more regularization

Exercises

  1. Break it on purpose: Introduce a subtle bug (swap two labels, corrupt 10% of data, normalize wrong). Can you detect it with these checks? Which check catches it?

  2. Apply to your project: Take any tutorial model (e.g., 20 - Build GPT from Scratch) and run through the full recipe. What does the LR finder suggest? Does the initial loss match expectations?

  3. Learning rate schedule: Implement cosine decay with warmup. Compare final accuracy with fixed LR vs scheduled.

  4. Gradient monitoring tool: Build a reusable function that hooks into any nn.Module and records per-layer gradient stats (mean, std, max, fraction of zeros).


This is the meta-tutorial — it applies to every other tutorial in this series.