Batch Normalization

What

Normalize activations within each mini-batch to have zero mean and unit variance. Then apply learnable scale (γ) and shift (β).

x_norm = (x - batch_mean) / sqrt(batch_var + ε)
output = γ × x_norm + β

Why it helps

  • Stabilizes training — activations don’t drift to extreme values
  • Allows higher learning rates → faster training
  • Mild regularization effect (due to batch noise)
  • Reduces sensitivity to weight initialization

In PyTorch

import torch.nn as nn
 
model = nn.Sequential(
    nn.Linear(784, 256),
    nn.BatchNorm1d(256),   # normalize after linear
    nn.ReLU(),
    nn.Linear(256, 10),
)

Variants

MethodNormalizes overUse case
BatchNormBatch dimensionCNNs, feedforward nets
LayerNormFeature dimensionTransformers, RNNs
GroupNormGroups of channelsSmall batch sizes
InstanceNormSingle sample, single channelStyle transfer

LayerNorm is used in transformers because it doesn’t depend on batch size.