Batch Normalization#
Normalize activations across the batch dimension to stabilize and accelerate training.
Batch Norm (BN)#
For a mini-batch $\mathcal{B} = {x_1, \ldots, x_m}$:
$$\mu_\mathcal{B} = \frac{1}{m} \sum_i x_i$$
$$\sigma^2_\mathcal{B} = \frac{1}{m} \sum_i (x_i - \mu_\mathcal{B})^2$$
$$\hat{x}i = \frac{x_i - \mu\mathcal{B}}{\sqrt{\sigma^2_\mathcal{B} + \varepsilon}}$$
$$y_i = \gamma \hat{x}_i + \beta \qquad \text{(learned scale and shift)}$$
$\gamma$ and $\beta$ are learnable parameters that allow the network to undo normalization if needed.
Train vs. Inference#
- Training: normalize using mini-batch statistics $(\mu_\mathcal{B},, \sigma^2_\mathcal{B})$
- Inference: use running estimates accumulated during training (exponential moving average)
bn = nn.BatchNorm2d(num_features)
# At inference, use model.eval() to switch to running statsWhere to Place BN#
Common convention (ResNet original): Conv → BN → ReLU
Pre-activation (ResNet v2, often better): BN → ReLU → Conv
Variants#
| Variant | Normalize over | Use |
|---|---|---|
| Batch Norm | batch + spatial dims | CNNs |
| Layer Norm | feature dim per sample | Transformers, RNNs |
| Instance Norm | spatial dims per sample | style transfer |
| Group Norm | groups of channels per sample | small batch / detection |
Layer Normalization#
LN normalizes across the feature dimension (not batch), so it works with batch size = 1:
$$\mu = \frac{1}{d} \sum_j x_j, \qquad \sigma^2 = \frac{1}{d} \sum_j (x_j - \mu)^2$$
$$\hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \varepsilon}}, \qquad y = \gamma \hat{x} + \beta$$
No distinction between train and inference. Used in all Transformer architectures.
RMSNorm#
Simplified layer norm — skip mean centering, only scale:
$$\text{RMSNorm}(x) = \frac{x}{\text{RMS}(x)} \cdot \gamma, \qquad \text{RMS}(x) = \sqrt{\frac{1}{d} \sum_i x_i^2}$$
Used in LLaMA, T5. Slightly faster, similar performance.