- The paper introduces DIAMOND, a diffusion-based world model that directly predicts pixel-level transitions in Atari, preserving visual details for better decision-making.
- The paper employs the EDM framework to reduce sampling steps and computational cost while maintaining high image fidelity compared to discrete models.
- The paper demonstrates state-of-the-art performance on Atari with a mean Human Normalized Score of 1.46, using fewer model parameters than previous approaches.
This paper introduces DIAMOND (DIffusion As a Model Of eNvironment Dreams), a reinforcement learning (RL) agent trained entirely within a world model that uses a diffusion model to predict future observations. The core motivation is that traditional world models often rely on discrete latent representations, which can discard important visual details necessary for optimal decision-making in visually complex environments. Diffusion models, known for high-fidelity image generation, offer a way to model the environment directly in pixel space, potentially preserving these critical details.
Core Idea and Implementation
- Diffusion-based World Model: Instead of encoding observations into discrete tokens, DIAMOND models the transition dynamics p(xt+1∣x≤t,a≤t) directly in image space using a conditional diffusion model. Specifically, it adapts the Elucidating the Design Space of Diffusion-Based Generative Models (EDM) framework (2206.00364).
- Model Input: The diffusion model Dθ takes a noised version of the next frame xt+1τ, the diffusion timestep τ, a history of L previous clean frames xt−L+10,…,xt0, and the corresponding actions at−L+1,…,at as input.
- Model Output: It predicts the clean next frame xt+10.
Training Objective: The model is trained using a denoising score matching objective, essentially an L2 reconstruction loss between the model's prediction and the ground truth clean frame xt+10:
L(θ)=E[∥Dθ(xt+1τ,τ,x≤t0,a≤t)−xt+10∥2]
The specific parameterization follows EDM (Eq. 6, 7), which uses preconditioners (cinτ,coutτ,cskipτ,cnoiseτ) to stabilize training across different noise levels σ(τ). The noise level σ(τ) is sampled from a log-normal distribution during training.
Network Architecture: A standard 2D U-Net architecture is used for the core denoising network Fθ. Past observations are concatenated channel-wise to the noisy input xt+1τ. Past actions and the noise level embedding cnoiseτ are incorporated using Adaptive Group Normalization (AdaGN) layers within the U-Net's residual blocks.
- Choice of Diffusion Framework (EDM vs. DDPM): The paper argues that EDM is crucial for efficiency. Standard DDPM (2006.11239) requires many denoising steps (high NFE) for good quality. EDM's improved preconditioning and training objective (Eq. 7) allow for stable and high-quality generation with very few steps (e.g., NFE=1 to 10). This is critical because the world model needs to be sampled repeatedly during agent training, making low NFE essential for computational feasibility. Experiments show EDM avoids the compounding errors seen with DDPM at low NFE (Figure 4).
- Sampling (Imagination): To generate a trajectory ("dream"), the model starts with Gaussian noise xtT and iteratively applies the learned denoising function Dθ using a numerical solver for the reverse SDE (Eq. 3) or the corresponding ODE.
- Solver: Euler's method is found to be effective.
- Number of Function Evaluations (NFE): While NFE=1 works surprisingly well due to EDM, it can lead to blurry predictions in stochastic environments (averaging over modes). NFE=3 is used in experiments as a balance between capturing sharper details (modes) and computational cost (Figure 5). Each sampling step requires conditioning on the L previous generated frames and the actions selected by the policy.
- Reinforcement Learning Agent:
- Architecture: A separate actor-critic agent (πϕ,Vϕ) is used, based on a shared CNN-LSTM backbone. It takes the current generated observation xt0 as input.
- Training: The agent is trained entirely within the imagined environment generated by the diffusion world model. It uses the REINFORCE algorithm with a value baseline Vϕ(xt) and λ-returns for the value target to improve sample efficiency within imagination.
- Value Loss: LV(ϕ)=Eπϕ[t=0∑H−1(Vϕ(xt)−sg(Λt))2]
- Policy Loss: $\mathcal{L}_\pi(\phi) = - \mathbb{E}_{\pi_\phi} \left[ \sum_{t=0}^{H-1} \log\left(\pi_\phi\left(a_t \mid x_{\le t}\right)\right) \operatorname{sg}\left(\Lambda_t - V_\phi\left(x_t\right)\right) + \eta \operatorname{\mathcal{H}\left(\pi_\phi \left(a_t \mid x_{\le t} \right) \right)\right]$
- Reward/Termination: Separate predictors (Rψ), also using a CNN-LSTM architecture, are trained on the real data to predict rewards and termination signals within the imagination.
- Overall Training Loop (Algorithm 1):
- Collect experience from the real environment using the current policy πϕ.
- Update the diffusion world model Dθ using data from the replay buffer D.
- Update the reward/termination predictor Rψ using data from D.
- Train the actor-critic agent (π,V)ϕ by generating imagined trajectories using Dθ and Rψ, and applying the RL losses.
Practical Implications and Key Results
- State-of-the-Art Performance: DIAMOND achieves a mean Human Normalized Score (HNS) of 1.46 on the Atari 100k benchmark, setting a new record for agents trained purely within a world model.
- Visual Fidelity Matters: Qualitative comparisons (Figure 6) show DIAMOND generates more visually consistent and detailed trajectories than IRIS (2209.00588), a strong transformer-based discrete world model. For example, DIAMOND correctly renders details like scores updating or distinguishes enemies from rewards more reliably. This improved fidelity likely contributes to better agent performance, especially in games like Asterix, Breakout, and Road Runner where small objects are critical.
- Computational Efficiency: By using EDM and only NFE=3, the diffusion world model's sampling cost is kept manageable, comparable to or even faster than some discrete models like IRIS (which used NFE=16). DIAMOND uses fewer parameters (13M) than IRIS (30M) or DreamerV3 (18M) (2301.04104).
- Drop-in Substitute: Since the world model operates in pixel space, it can directly substitute the real environment for analysis or interaction (e.g., the playable models provided with the code).
- Code Availability: The implementation is open-sourced at \href{https://github.com/eloialonso/diamond}{https://github.com/eloialonso/diamond}, facilitating replication and extension.
Implementation Considerations and Limitations
- Conditioning: The current model uses simple frame stacking (L=4) for history. More sophisticated temporal modeling (e.g., using transformers across time like DiT (2212.09748)) might improve long-range dependencies but was found less effective in initial experiments (Appendix G).
- Reward/Termination Model: These are predicted by a separate model. Integrating them into the diffusion process is non-trivial and left for future work.
- Environment Scope: Evaluation is primarily on discrete-action Atari environments. Performance in continuous control or more complex 3D environments needs further investigation (though Appendix G provides promising initial results on CS:GO and Driving datasets).
- Computational Cost: While NFE is low, training still requires significant GPU resources (approx. 2.9 days per game on an RTX 4090).
In summary, DIAMOND demonstrates that using diffusion models as world models is a viable and effective approach for RL, particularly when preserving visual detail is important. The key implementation choices involve using the EDM framework for efficient sampling (low NFE) and conditioning the U-Net appropriately on past states and actions. Its strong performance and improved visual fidelity over discrete models highlight the potential of diffusion for building more capable world models.