Policy Gradient Methods#

Policy Gradient Theorem#

Directly optimize parameterized policy $\pi_\theta$:

$$\nabla_\theta J(\theta) = \mathbb{E}\pi!\left[\nabla\theta \log \pi_\theta(a_t \mid s_t) \cdot G_t\right]$$

where $G_t = \sum_{k \geq t} \gamma^{k-t} r_k$ is the return.

REINFORCE (Monte Carlo PG):

$$\theta \leftarrow \theta + \alpha, \nabla_\theta \log \pi_\theta(a_t \mid s_t) \cdot G_t$$

High variance — use baseline to reduce it.

Baseline#

Subtract baseline $b(s)$ (e.g., $V(s)$) to reduce variance without bias:

$$\nabla_\theta J \propto \mathbb{E}!\left[\nabla_\theta \log \pi_\theta(a \mid s) \cdot (Q(s,a) - b(s))\right]$$

Advantage: $A(s,a) = Q(s,a) - V(s)$ — measures how much better action $a$ is than average.

Actor-Critic#

  • Actor: policy $\pi_\theta(a \mid s)$
  • Critic: value function $V_\phi(s)$ estimates baseline

$$\theta \leftarrow \theta + \alpha, \nabla_\theta \log \pi_\theta(a_t \mid s_t) \cdot \delta_t$$

where $\delta_t = r_t + \gamma V_\phi(s_{t+1}) - V_\phi(s_t)$

PPO (Proximal Policy Optimization)#

Clip policy ratio to prevent large updates:

$$r_t(\theta) = \frac{\pi_\theta(a_t \mid s_t)}{\pi_{\theta_\text{old}}(a_t \mid s_t)}$$

$$L_\text{CLIP} = \mathbb{E}!\left[\min!\left(r_t \hat{A}_t,; \text{clip}(r_t, 1-\varepsilon, 1+\varepsilon)\hat{A}_t\right)\right], \quad \varepsilon = 0.2$$

Also maximizes entropy bonus for exploration: $L = L_\text{CLIP} - c_1 L_\text{VF} + c_2 S[\pi_\theta]$

Widely used: OpenAI Five, AlphaCode, RLHF for LLMs.

TRPO (Trust Region PO)#

Constrain KL divergence between old and new policy:

$$\max_\theta; \mathbb{E}[r_t(\theta)\hat{A}t] \quad \text{s.t.} \quad \mathbb{E}[\text{KL}(\pi\text{old} | \pi_\theta)] \leq \delta$$

PPO is a simpler first-order approximation of TRPO.

GAE (Generalized Advantage Estimation)#

Exponentially weighted sum of TD errors:

$$\hat{A}t^{\text{GAE}(\lambda)} = \sum{l \geq 0} (\gamma\lambda)^l \delta_{t+l}$$

  • $\lambda = 1$: high variance, low bias (Monte Carlo)
  • $\lambda = 0$: low variance, high bias (TD(0))
  • $\lambda = 0.95$: typical default