Batch Normalization
What
Normalize activations within each mini-batch to have zero mean and unit variance. Then apply learnable scale (γ) and shift (β).
x_norm = (x - batch_mean) / sqrt(batch_var + ε)
output = γ × x_norm + β
Why it helps
- Stabilizes training — activations don’t drift to extreme values
- Allows higher learning rates → faster training
- Mild regularization effect (due to batch noise)
- Reduces sensitivity to weight initialization
In PyTorch
import torch.nn as nn
model = nn.Sequential(
nn.Linear(784, 256),
nn.BatchNorm1d(256), # normalize after linear
nn.ReLU(),
nn.Linear(256, 10),
)Variants
| Method | Normalizes over | Use case |
|---|---|---|
| BatchNorm | Batch dimension | CNNs, feedforward nets |
| LayerNorm | Feature dimension | Transformers, RNNs |
| GroupNorm | Groups of channels | Small batch sizes |
| InstanceNorm | Single sample, single channel | Style transfer |
LayerNorm is used in transformers because it doesn’t depend on batch size.