Model-Based RL
What
Learn a model of the environment (how the world works), then use that model to plan or generate synthetic experience. Instead of learning only from real interactions, the agent can “imagine” what would happen and learn from those imagined trajectories.
The fundamental tradeoff: model-free RL is simple but sample-hungry. Model-based RL is sample-efficient but only as good as the model.
Model-free vs model-based
| Aspect | Model-free | Model-based |
|---|---|---|
| Learns a model? | No | Yes: f(s,a) → s’, r |
| Sample efficiency | Low (needs millions of steps) | High (can learn from fewer real interactions) |
| Compute per step | Low | High (planning/simulation) |
| Accuracy | Only limited by policy capacity | Limited by model accuracy |
| When to use | Simulator available, cheap interactions | Real-world (robots, drones), expensive/dangerous interactions |
The world model
A world model predicts what happens next:
s_{t+1}, r_t = f_theta(s_t, a_t)
Given the current state and an action, predict the next state and reward.
Types of world models
| Type | What it learns | Example |
|---|---|---|
| Deterministic | Single next state prediction | Simple dynamics models |
| Stochastic | Distribution over next states | Ensemble models, VAE-based |
| Latent | Dynamics in learned latent space | Dreamer, RSSM |
Simple world model in PyTorch
import torch
import torch.nn as nn
import numpy as np
class WorldModel(nn.Module):
"""Predict next state and reward from current state and action."""
def __init__(self, state_dim, action_dim, hidden=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden),
nn.ReLU(),
nn.Linear(hidden, hidden),
nn.ReLU(),
)
self.state_head = nn.Linear(hidden, state_dim)
self.reward_head = nn.Linear(hidden, 1)
def forward(self, state, action):
"""Predict (next_state, reward)."""
x = torch.cat([state, action], dim=-1)
h = self.net(x)
next_state = self.state_head(h)
reward = self.reward_head(h)
return next_state, reward.squeeze(-1)
def train_world_model(model, buffer, n_steps=1000, batch_size=256, lr=1e-3):
"""Train world model on collected experience."""
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
for step in range(n_steps):
# Sample batch from replay buffer
idx = np.random.choice(len(buffer["states"]), batch_size)
states = torch.FloatTensor(buffer["states"][idx])
actions = torch.FloatTensor(buffer["actions"][idx])
next_states = torch.FloatTensor(buffer["next_states"][idx])
rewards = torch.FloatTensor(buffer["rewards"][idx])
pred_next, pred_reward = model(states, actions)
state_loss = nn.functional.mse_loss(pred_next, next_states)
reward_loss = nn.functional.mse_loss(pred_reward, rewards)
loss = state_loss + reward_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 200 == 0:
print(f"Step {step}: state_loss={state_loss:.4f}, reward_loss={reward_loss:.4f}")Dyna architecture
Dyna (Sutton, 1991) is the simplest model-based RL framework:
- Interact with real environment → store experience
- Update value function / policy from real experience (standard RL)
- Train world model on collected experience
- Generate N simulated transitions with the world model
- Update value function / policy from simulated experience too
- Go to 1
The key insight: simulated experience from the model supplements real experience, dramatically improving sample efficiency.
Dyna-Q implementation
import numpy as np
class DynaQ:
"""Dyna-Q: Q-learning + world model for simulated planning.
Works on discrete state/action spaces (grid world).
"""
def __init__(self, n_states, n_actions, alpha=0.1, gamma=0.95,
epsilon=0.1, n_planning_steps=10):
self.Q = np.zeros((n_states, n_actions))
self.alpha = alpha
self.gamma = gamma
self.epsilon = epsilon
self.n_planning = n_planning_steps
# World model: store observed transitions
self.model = {} # (s, a) -> (r, s')
def choose_action(self, state):
if np.random.random() < self.epsilon:
return np.random.randint(self.Q.shape[1])
return np.argmax(self.Q[state])
def update(self, state, action, reward, next_state):
"""Direct RL update + model learning + planning."""
# 1. Direct RL update (standard Q-learning)
td_target = reward + self.gamma * np.max(self.Q[next_state])
self.Q[state, action] += self.alpha * (td_target - self.Q[state, action])
# 2. Update world model
self.model[(state, action)] = (reward, next_state)
# 3. Planning: simulate experience from the model
for _ in range(self.n_planning):
# Sample a previously visited (s, a) pair
s, a = list(self.model.keys())[np.random.randint(len(self.model))]
r, s_next = self.model[(s, a)]
# Q-learning update on simulated transition
td_target = r + self.gamma * np.max(self.Q[s_next])
self.Q[s, a] += self.alpha * (td_target - self.Q[s, a])
def run_dyna_q_experiment(env, n_episodes=100, n_planning_steps=10):
"""Compare Q-learning (planning=0) vs Dyna-Q (planning=N)."""
agent = DynaQ(
n_states=env.observation_space.n,
n_actions=env.action_space.n,
n_planning_steps=n_planning_steps,
)
episode_rewards = []
for ep in range(n_episodes):
state, _ = env.reset()
total_reward = 0
done = False
while not done:
action = agent.choose_action(state)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
agent.update(state, action, reward, next_state)
state = next_state
total_reward += reward
episode_rewards.append(total_reward)
return episode_rewardsComparing Dyna-Q vs plain Q-learning
import gymnasium as gym
import matplotlib.pyplot as plt
# Use FrozenLake or Taxi as test environments
env = gym.make("Taxi-v3")
# Run with different planning steps
results = {}
for n_plan in [0, 5, 20, 50]:
rewards = run_dyna_q_experiment(env, n_episodes=500, n_planning_steps=n_plan)
# Smooth for plotting
smoothed = np.convolve(rewards, np.ones(20)/20, mode="valid")
results[n_plan] = smoothed
plt.figure(figsize=(10, 6))
for n_plan, rewards in results.items():
label = "Q-learning" if n_plan == 0 else f"Dyna-Q (plan={n_plan})"
plt.plot(rewards, label=label)
plt.xlabel("Episode")
plt.ylabel("Reward (smoothed)")
plt.title("Dyna-Q: more planning = faster learning")
plt.legend()
plt.grid(True)
plt.savefig("dyna_q_comparison.png", dpi=150)
plt.show()Dreamer (v1/v2/v3)
Dreamer learns a world model entirely in latent space (compressed representation), then trains a policy by imagining trajectories within that latent space.
How Dreamer works
- Encoder: map observations to latent states
- RSSM (Recurrent State Space Model): predict next latent state from current latent state + action
- Decoder: reconstruct observations from latent states (for model training)
- Reward predictor: predict reward from latent states
- Policy: trained entirely in imagination (latent rollouts), no real environment needed
Real world: o_1, a_1, r_1, o_2, a_2, r_2, ...
↓ encoder
Latent: z_1 → z_2 → z_3 → ... (RSSM predictions)
↓ policy training
Actor-Critic: trained on imagined latent trajectories
Why latent space? Real observations (images) are high-dimensional and hard to predict accurately. Latent states are compact, and the model only needs to capture dynamics-relevant information.
Dreamer v3 (2023) works across diverse domains: Atari, continuous control, 3D environments, Minecraft.
MuZero
MuZero (DeepMind, 2020) takes model-based RL to the extreme: it learns a model of the environment without ever predicting observations. It only learns what’s needed for planning.
Three learned components:
- Representation: observation → hidden state
- Dynamics: hidden state + action → next hidden state + reward
- Prediction: hidden state → policy + value (for MCTS planning)
MuZero achieved superhuman performance on Atari and board games (Go, chess, shogi) without knowing the rules of any game.
When model-based wins
- Expensive real interactions: each real-world trial with a drone costs time, energy, and risk of crash. A world model lets you do 1000 simulated trials per real trial.
- Safety-critical domains: military sim, medical robotics. Can’t explore dangerous actions in the real world.
- Known structure: if the physics are well-understood (e.g., rigid body dynamics), the model is accurate and model-based planning is very effective.
When model-based fails
- Model errors compound: small prediction errors at each step accumulate over long rollouts. After 50 imagined steps, the model may be in a completely unrealistic state.
- Model exploitation: the policy finds inputs that exploit model inaccuracies (get high predicted reward for states that don’t exist in reality). This is the model-based equivalent of reward hacking.
- Complex environments: environments with many interacting objects, flexible materials, or chaotic dynamics are hard to model accurately.
Mitigation: use ensembles of models (epistemic uncertainty), limit imagined rollout length, periodically retrain the model with fresh data.
Connection to autonomous systems
The sim-to-real problem is fundamentally about model accuracy. See Tutorial - Sim-to-Real Transfer for practical approaches.
- Drone navigation: train in simulation (Gazebo, AirSim), transfer to real drone. The simulator IS the world model.
- Model Predictive Control (MPC): explicitly use a dynamics model to plan optimal actions over a short horizon. Used in robotics and autonomous vehicles.
Self-test questions
- Why is model-based RL more sample-efficient than model-free RL?
- What is the Dyna architecture, and how does it combine real and simulated experience?
- Why does Dreamer learn dynamics in latent space instead of observation space?
- What is model exploitation, and how is it analogous to reward hacking?
- When would you choose model-free over model-based RL?
Exercises
- Dyna-Q comparison: Run the Dyna-Q experiment above with 0, 5, 20, 50 planning steps. Plot learning curves. How much does planning help?
- Neural world model: Train the WorldModel class on CartPole transitions. Collect 10k real transitions, train the model, then generate 10k synthetic transitions. Compare the synthetic data distribution with real data.
- Plan with learned model: Use the trained world model to do simple planning: for each state, evaluate all possible actions by predicting the next state and reward. Choose the action with highest predicted value. Does this beat random?
Links
- Q-Learning and DQN — model-free value-based alternative
- RL Fundamentals — MDP framework
- Actor-Critic and PPO — model-free policy optimization
- Tutorial - Sim-to-Real Transfer — simulation as world model
- Reward Design and Curriculum — what the model predicts rewards for