Mixed Precision & Scaling#
Mixed Precision Training (FP16/BF16)#
Store weights in FP32, compute in FP16/BF16:
- FP32: 32-bit float, range $\approx \pm 3.4 \times 10^{38}$, precision 7 decimal digits
- FP16: 16-bit float, range $\approx \pm 65504$, precision 3–4 decimal digits
- BF16: 16-bit, same exponent range as FP32 (wider range than FP16), less precision
Loss scaling: multiply loss by scale factor to avoid underflow in FP16; divide gradients back.
scaler = torch.cuda.amp.GradScaler()
with torch.autocast(device_type='cuda'):
loss = model(x)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()BF16 preferred for LLM training (no loss scaling needed, hardware support on A100+).
Gradient Accumulation#
Simulate large batch by accumulating gradients over multiple steps:
for i, (x, y) in enumerate(loader):
loss = model(x, y) / accum_steps
loss.backward()
if (i+1) % accum_steps == 0:
optimizer.step()
optimizer.zero_grad()Effective batch $= \text{batch_size} \times \text{accum_steps}$. Useful when GPU memory limits batch size.
Data Parallelism#
Split batch across GPUs, sync gradients:
- DDP (DistributedDataParallel): each GPU has full model copy, all-reduce gradients
- FSDP (Fully Sharded Data Parallel): shard model params, gradients, and optimizer state across GPUs — enables training models larger than single-GPU memory
Tensor Parallelism#
Split model layers across GPUs (Megatron-LM style). Each GPU holds a slice of each weight matrix.
Pipeline Parallelism#
Assign different layers to different GPUs. Microbatches flow through the pipeline.
Scaling Laws (Chinchilla)#
For compute-optimal training:
$$N_\text{opt} \propto C^{0.5}, \qquad D_\text{opt} \propto C^{0.5}$$
Where $N$ = model parameters, $D$ = training tokens, $C$ = compute FLOPs.
Chinchilla rule: ~20 tokens per parameter. GPT-3 (175B) was undertrained; Chinchilla (70B) outperformed it.
Compute: $\approx 6ND$ FLOPs for a forward+backward pass (2 FLOPs per multiply-add, $\times 3$ for backward $\approx \times 6$).