PPO from Scratch

Goal: Implement Proximal Policy Optimization (PPO) from scratch in PyTorch. Understand why PPO works, how the clipped surrogate objective prevents policy collapse, and implement the full training loop.

Based on: Proximal Policy Optimization Algorithms (Schulman et al., 2017) and OpenAI SpinningUp PPO

Prerequisites: Reinforcement Learning concepts, PyTorch, understand policy gradients (REINFORCE)


Why PPO?

Policy gradient methods (like REINFORCE) have a fundamental problem: they can take too large policy updates, causing catastrophic drops in performance (policy collapse). PPO constrains updates to be “conservative.”

Two key innovations:

  1. Clipped surrogate objective — don’t allow updates that improve too much
  2. Multiple epochs with minibatches — data efficiency

The Core Problem: Policy Gradient Instability

Standard policy gradient:

∇J = E[∇log π(a|s) * A(s,a)]

The problem: if an action gets unexpectedly good rewards, policy gradient will increase its probability dramatically — even from a single lucky episode. This can cause the policy to diverge.

PPO’s insight: constrain how much the policy can change in one update.


Clipped Surrogate Objective

PPO maximizes this objective:

def ppo_objective(old_log_probs, new_log_probs, advantages, clip_epsilon=0.2):
    """
    Clipped surrogate objective.
    
    ratio = π_new(a|s) / π_old(a|s)
    
    L = min(ratio * A, clip(ratio, 1-ε, 1+ε) * A)
    
    When A > 0 (good action): don't increase probability too much
    When A < 0 (bad action): don't decrease probability too much
    """
    ratio = torch.exp(new_log_probs - old_log_probs)
    
    # Unclipped objective
    surr1 = ratio * advantages
    
    # Clipped objective  
    surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantages
    
    # Take the worse of the two (conservative bound)
    return -torch.min(surr1, surr2).mean()

Intuition: The min means we only get the benefit of the unclipped objective if the ratio is within the clip range. If ratio goes outside [1-ε, 1+ε], the gradient goes to zero.


Advantage Estimation (GAE)

We need advantages A(s,a) for the objective. Raw returns are noisy; GAE provides a bias-variance tradeoff:

def compute_gae(rewards, values, dones, next_values, gamma=0.99, lam=0.95):
    """
    Generalized Advantage Estimation.
    
    A_t = Σ_{l=0}^{∞} (γλ)^l * δ_{t+l}
    where δ_t = r_t + γ*V(s_{t+1}) - V(s_t) is the TD error
    
    Args:
        rewards: rewards at each step
        values: value estimates V(s_t)
        dones: episode termination flags
        next_values: V(s_{t+1}) for last step
        gamma: discount factor
        lam: GAE lambda (bias-variance tradeoff)
    
    Returns:
        advantages: GAE advantage estimates
        returns: advantage-normalized returns for value fitting
    """
    advantages = torch.zeros_like(rewards)
    last_advantage = 0
    
    # Work backwards (TD error accumulation)
    for t in reversed(range(len(rewards))):
        if t == len(rewards) - 1:
            next_value = next_values[t]
        else:
            next_value = values[t + 1]
        
        # TD error: r_t + γ*V(s_{t+1}) - V(s_t)
        delta = rewards[t] + gamma * next_value * (1 - dones[t]) - values[t]
        
        # GAE accumulation
        advantages[t] = last_advantage = delta + gamma * lam * (1 - dones[t]) * last_advantage
    
    returns = advantages + values
    return advantages, returns

Lambda parameter:

  • λ = 0: high bias, low variance (just TD(0))
  • λ = 1: low bias, high variance (Monte Carlo)

Typical: λ = 0.95-0.99


Actor-Critic Architecture

PPO uses both a policy (actor) and value function (critic):

import torch
import torch.nn as nn
 
class ActorCritic(nn.Module):
    """Shared backbone, separate heads for policy and value."""
    
    def __init__(self, obs_dim, act_dim, hidden_dim=64):
        super().__init__()
        
        # Shared feature extractor
        self.shared = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
        )
        
        # Actor head (policy)
        self.actor = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, act_dim),
            nn.Softmax(dim=-1),
        )
        
        # Critic head (value function)
        self.critic = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1),
        )
    
    def forward(self, x):
        features = self.shared(x)
        return self.actor(features), self.critic(features)
    
    def get_action(self, obs):
        """Sample action from policy."""
        probs, value = self.forward(obs)
        dist = torch.distributions.Categorical(probs)
        action = dist.sample()
        log_prob = dist.log_prob(action)
        return action, log_prob, value.squeeze(-1)
    
    def evaluate(self, obs, actions):
        """Evaluate actions for PPO update."""
        probs, values = self.forward(obs)
        dist = torch.distributions.Categorical(probs)
        log_probs = dist.log_prob(actions)
        entropy = dist.entropy()
        return log_probs, values.squeeze(-1), entropy

