PPO from Scratch
Goal: Implement Proximal Policy Optimization (PPO) from scratch in PyTorch. Understand why PPO works, how the clipped surrogate objective prevents policy collapse, and implement the full training loop.
Based on: Proximal Policy Optimization Algorithms (Schulman et al., 2017) and OpenAI SpinningUp PPO
Prerequisites: Reinforcement Learning concepts, PyTorch, understand policy gradients (REINFORCE)
Why PPO?
Policy gradient methods (like REINFORCE) have a fundamental problem: they can take too large policy updates, causing catastrophic drops in performance (policy collapse). PPO constrains updates to be “conservative.”
Two key innovations:
- Clipped surrogate objective — don’t allow updates that improve too much
- Multiple epochs with minibatches — data efficiency
The Core Problem: Policy Gradient Instability
Standard policy gradient:
∇J = E[∇log π(a|s) * A(s,a)]
The problem: if an action gets unexpectedly good rewards, policy gradient will increase its probability dramatically — even from a single lucky episode. This can cause the policy to diverge.
PPO’s insight: constrain how much the policy can change in one update.
Clipped Surrogate Objective
PPO maximizes this objective:
def ppo_objective(old_log_probs, new_log_probs, advantages, clip_epsilon=0.2):
"""
Clipped surrogate objective.
ratio = π_new(a|s) / π_old(a|s)
L = min(ratio * A, clip(ratio, 1-ε, 1+ε) * A)
When A > 0 (good action): don't increase probability too much
When A < 0 (bad action): don't decrease probability too much
"""
ratio = torch.exp(new_log_probs - old_log_probs)
# Unclipped objective
surr1 = ratio * advantages
# Clipped objective
surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantages
# Take the worse of the two (conservative bound)
return -torch.min(surr1, surr2).mean()Intuition: The min means we only get the benefit of the unclipped objective if the ratio is within the clip range. If ratio goes outside [1-ε, 1+ε], the gradient goes to zero.
Advantage Estimation (GAE)
We need advantages A(s,a) for the objective. Raw returns are noisy; GAE provides a bias-variance tradeoff:
def compute_gae(rewards, values, dones, next_values, gamma=0.99, lam=0.95):
"""
Generalized Advantage Estimation.
A_t = Σ_{l=0}^{∞} (γλ)^l * δ_{t+l}
where δ_t = r_t + γ*V(s_{t+1}) - V(s_t) is the TD error
Args:
rewards: rewards at each step
values: value estimates V(s_t)
dones: episode termination flags
next_values: V(s_{t+1}) for last step
gamma: discount factor
lam: GAE lambda (bias-variance tradeoff)
Returns:
advantages: GAE advantage estimates
returns: advantage-normalized returns for value fitting
"""
advantages = torch.zeros_like(rewards)
last_advantage = 0
# Work backwards (TD error accumulation)
for t in reversed(range(len(rewards))):
if t == len(rewards) - 1:
next_value = next_values[t]
else:
next_value = values[t + 1]
# TD error: r_t + γ*V(s_{t+1}) - V(s_t)
delta = rewards[t] + gamma * next_value * (1 - dones[t]) - values[t]
# GAE accumulation
advantages[t] = last_advantage = delta + gamma * lam * (1 - dones[t]) * last_advantage
returns = advantages + values
return advantages, returnsLambda parameter:
- λ = 0: high bias, low variance (just TD(0))
- λ = 1: low bias, high variance (Monte Carlo)
Typical: λ = 0.95-0.99
Actor-Critic Architecture
PPO uses both a policy (actor) and value function (critic):
import torch
import torch.nn as nn
class ActorCritic(nn.Module):
"""Shared backbone, separate heads for policy and value."""
def __init__(self, obs_dim, act_dim, hidden_dim=64):
super().__init__()
# Shared feature extractor
self.shared = nn.Sequential(
nn.Linear(obs_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
)
# Actor head (policy)
self.actor = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, act_dim),
nn.Softmax(dim=-1),
)
# Critic head (value function)
self.critic = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 1),
)
def forward(self, x):
features = self.shared(x)
return self.actor(features), self.critic(features)
def get_action(self, obs):
"""Sample action from policy."""
probs, value = self.forward(obs)
dist = torch.distributions.Categorical(probs)
action = dist.sample()
log_prob = dist.log_prob(action)
return action, log_prob, value.squeeze(-1)
def evaluate(self, obs, actions):
"""Evaluate actions for PPO update."""
probs, values = self.forward(obs)
dist = torch.distributions.Categorical(probs)
log_probs = dist.log_prob(actions)
entropy = dist.entropy()
return log_probs, values.squeeze(-1), entropyPPO Update
def ppo_update(policy, optimizer, rollout, clip_epsilon=0.2,
entropy_coef=0.01, value_coef=0.5, max_grad_norm=0.5):
"""
Perform PPO update on collected rollout.
Args:
policy: ActorCritic network
optimizer: optimizer
rollout: dict with obs, actions, log_probs, values, rewards, dones
clip_epsilon: PPO clipping parameter
entropy_coef: entropy bonus coefficient
value_coef: value loss coefficient
"""
obs = rollout['obs']
actions = rollout['actions']
old_log_probs = rollout['log_probs']
values = rollout['values']
rewards = rollout['rewards']
dones = rollout['dones']
# Compute advantages
with torch.no_grad():
# Get final value for GAE
last_value = policy(obs[-1:]).[1].item() if not dones[-1] else 0
advantages, returns = compute_gae(
torch.tensor(rewards),
torch.tensor(values),
torch.tensor(dones),
torch.tensor([last_value] + values[:-1]), # shifted
gamma=0.99, lam=0.95
)
# Normalize advantages (helps training stability)
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Multiple epochs of updates
for _ in range(10): # ppo_epochs
# Get current policy distribution
log_probs, values_pred, entropy = policy.evaluate(obs, actions)
# PPO policy loss (clipped surrogate)
policy_loss = ppo_objective(old_log_probs, log_probs, advantages, clip_epsilon)
# Value function loss
value_loss = torch.nn.functional.mse_loss(values_pred, returns)
# Entropy bonus (encourages exploration)
entropy_loss = -entropy.mean()
# Total loss
loss = policy_loss + value_coef * value_loss + entropy_coef * entropy_loss
# Gradient step
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(policy.parameters(), max_grad_norm)
optimizer.step()Full Training Loop
def train_ppo(env_fn, steps_per_epoch=4000, epochs=50):
env = env_fn()
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.n
policy = ActorCritic(obs_dim, act_dim, hidden_dim=64)
optimizer = torch.optim.Adam(policy.parameters(), lr=3e-4)
for epoch in range(epochs):
# Collect rollout
rollout = {
'obs': [], 'actions': [], 'log_probs': [],
'values': [], 'rewards': [], 'dones': []
}
obs, _ = env.reset()
done = False
for _ in range(steps_per_epoch):
action, log_prob, value = policy.get_action(torch.tensor(obs, dtype=torch.float32))
rollout['obs'].append(obs)
rollout['actions'].append(action.item())
rollout['log_probs'].append(log_prob.item())
rollout['values'].append(value.item())
obs, reward, terminated, truncated, _ = env.step(action.item())
done = terminated or truncated
rollout['rewards'].append(reward)
rollout['dones'].append(done)
if done:
obs, _ = env.reset()
# PPO update
ppo_update(policy, optimizer, rollout)
# Logging
print(f"Epoch {epoch} | Avg Reward: {np.mean(rollout['rewards']):.2f}")Key Implementation Details
Clipping in Log Space
We compute ratio in log space to avoid numerical instability:
# Numerically stable ratio computation
ratio = torch.exp(new_log_probs - old_log_probs)
# Instead of: ratio = new_probs / old_probs
# Because probabilities can be very smallEarly Stopping with KL Divergence
Some implementations also add a KL penalty to prevent the new policy from drifting too far:
def compute_kl_div(old_log_probs, new_log_probs):
"""Compute KL(old || new) for monitoring."""
return (torch.exp(old_log_probs) * (old_log_probs - new_log_probs)).mean()Learning Rate Decay
PPO is sensitive to learning rate. Many implementations use:
- Linear decay to 0
- Or adaptive methods like AdamW with weight decay
Common Pitfalls
- Don’t clip too aggressively (ε = 0.2 is standard; too small = slow learning, too large = instability)
- Normalize observations — RL is sensitive to scale
- Bootstrap for terminal states — set V(s_T) = 0 for done episodes
- Don’t reuse data too many epochs — PPO’s advantage is data efficiency; overfitting to old data hurts
- Entropy should decrease — if entropy hits zero too fast, you collapsed
Exercises
- Implement GAE with different λ values and observe bias-variance tradeoff
- Add value function clipping (PPO-Clip) — clip value function changes too
- Implement async PPO (APPO) — use multiple workers for data collection
- Test on harder envs — LunarLander, Walker2d, Humanoid from Gymnasium
See Also
- Reinforcement Learning Roadmap — concepts this builds on
- OpenAI SpinningUp PPO — the reference implementation this tutorial follows
- Proximal Policy Optimization Algorithms — original PPO paper