Mixed Precision & Scaling#

Mixed Precision Training (FP16/BF16)#

Store weights in FP32, compute in FP16/BF16:

  • FP32: 32-bit float, range $\approx \pm 3.4 \times 10^{38}$, precision 7 decimal digits
  • FP16: 16-bit float, range $\approx \pm 65504$, precision 3–4 decimal digits
  • BF16: 16-bit, same exponent range as FP32 (wider range than FP16), less precision

Loss scaling: multiply loss by scale factor to avoid underflow in FP16; divide gradients back.

scaler = torch.cuda.amp.GradScaler()
with torch.autocast(device_type='cuda'):
    loss = model(x)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

BF16 preferred for LLM training (no loss scaling needed, hardware support on A100+).

Gradient Accumulation#

Simulate large batch by accumulating gradients over multiple steps:

for i, (x, y) in enumerate(loader):
    loss = model(x, y) / accum_steps
    loss.backward()
    if (i+1) % accum_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

Effective batch $= \text{batch_size} \times \text{accum_steps}$. Useful when GPU memory limits batch size.

Data Parallelism#

Split batch across GPUs, sync gradients:

  • DDP (DistributedDataParallel): each GPU has full model copy, all-reduce gradients
  • FSDP (Fully Sharded Data Parallel): shard model params, gradients, and optimizer state across GPUs — enables training models larger than single-GPU memory

Tensor Parallelism#

Split model layers across GPUs (Megatron-LM style). Each GPU holds a slice of each weight matrix.

Pipeline Parallelism#

Assign different layers to different GPUs. Microbatches flow through the pipeline.

Scaling Laws (Chinchilla)#

For compute-optimal training:

$$N_\text{opt} \propto C^{0.5}, \qquad D_\text{opt} \propto C^{0.5}$$

Where $N$ = model parameters, $D$ = training tokens, $C$ = compute FLOPs.

Chinchilla rule: ~20 tokens per parameter. GPT-3 (175B) was undertrained; Chinchilla (70B) outperformed it.

Compute: $\approx 6ND$ FLOPs for a forward+backward pass (2 FLOPs per multiply-add, $\times 3$ for backward $\approx \times 6$).