Policy Gradient Methods

What

A family of RL algorithms that directly parameterize the policy as a neural network $\pi_\theta(a|s)$ and optimize it by gradient ascent on expected return. Instead of learning a value function and deriving a policy from it (like Q-learning), policy gradient methods learn the policy itself.

The policy network outputs a probability distribution over actions given a state. Training adjusts the network weights to make high-reward actions more probable.

Why It Matters

  • Continuous action spaces: Q-learning needs $\max_a Q(s,a)$, which requires enumerating all actions. Policy gradients output continuous actions directly (joint torques, steering angles)
  • Stochastic policies: some problems require randomized strategies (rock-paper-scissors, exploration). Policy gradients naturally represent distributions
  • Foundation for modern RL: PPO, the dominant RL algorithm (used in robotics, game playing, and RLHF for LLMs), is a policy gradient method
  • Theoretical guarantees: the policy gradient theorem provides a clean, exact gradient — no approximation needed

How It Works

The Policy Gradient Theorem

The objective is to maximize expected return:

J(θ) = E_τ~π_θ [R(τ)]

where $\tau$ is a trajectory (sequence of states and actions) and $R(\tau)$ is its total reward.

The policy gradient theorem gives the gradient:

∇J(θ) = E [Σ_t ∇log π_θ(a_t|s_t) · G_t]

where G_t = Σ_{k=t}^{T} γ^{k-t} · r_k  (discounted return from step t)

Intuition: $\nabla \log \pi_\theta(a_t|s_t)$ points in the direction that increases the probability of action $a_t$. We scale it by $G_t$: if the return was high, push harder in that direction. If the return was low (or negative), push away.

The REINFORCE Algorithm

The simplest policy gradient method (Williams, 1992):

1. Collect a complete episode using current policy π_θ
2. For each timestep t:
   a. Compute return G_t = Σ_{k=t}^{T} γ^{k-t} · r_k
   b. Compute policy gradient: ∇log π_θ(a_t|s_t) · G_t
3. Update θ ← θ + α · (1/T) Σ_t ∇log π_θ(a_t|s_t) · G_t
4. Repeat

REINFORCE is Monte Carlo: it uses complete episodes (no bootstrapping). This means unbiased gradients but high variance.

Variance Reduction with Baselines

Raw REINFORCE has high variance because $G_t$ varies wildly between episodes. Subtracting a baseline $b(s_t)$ reduces variance without introducing bias:

∇J(θ) = E [Σ_t ∇log π_θ(a_t|s_t) · (G_t - b(s_t))]

The optimal baseline is approximately $V(s_t)$ (the value function). This gives the advantage:

A_t = G_t - V(s_t)

Positive advantage: action was better than average — increase its probability. Negative advantage: action was worse than average — decrease its probability.

When you learn $V(s_t)$ with a separate network, you get Actor-Critic Methods.

From REINFORCE to PPO

The evolution of policy gradient algorithms:

REINFORCE (1992)         → high variance, Monte Carlo
+ baseline               → REINFORCE with baseline, lower variance
+ learned value function → Actor-Critic (A2C)
+ trust region           → TRPO (constrained optimization, complex)
+ clipped objective      → PPO (simple, stable, the default)

PPO’s key idea: limit how much the policy changes per update. The clipped objective prevents catastrophic policy updates without the complexity of TRPO. See Actor-Critic and PPO for details.

Code Example

REINFORCE on CartPole (PyTorch)

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import gymnasium as gym
import numpy as np
 
class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, action_dim),
        )
 
    def forward(self, state):
        return torch.softmax(self.net(state), dim=-1)
 
