Attention & Transformers#
Scaled Dot-Product Attention#
Given queries $Q$, keys $K$, values $V$:
$$\text{Attention}(Q, K, V) = \text{softmax}!\left(\frac{QK^\top}{\sqrt{d_k}}\right) V$$
- $Q \in \mathbb{R}^{n \times d_k}$, $K \in \mathbb{R}^{m \times d_k}$, $V \in \mathbb{R}^{m \times d_v}$
- Division by $\sqrt{d_k}$ prevents softmax saturation in high dimensions
- Output: weighted sum of values, weights = similarity(query, key)
Complexity: $O(n^2 d)$ time and memory in sequence length $n$.
Multi-Head Attention#
Run $h$ attention heads in parallel, concatenate outputs:
$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h), W^O$$
$$\text{head}_i = \text{Attention}(QW_i^Q,, KW_i^K,, VW_i^V)$$
Each head learns different aspects of relationships. Typical: $h = 8$–$32$ heads, $d_k = d_\text{model}/h$.
Transformer Block#
x = x + MultiHeadAttention(LayerNorm(x)) # residual + attention
x = x + FFN(LayerNorm(x)) # residual + feed-forwardFFN = Linear($d_\text{model} \to 4d_\text{model}$) → GELU → Linear($4d_\text{model} \to d_\text{model}$)
Pre-norm (shown above) is more stable than post-norm for deep models.
Positional Encoding#
Attention is permutation-invariant — need to inject position info.
Sinusoidal (original): $\text{PE}(\text{pos}, 2i) = \sin(\text{pos} / 10000^{2i/d_\text{model}})$
Learned: trainable embedding per position (GPT, BERT).
RoPE (rotary): rotate $Q/K$ by position angle — relative positions, extrapolates better.
ALiBi: add bias $-|i-j|$ to attention logits — efficient extrapolation.
Encoder vs. Decoder#
| Encoder | Decoder | |
|---|---|---|
| Attention mask | full (all-to-all) | causal (left-only) |
| Use | representations (BERT) | generation (GPT) |
| Cross-attention | — | attends to encoder output |
Encoder-decoder (T5, BART): encoder processes input, decoder generates output.
Complexity vs. RNN#
| Transformer | RNN | |
|---|---|---|
| Parallel training | yes | no (sequential) |
| Context length | $O(n^2)$ memory | $O(1)$ memory, $O(n)$ time |
| Long-range deps | strong | weak (vanishing gradient) |
Key Variants#
| Model | Type | Notes |
|---|---|---|
| BERT | encoder | masked LM, bidirectional |
| GPT series | decoder | autoregressive LM |
| T5 | enc-dec | text-to-text unified |
| LLaMA | decoder | efficient, open weights |
| PaLM | decoder | pathways, massive scale |