When to Trust Your Model: Model-Based Policy Optimization (1906.08253v3)
Abstract: Designing effective model-based reinforcement learning algorithms is difficult because the ease of data generation must be weighed against the bias of model-generated data. In this paper, we study the role of model usage in policy optimization both theoretically and empirically. We first formulate and analyze a model-based reinforcement learning algorithm with a guarantee of monotonic improvement at each step. In practice, this analysis is overly pessimistic and suggests that real off-policy data is always preferable to model-generated on-policy data, but we show that an empirical estimate of model generalization can be incorporated into such analysis to justify model usage. Motivated by this analysis, we then demonstrate that a simple procedure of using short model-generated rollouts branched from real data has the benefits of more complicated model-based algorithms without the usual pitfalls. In particular, this approach surpasses the sample efficiency of prior model-based methods, matches the asymptotic performance of the best model-free algorithms, and scales to horizons that cause other model-based methods to fail entirely.
- Michael Janner (14 papers)
- Justin Fu (20 papers)
- Marvin Zhang (10 papers)
- Sergey Levine (531 papers)
Summary
This paper explores the critical trade-off in model-based reinforcement learning (MBRL) between leveraging a learned dynamics model for data augmentation (improving sample efficiency) and the detrimental effects of model inaccuracies (bias) that can hinder policy performance. The central question is how to best utilize a potentially imperfect model within a policy optimization loop to accelerate learning without compromising the final policy quality. The authors develop a theoretical framework to analyze model usage and propose a practical algorithm, Model-Based Policy Optimization (MBPO), demonstrating significant empirical gains.
Theoretical Analysis of Model Usage
The paper first establishes a theoretical basis for analyzing model-based policy updates, aiming for monotonic improvement guarantees similar to those found in model-free policy gradient methods.
Initial Monotonic Improvement Bound:
The analysis begins by considering a general MBRL scheme where a policy π is optimized using a learned model P^. Theorem 4.1 provides a lower bound on the true expected return η[π] based on the expected return under the model η^[π]:
η[π]≥η^[π]−(1−γ)22ϵmγ−(1−γ)34ϵπγ
Here, ϵm=Es∼dπD[DTV(P^(⋅∣s,π(s))∣∣P(⋅∣s,π(s)))] represents the average total variation divergence between the true dynamics P and the model P^ under the state distribution dπD induced by the data-collecting policy πD. ϵπ=Es∼dπD[DTV(π(⋅∣s)∣∣πD(⋅∣s))] measures the divergence between the current policy π and the policy πD used to collect the data for model training. This bound suggests that if the policy improvement achieved under the model, Δη^[π]=η^[π]−η^[πold], is sufficiently large to overcome the error terms, then monotonic improvement in the true environment is guaranteed. However, the constants are pessimistic, scaling poorly with the effective horizon (1/(1−γ)), especially the ϵπ term which depends on the cubic power. This pessimism arises from worst-case assumptions about how model errors compound and how the state distribution shifts with the policy.
Analysis of Branched Rollouts:
To mitigate compounding errors, the paper analyzes a specific model usage strategy: generating rollouts of finite length k starting from states st sampled from the real environment replay buffer Denv. Theorem 4.2 analyzes the performance difference using these k-step model rollouts: ∣η[π]−η^k[π]∣≤… (simplified expression) The resulting bound shows that the error contribution from the model ϵm scales linearly with the rollout horizon k, while the error contribution from the policy shift ϵπ (before branching) decays exponentially with k. Despite this improved structure compared to infinite horizon rollouts, the pessimistic constants derived under worst-case assumptions still often suggest that k=0 (no model usage) is optimal if ϵm is assumed to potentially increase substantially due to the policy shift (ϵπ).
Incorporating Empirical Model Generalization:
The key insight is that the standard bounds are overly pessimistic because they don't account for the model's ability to generalize to states visited by the new policy π, even if π differs from πD. The authors argue that with sufficient data, the model error on the new policy's state distribution, ϵm′=Es∼dπ[DTV(P^(⋅∣s,π(s))∣∣P(⋅∣s,π(s)))], might not increase drastically with the policy divergence ϵπ. They empirically observe (Figure 1) that model error scales sublinearly with dataset size and the divergence between training and testing policies.
By incorporating an empirical estimate of ϵm′ (conceptually, measuring model error on states visited by π), a revised bound (Theorem 4.3) is derived. This bound has a more favorable dependency on model error: η[π]≥η^k[π]−(1−γ)22ϵm′γ(1−γk)−… (simplified expression) Crucially, this revised bound can be minimized by a non-zero rollout length k∗>0 when the estimated model error on the current policy's states, ϵm′, is sufficiently low. This provides the theoretical justification for using short, finite-horizon model rollouts branched from real data: they limit error accumulation (linear scaling with k) while benefiting from the model's generalization capabilities.
MBPO Algorithm Implementation
Motivated by the theoretical insights, the paper proposes the Model-Based Policy Optimization (MBPO) algorithm. MBPO combines an ensemble of probabilistic dynamics models with an off-policy model-free RL algorithm (SAC) using the short, branched rollout strategy.
Core Components:
- Dynamics Model Ensemble: An ensemble of B probabilistic dynamics models {P^θi}i=1B is trained on the most recent data collected from the real environment, stored in a replay buffer Denv. Each model typically predicts the mean and variance of the next state distribution, st+1∼N(μθi(st,at),Σθi(st,at)). Training uses standard maximum likelihood estimation on transitions (st,at,rt,st+1) from Denv. Ensembles help capture model uncertainty.
- Off-Policy RL Algorithm: A state-of-the-art off-policy algorithm, Soft Actor-Critic (SAC), is used for policy learning. SAC maintains actors πϕ and critics Qψ, trained using data from a replay buffer.
- Branched Rollout Generation: This is the core mechanism for model usage.
- Periodically, the algorithm samples a batch of states {st} from the real data buffer Denv.
- For each sampled state st, the current policy πϕ is unrolled for k steps using the learned dynamics models. At each step τ within the rollout (0≤τ<k):
- An action at+τ is sampled from the policy πϕ(⋅∣st+τ).
- A model is chosen from the ensemble (e.g., randomly).
- The chosen model P^θi predicts the next state st+τ+1 and reward rt+τ.
- The generated k-step transitions (st+τ,at+τ,rt+τ,st+τ+1) for 0≤τ<k are collected.
- Model Data Buffer: These synthetic transitions are stored in a separate model replay buffer Dmodel.
- Policy Training: The SAC algorithm is trained by sampling batches of data from both the real buffer Denv and the model buffer Dmodel. A ratio hyperparameter controls the proportion of updates performed using real versus model-generated data.
Algorithmic Flow:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
Initialize policy π_ϕ, Q-functions Q_ψ, target Q-functions Q_ψ̄, model ensemble {P̂_θi}, env buffer D_env, model buffer D_model. for each environment step t do: Execute policy: a_t ~ π_ϕ(·|s_t) Observe transition: (s_t, a_t, r_t, s_{t+1}) Store real transition in D_env: D_env ← D_env ∪ {(s_t, a_t, r_t, s_{t+1})} if time to train model then: Train model ensemble {P̂_θi} on recent data from D_env. if time to generate model data then: Clear D_model. for n_rollouts times do: Sample initial state s_0 ~ D_env. s ← s_0 for step k' = 0 to k-1 do: a ~ π_ϕ(·|s) Randomly select model P̂_θi from ensemble. s_next, r ~ P̂_θi(s, a) // Predict next state and reward Store synthetic transition in D_model: D_model ← D_model ∪ {(s, a, r, s_next)} s ← s_next if time to train policy then: for n_policy_updates times do: Sample real batch B_env ~ D_env. Sample model batch B_model ~ D_model. Combine batches: B ← B_env ∪ B_model. Perform SAC update (actor and critic) using batch B. Update target networks Q_ψ̄. end for |
Practical Implications and Applications
MBPO demonstrated strong empirical results, offering significant advantages for applying RL in scenarios where real-world interaction is costly or time-consuming.
- Sample Efficiency: The most prominent result is the dramatic improvement in sample efficiency. On MuJoCo benchmarks like Hopper, Walker2d, and notably Ant, MBPO achieved high levels of performance using significantly fewer environment steps (often 10-20x fewer) compared to the model-free SAC baseline. It also outperformed prior MBRL methods like PETS and STEVE in sample efficiency.
- Asymptotic Performance: Unlike many older MBRL methods that might learn faster initially but plateau at suboptimal performance due to model bias, MBPO successfully matched the high asymptotic performance of the state-of-the-art model-free algorithm SAC on the tested environments.
- Scalability: MBPO proved effective on tasks with standard, non-truncated horizons (e.g., 1000 steps) and higher state dimensions (e.g., Ant), areas where methods relying on long model rollouts (like PETS) can struggle due to compounding errors.
- Simplicity of Short Rollouts: The ablation studies revealed that a very short rollout length, specifically k=1, yielded most of the performance benefits. This simplifies implementation significantly, as it avoids the need for complex trajectory optimization or long-horizon planning within the model, essentially performing one-step model-based lookahead to generate data points for the off-policy learner. While adaptively tuning k gave slightly better results, fixed k=1 was a robust and effective strategy.
- Mitigation of Model Exploitation: The paper found little evidence that the policy was exploiting inaccuracies in the model ensemble (i.e., finding high predicted returns that didn't correspond to high real returns). This robustness is attributed mainly to the use of short rollouts (k) branched from real states. Because the rollouts start from the distribution of recently visited real states and are short, they are less likely to diverge into unrealistic state regions where the model might be inaccurate and exploitable.
Implementation Considerations
- Computational Cost: MBPO trades environment interaction cost for computational cost. Training the model ensemble and generating rollouts requires significant computation, but this can often be parallelized and is typically faster/cheaper than interacting with a physical system or complex simulator. The number of gradient updates per environment step is much higher than in typical MFRL settings.
- Hyperparameter Sensitivity: Key hyperparameters include:
k
: The model rollout length. k=1 is a strong default, but optimal k might be task-dependent or benefit from minor tuning/scheduling.- Model ensemble size (B): Typically 5-7 models provide a good balance between uncertainty capture and computational cost.
- Model training frequency: How often the models are retrained on the latest data from Denv.
- Gradient update ratio: The number of gradient steps per environment step and the ratio of model vs. real data used for updates. The paper used a high number of updates (e.g., 40) per environment step, primarily using model data.
- Model Architecture: The paper used simple probabilistic feedforward networks for the dynamics models. More complex architectures might be needed for environments with complex dynamics (e.g., images), increasing computational demands.
- Choice of Off-Policy Learner: While SAC was used, the branched rollout data generation strategy could likely be paired with other efficient off-policy algorithms like TD3.
- Potential Limitations: Performance hinges on the ability to learn a reasonably accurate local dynamics model. In highly stochastic environments or domains with very complex, hard-to-model dynamics, the benefits might diminish. The theoretical justification (Theorem 4.3) relies on empirical estimation of model generalization, which could be noisy or difficult to measure accurately in practice.
Conclusion
"When to Trust Your Model: Model-Based Policy Optimization" provides valuable theoretical insights and a practical algorithmic solution (MBPO) for effectively integrating learned dynamics models into policy optimization. By leveraging ensembles and, critically, using short model rollouts (k) branched from real data, MBPO achieves substantial improvements in sample efficiency (often >10x) compared to model-free methods while matching their asymptotic performance. The effectiveness of even single-step (k=1) rollouts makes it a compelling and relatively simple approach for accelerating RL in data-constrained settings without the common pitfalls of model bias encountered in traditional MBRL methods relying on long rollouts.
Related Papers
- Model-Augmented Actor-Critic: Backpropagating through Paths (2020)
- Scalable Model-based Policy Optimization for Decentralized Networked Systems (2022)
- On-Policy Model Errors in Reinforcement Learning (2021)
- Trust the Model When It Is Confident: Masked Model-based Actor-Critic (2020)
- Variational Model-based Policy Optimization (2020)