Attention & Transformers#

Scaled Dot-Product Attention#

Given queries $Q$, keys $K$, values $V$:

$$\text{Attention}(Q, K, V) = \text{softmax}!\left(\frac{QK^\top}{\sqrt{d_k}}\right) V$$

  • $Q \in \mathbb{R}^{n \times d_k}$, $K \in \mathbb{R}^{m \times d_k}$, $V \in \mathbb{R}^{m \times d_v}$
  • Division by $\sqrt{d_k}$ prevents softmax saturation in high dimensions
  • Output: weighted sum of values, weights = similarity(query, key)

Complexity: $O(n^2 d)$ time and memory in sequence length $n$.

Multi-Head Attention#

Run $h$ attention heads in parallel, concatenate outputs:

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h), W^O$$

$$\text{head}_i = \text{Attention}(QW_i^Q,, KW_i^K,, VW_i^V)$$

Each head learns different aspects of relationships. Typical: $h = 8$–$32$ heads, $d_k = d_\text{model}/h$.

Transformer Block#

x = x + MultiHeadAttention(LayerNorm(x))   # residual + attention
x = x + FFN(LayerNorm(x))                  # residual + feed-forward

FFN = Linear($d_\text{model} \to 4d_\text{model}$) → GELU → Linear($4d_\text{model} \to d_\text{model}$)

Pre-norm (shown above) is more stable than post-norm for deep models.

Positional Encoding#

Attention is permutation-invariant — need to inject position info.

Sinusoidal (original): $\text{PE}(\text{pos}, 2i) = \sin(\text{pos} / 10000^{2i/d_\text{model}})$

Learned: trainable embedding per position (GPT, BERT).

RoPE (rotary): rotate $Q/K$ by position angle — relative positions, extrapolates better.

ALiBi: add bias $-|i-j|$ to attention logits — efficient extrapolation.

Encoder vs. Decoder#

Encoder Decoder
Attention mask full (all-to-all) causal (left-only)
Use representations (BERT) generation (GPT)
Cross-attention attends to encoder output

Encoder-decoder (T5, BART): encoder processes input, decoder generates output.

Complexity vs. RNN#

Transformer RNN
Parallel training yes no (sequential)
Context length $O(n^2)$ memory $O(1)$ memory, $O(n)$ time
Long-range deps strong weak (vanishing gradient)

Key Variants#

Model Type Notes
BERT encoder masked LM, bidirectional
GPT series decoder autoregressive LM
T5 enc-dec text-to-text unified
LLaMA decoder efficient, open weights
PaLM decoder pathways, massive scale