def reinforce(env_name="CartPole-v1", episodes=500, gamma=0.99, lr=1e-2):
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
 
    policy = PolicyNetwork(state_dim, action_dim)
    optimizer = optim.Adam(policy.parameters(), lr=lr)
 
    returns_history = []
 
    for ep in range(episodes):
        state, _ = env.reset()
        log_probs = []
        rewards = []
 
        # --- Collect episode ---
        done = False
        while not done:
            state_t = torch.tensor(state, dtype=torch.float32)
            probs = policy(state_t)
            dist = Categorical(probs)
            action = dist.sample()
            log_probs.append(dist.log_prob(action))
 
            state, reward, terminated, truncated, _ = env.step(action.item())
            rewards.append(reward)
            done = terminated or truncated
 
        # --- Compute discounted returns ---
        G = 0
        returns = []
        for r in reversed(rewards):
            G = r + gamma * G
            returns.insert(0, G)
        returns = torch.tensor(returns, dtype=torch.float32)
 
        # Normalize returns (variance reduction)
        if returns.std() > 0:
            returns = (returns - returns.mean()) / (returns.std() + 1e-8)
 
        # --- Policy gradient update ---
        loss = 0
        for log_prob, G_t in zip(log_probs, returns):
            loss -= log_prob * G_t  # negative because we minimize loss
 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
        total_reward = sum(rewards)
        returns_history.append(total_reward)
 
        if (ep + 1) % 50 == 0:
            avg = np.mean(returns_history[-50:])
            print(f"Episode {ep+1}: avg reward (last 50) = {avg:.1f}")
 
    env.close()
    return policy, returns_history
 
# Train
policy, history = reinforce(episodes=500)
# CartPole-v1 is solved when avg reward >= 475 over 100 episodes

Key Tradeoffs

AspectPolicy GradientValue-Based (DQN)
Action spaceContinuous or discreteDiscrete only
Policy typeStochastic (natural)Deterministic (epsilon-greedy)
Sample efficiencyLow (on-policy, needs fresh data)Higher (off-policy, replay buffer)
VarianceHigh (needs baselines)Lower (bootstrapping)
StabilityCan be unstable without clippingTarget network helps
ConvergenceLocal optima possibleCan diverge with function approx

On-policy vs off-policy: REINFORCE is on-policy — it can only learn from data generated by the current policy. DQN is off-policy — it learns from old data in a replay buffer. This makes DQN more sample-efficient but limits it to discrete actions.

Common Pitfalls

  • No return normalization: raw returns can have huge variance. Always normalize returns (subtract mean, divide by std) within each batch
  • Learning rate too high: policy gradient updates can be large. Start with 1e-3 or lower. PPO solves this with clipping
  • Too few episodes per update: with only 1 episode, the gradient estimate is very noisy. Use batches of episodes or switch to PPO with minibatch updates
  • Forgetting to use torch.no_grad(): during rollout collection, you don’t need gradients for the forward pass. Save memory with no_grad outside of the loss computation
  • Discount factor too low: $\gamma = 0.9$ heavily discounts future rewards. For long-horizon tasks, use $\gamma = 0.99$ or $0.999$

Exercises

  1. Implement REINFORCE with a learned baseline: add a value network $V(s)$ that predicts the return, and subtract it from $G_t$ in the loss. Compare training curves with and without the baseline
  2. Modify the CartPole REINFORCE to work on LunarLander-v2 (discrete, 8-dim state). How many episodes does it need?
  3. Plot the training curve (episode reward vs episode number). Add a moving average line. At what episode does the agent consistently solve CartPole?
  4. Implement entropy regularization: add $-\beta H(\pi)$ to the loss where $H(\pi) = -\sum_a \pi(a|s) \log \pi(a|s)$. This encourages exploration. Try $\beta = 0.01$

Self-Test Questions

  1. What does the policy gradient theorem say in plain English? Why is $\nabla \log \pi$ the key quantity?
  2. Why does REINFORCE have high variance? What are two ways to reduce it?
  3. Explain why policy gradients can handle continuous action spaces but DQN cannot
  4. What is the advantage function $A(s,a)$ and why is it better than using raw returns?
  5. Trace the evolution from REINFORCE to PPO. What problem does each step solve?