Model-Based RL

What

Learn a model of the environment (how the world works), then use that model to plan or generate synthetic experience. Instead of learning only from real interactions, the agent can “imagine” what would happen and learn from those imagined trajectories.

The fundamental tradeoff: model-free RL is simple but sample-hungry. Model-based RL is sample-efficient but only as good as the model.

Model-free vs model-based

AspectModel-freeModel-based
Learns a model?NoYes: f(s,a) s’, r
Sample efficiencyLow (needs millions of steps)High (can learn from fewer real interactions)
Compute per stepLowHigh (planning/simulation)
AccuracyOnly limited by policy capacityLimited by model accuracy
When to useSimulator available, cheap interactionsReal-world (robots, drones), expensive/dangerous interactions

The world model

A world model predicts what happens next:

s_{t+1}, r_t = f_theta(s_t, a_t)

Given the current state and an action, predict the next state and reward.

Types of world models

TypeWhat it learnsExample
DeterministicSingle next state predictionSimple dynamics models
StochasticDistribution over next statesEnsemble models, VAE-based
LatentDynamics in learned latent spaceDreamer, RSSM

Simple world model in PyTorch

import torch
import torch.nn as nn
import numpy as np
 
class WorldModel(nn.Module):
    """Predict next state and reward from current state and action."""
 
    def __init__(self, state_dim, action_dim, hidden=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
        )
        self.state_head = nn.Linear(hidden, state_dim)
        self.reward_head = nn.Linear(hidden, 1)
 
    def forward(self, state, action):
        """Predict (next_state, reward)."""
        x = torch.cat([state, action], dim=-1)
        h = self.net(x)
        next_state = self.state_head(h)
        reward = self.reward_head(h)
        return next_state, reward.squeeze(-1)
 
def train_world_model(model, buffer, n_steps=1000, batch_size=256, lr=1e-3):
    """Train world model on collected experience."""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
 
    for step in range(n_steps):
        # Sample batch from replay buffer
        idx = np.random.choice(len(buffer["states"]), batch_size)
        states = torch.FloatTensor(buffer["states"][idx])
        actions = torch.FloatTensor(buffer["actions"][idx])
        next_states = torch.FloatTensor(buffer["next_states"][idx])
        rewards = torch.FloatTensor(buffer["rewards"][idx])
 
        pred_next, pred_reward = model(states, actions)
        state_loss = nn.functional.mse_loss(pred_next, next_states)
        reward_loss = nn.functional.mse_loss(pred_reward, rewards)
        loss = state_loss + reward_loss
 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
        if step % 200 == 0:
            print(f"Step {step}: state_loss={state_loss:.4f}, reward_loss={reward_loss:.4f}")

Dyna architecture

Dyna (Sutton, 1991) is the simplest model-based RL framework:

  1. Interact with real environment → store experience
  2. Update value function / policy from real experience (standard RL)
  3. Train world model on collected experience
  4. Generate N simulated transitions with the world model
  5. Update value function / policy from simulated experience too
  6. Go to 1

The key insight: simulated experience from the model supplements real experience, dramatically improving sample efficiency.

Dyna-Q implementation

import numpy as np
 
class DynaQ:
    """Dyna-Q: Q-learning + world model for simulated planning.
    Works on discrete state/action spaces (grid world).
    """
    def __init__(self, n_states, n_actions, alpha=0.1, gamma=0.95,
                 epsilon=0.1, n_planning_steps=10):
        self.Q = np.zeros((n_states, n_actions))
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.n_planning = n_planning_steps
 
        # World model: store observed transitions
        self.model = {}  # (s, a) -> (r, s')
 
    def choose_action(self, state):
        if np.random.random() < self.epsilon:
            return np.random.randint(self.Q.shape[1])
        return np.argmax(self.Q[state])
 
    def update(self, state, action, reward, next_state):
        """Direct RL update + model learning + planning."""
        # 1. Direct RL update (standard Q-learning)
        td_target = reward + self.gamma * np.max(self.Q[next_state])
        self.Q[state, action] += self.alpha * (td_target - self.Q[state, action])
 
        # 2. Update world model
        self.model[(state, action)] = (reward, next_state)
 
        # 3. Planning: simulate experience from the model
        for _ in range(self.n_planning):
            # Sample a previously visited (s, a) pair
            s, a = list(self.model.keys())[np.random.randint(len(self.model))]
            r, s_next = self.model[(s, a)]
 
            # Q-learning update on simulated transition
            td_target = r + self.gamma * np.max(self.Q[s_next])
            self.Q[s, a] += self.alpha * (td_target - self.Q[s, a])
 
def run_dyna_q_experiment(env, n_episodes=100, n_planning_steps=10):
    """Compare Q-learning (planning=0) vs Dyna-Q (planning=N)."""
    agent = DynaQ(
        n_states=env.observation_space.n,
        n_actions=env.action_space.n,
        n_planning_steps=n_planning_steps,
    )
 
    episode_rewards = []
    for ep in range(n_episodes):
        state, _ = env.reset()
        total_reward = 0
        done = False
 
        while not done:
            action = agent.choose_action(state)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            agent.update(state, action, reward, next_state)
            state = next_state
            total_reward += reward
 
        episode_rewards.append(total_reward)
 
    return episode_rewards