PPO Update

def ppo_update(policy, optimizer, rollout, clip_epsilon=0.2, 
               entropy_coef=0.01, value_coef=0.5, max_grad_norm=0.5):
    """
    Perform PPO update on collected rollout.
    
    Args:
        policy: ActorCritic network
        optimizer: optimizer
        rollout: dict with obs, actions, log_probs, values, rewards, dones
        clip_epsilon: PPO clipping parameter
        entropy_coef: entropy bonus coefficient
        value_coef: value loss coefficient
    """
    obs = rollout['obs']
    actions = rollout['actions']
    old_log_probs = rollout['log_probs']
    values = rollout['values']
    rewards = rollout['rewards']
    dones = rollout['dones']
    
    # Compute advantages
    with torch.no_grad():
        # Get final value for GAE
        last_value = policy(obs[-1:]).[1].item() if not dones[-1] else 0
        advantages, returns = compute_gae(
            torch.tensor(rewards),
            torch.tensor(values),
            torch.tensor(dones),
            torch.tensor([last_value] + values[:-1]),  # shifted
            gamma=0.99, lam=0.95
        )
    
    # Normalize advantages (helps training stability)
    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
    
    # Multiple epochs of updates
    for _ in range(10):  # ppo_epochs
        # Get current policy distribution
        log_probs, values_pred, entropy = policy.evaluate(obs, actions)
        
        # PPO policy loss (clipped surrogate)
        policy_loss = ppo_objective(old_log_probs, log_probs, advantages, clip_epsilon)
        
        # Value function loss
        value_loss = torch.nn.functional.mse_loss(values_pred, returns)
        
        # Entropy bonus (encourages exploration)
        entropy_loss = -entropy.mean()
        
        # Total loss
        loss = policy_loss + value_coef * value_loss + entropy_coef * entropy_loss
        
        # Gradient step
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(policy.parameters(), max_grad_norm)
        optimizer.step()

Full Training Loop

def train_ppo(env_fn, steps_per_epoch=4000, epochs=50):
    env = env_fn()
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.n
    
    policy = ActorCritic(obs_dim, act_dim, hidden_dim=64)
    optimizer = torch.optim.Adam(policy.parameters(), lr=3e-4)
    
    for epoch in range(epochs):
        # Collect rollout
        rollout = {
            'obs': [], 'actions': [], 'log_probs': [],
            'values': [], 'rewards': [], 'dones': []
        }
        
        obs, _ = env.reset()
        done = False
        
        for _ in range(steps_per_epoch):
            action, log_prob, value = policy.get_action(torch.tensor(obs, dtype=torch.float32))
            
            rollout['obs'].append(obs)
            rollout['actions'].append(action.item())
            rollout['log_probs'].append(log_prob.item())
            rollout['values'].append(value.item())
            
            obs, reward, terminated, truncated, _ = env.step(action.item())
            done = terminated or truncated
            
            rollout['rewards'].append(reward)
            rollout['dones'].append(done)
            
            if done:
                obs, _ = env.reset()
        
        # PPO update
        ppo_update(policy, optimizer, rollout)
        
        # Logging
        print(f"Epoch {epoch} | Avg Reward: {np.mean(rollout['rewards']):.2f}")

Key Implementation Details

Clipping in Log Space

We compute ratio in log space to avoid numerical instability:

# Numerically stable ratio computation
ratio = torch.exp(new_log_probs - old_log_probs)
 
# Instead of: ratio = new_probs / old_probs
# Because probabilities can be very small

Early Stopping with KL Divergence

Some implementations also add a KL penalty to prevent the new policy from drifting too far:

def compute_kl_div(old_log_probs, new_log_probs):
    """Compute KL(old || new) for monitoring."""
    return (torch.exp(old_log_probs) * (old_log_probs - new_log_probs)).mean()

Learning Rate Decay

PPO is sensitive to learning rate. Many implementations use:

  • Linear decay to 0
  • Or adaptive methods like AdamW with weight decay

Common Pitfalls

  1. Don’t clip too aggressively (ε = 0.2 is standard; too small = slow learning, too large = instability)
  2. Normalize observations — RL is sensitive to scale
  3. Bootstrap for terminal states — set V(s_T) = 0 for done episodes
  4. Don’t reuse data too many epochs — PPO’s advantage is data efficiency; overfitting to old data hurts
  5. Entropy should decrease — if entropy hits zero too fast, you collapsed

Exercises

  1. Implement GAE with different λ values and observe bias-variance tradeoff
  2. Add value function clipping (PPO-Clip) — clip value function changes too
  3. Implement async PPO (APPO) — use multiple workers for data collection
  4. Test on harder envs — LunarLander, Walker2d, Humanoid from Gymnasium

See Also