Batch Normalization#

Normalize activations across the batch dimension to stabilize and accelerate training.

Batch Norm (BN)#

For a mini-batch $\mathcal{B} = {x_1, \ldots, x_m}$:

$$\mu_\mathcal{B} = \frac{1}{m} \sum_i x_i$$

$$\sigma^2_\mathcal{B} = \frac{1}{m} \sum_i (x_i - \mu_\mathcal{B})^2$$

$$\hat{x}i = \frac{x_i - \mu\mathcal{B}}{\sqrt{\sigma^2_\mathcal{B} + \varepsilon}}$$

$$y_i = \gamma \hat{x}_i + \beta \qquad \text{(learned scale and shift)}$$

$\gamma$ and $\beta$ are learnable parameters that allow the network to undo normalization if needed.

Train vs. Inference#

  • Training: normalize using mini-batch statistics $(\mu_\mathcal{B},, \sigma^2_\mathcal{B})$
  • Inference: use running estimates accumulated during training (exponential moving average)
bn = nn.BatchNorm2d(num_features)
# At inference, use model.eval() to switch to running stats

Where to Place BN#

Common convention (ResNet original): Conv → BN → ReLU

Pre-activation (ResNet v2, often better): BN → ReLU → Conv

Variants#

Variant Normalize over Use
Batch Norm batch + spatial dims CNNs
Layer Norm feature dim per sample Transformers, RNNs
Instance Norm spatial dims per sample style transfer
Group Norm groups of channels per sample small batch / detection

Layer Normalization#

LN normalizes across the feature dimension (not batch), so it works with batch size = 1:

$$\mu = \frac{1}{d} \sum_j x_j, \qquad \sigma^2 = \frac{1}{d} \sum_j (x_j - \mu)^2$$

$$\hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \varepsilon}}, \qquad y = \gamma \hat{x} + \beta$$

No distinction between train and inference. Used in all Transformer architectures.

RMSNorm#

Simplified layer norm — skip mean centering, only scale:

$$\text{RMSNorm}(x) = \frac{x}{\text{RMS}(x)} \cdot \gamma, \qquad \text{RMS}(x) = \sqrt{\frac{1}{d} \sum_i x_i^2}$$

Used in LLaMA, T5. Slightly faster, similar performance.