Comparing Dyna-Q vs plain Q-learning

import gymnasium as gym
import matplotlib.pyplot as plt
 
# Use FrozenLake or Taxi as test environments
env = gym.make("Taxi-v3")
 
# Run with different planning steps
results = {}
for n_plan in [0, 5, 20, 50]:
    rewards = run_dyna_q_experiment(env, n_episodes=500, n_planning_steps=n_plan)
    # Smooth for plotting
    smoothed = np.convolve(rewards, np.ones(20)/20, mode="valid")
    results[n_plan] = smoothed
 
plt.figure(figsize=(10, 6))
for n_plan, rewards in results.items():
    label = "Q-learning" if n_plan == 0 else f"Dyna-Q (plan={n_plan})"
    plt.plot(rewards, label=label)
plt.xlabel("Episode")
plt.ylabel("Reward (smoothed)")
plt.title("Dyna-Q: more planning = faster learning")
plt.legend()
plt.grid(True)
plt.savefig("dyna_q_comparison.png", dpi=150)
plt.show()

Dreamer (v1/v2/v3)

Dreamer learns a world model entirely in latent space (compressed representation), then trains a policy by imagining trajectories within that latent space.

How Dreamer works

  1. Encoder: map observations to latent states
  2. RSSM (Recurrent State Space Model): predict next latent state from current latent state + action
  3. Decoder: reconstruct observations from latent states (for model training)
  4. Reward predictor: predict reward from latent states
  5. Policy: trained entirely in imagination (latent rollouts), no real environment needed
Real world: o_1, a_1, r_1, o_2, a_2, r_2, ...
  ↓ encoder
Latent:     z_1 → z_2 → z_3 → ...  (RSSM predictions)
  ↓ policy training
Actor-Critic: trained on imagined latent trajectories

Why latent space? Real observations (images) are high-dimensional and hard to predict accurately. Latent states are compact, and the model only needs to capture dynamics-relevant information.

Dreamer v3 (2023) works across diverse domains: Atari, continuous control, 3D environments, Minecraft.

MuZero

MuZero (DeepMind, 2020) takes model-based RL to the extreme: it learns a model of the environment without ever predicting observations. It only learns what’s needed for planning.

Three learned components:

  1. Representation: observation → hidden state
  2. Dynamics: hidden state + action → next hidden state + reward
  3. Prediction: hidden state → policy + value (for MCTS planning)

MuZero achieved superhuman performance on Atari and board games (Go, chess, shogi) without knowing the rules of any game.

When model-based wins

  1. Expensive real interactions: each real-world trial with a drone costs time, energy, and risk of crash. A world model lets you do 1000 simulated trials per real trial.
  2. Safety-critical domains: military sim, medical robotics. Can’t explore dangerous actions in the real world.
  3. Known structure: if the physics are well-understood (e.g., rigid body dynamics), the model is accurate and model-based planning is very effective.

When model-based fails

  1. Model errors compound: small prediction errors at each step accumulate over long rollouts. After 50 imagined steps, the model may be in a completely unrealistic state.
  2. Model exploitation: the policy finds inputs that exploit model inaccuracies (get high predicted reward for states that don’t exist in reality). This is the model-based equivalent of reward hacking.
  3. Complex environments: environments with many interacting objects, flexible materials, or chaotic dynamics are hard to model accurately.

Mitigation: use ensembles of models (epistemic uncertainty), limit imagined rollout length, periodically retrain the model with fresh data.

Connection to autonomous systems

The sim-to-real problem is fundamentally about model accuracy. See Tutorial - Sim-to-Real Transfer for practical approaches.

  • Drone navigation: train in simulation (Gazebo, AirSim), transfer to real drone. The simulator IS the world model.
  • Model Predictive Control (MPC): explicitly use a dynamics model to plan optimal actions over a short horizon. Used in robotics and autonomous vehicles.

Self-test questions

  1. Why is model-based RL more sample-efficient than model-free RL?
  2. What is the Dyna architecture, and how does it combine real and simulated experience?
  3. Why does Dreamer learn dynamics in latent space instead of observation space?
  4. What is model exploitation, and how is it analogous to reward hacking?
  5. When would you choose model-free over model-based RL?

Exercises

  1. Dyna-Q comparison: Run the Dyna-Q experiment above with 0, 5, 20, 50 planning steps. Plot learning curves. How much does planning help?
  2. Neural world model: Train the WorldModel class on CartPole transitions. Collect 10k real transitions, train the model, then generate 10k synthetic transitions. Compare the synthetic data distribution with real data.
  3. Plan with learned model: Use the trained world model to do simple planning: for each state, evaluate all possible actions by predicting the next state and reward. Choose the action with highest predicted value. Does this beat random?