This paper introduces DIAMOND (DIffusion As a Model Of eNvironment Dreams), a reinforcement learning (RL) agent trained entirely within a world model that uses a diffusion model to predict future observations. The core motivation is that traditional world models often rely on discrete latent representations, which can discard important visual details necessary for optimal decision-making in visually complex environments. Diffusion models, known for high-fidelity image generation, offer a way to model the environment directly in pixel space, potentially preserving these critical details.
Core Idea and Implementation
- Diffusion-based World Model: Instead of encoding observations into discrete tokens, DIAMOND models the transition dynamics directly in image space using a conditional diffusion model. Specifically, it adapts the Elucidating the Design Space of Diffusion-Based Generative Models (EDM) framework (Karras et al., 2022
).
- Model Input: The diffusion model takes a noised version of the next frame , the diffusion timestep , a history of previous clean frames , and the corresponding actions as input.
- Model Output: It predicts the clean next frame .
Training Objective: The model is trained using a denoising score matching objective, essentially an L2 reconstruction loss between the model's prediction and the ground truth clean frame :
The specific parameterization follows EDM (Eq. 6, 7), which uses preconditioners () to stabilize training across different noise levels . The noise level is sampled from a log-normal distribution during training.
Network Architecture: A standard 2D U-Net architecture is used for the core denoising network . Past observations are concatenated channel-wise to the noisy input . Past actions and the noise level embedding are incorporated using Adaptive Group Normalization (AdaGN) layers within the U-Net's residual blocks.
- Choice of Diffusion Framework (EDM vs. DDPM): The paper argues that EDM is crucial for efficiency. Standard DDPM (Ho et al., 2020 ) requires many denoising steps (high NFE) for good quality. EDM's improved preconditioning and training objective (Eq. 7) allow for stable and high-quality generation with very few steps (e.g., NFE=1 to 10). This is critical because the world model needs to be sampled repeatedly during agent training, making low NFE essential for computational feasibility. Experiments show EDM avoids the compounding errors seen with DDPM at low NFE (Figure 4).
- Sampling (Imagination): To generate a trajectory ("dream"), the model starts with Gaussian noise and iteratively applies the learned denoising function using a numerical solver for the reverse SDE (Eq. 3) or the corresponding ODE.
- Solver: Euler's method is found to be effective.
- Number of Function Evaluations (NFE): While NFE=1 works surprisingly well due to EDM, it can lead to blurry predictions in stochastic environments (averaging over modes). NFE=3 is used in experiments as a balance between capturing sharper details (modes) and computational cost (Figure 5). Each sampling step requires conditioning on the previous generated frames and the actions selected by the policy.
- Reinforcement Learning Agent:
- Architecture: A separate actor-critic agent () is used, based on a shared CNN-LSTM backbone. It takes the current generated observation as input.
- Training: The agent is trained entirely within the imagined environment generated by the diffusion world model. It uses the REINFORCE algorithm with a value baseline and -returns for the value target to improve sample efficiency within imagination.
- Value Loss:
- Policy Loss: $\mathcal{L}_\pi(\phi) = - \mathbb{E}_{\pi_\phi} \left[ \sum_{t=0}^{H-1} \log\left(\pi_\phi\left(a_t \mid x_{\le t}\right)\right) \operatorname{sg}\left(\Lambda_t - V_\phi\left(x_t\right)\right) + \eta \operatorname{\mathcal{H}\left(\pi_\phi \left(a_t \mid x_{\le t} \right) \right)\right]$
- Reward/Termination: Separate predictors (), also using a CNN-LSTM architecture, are trained on the real data to predict rewards and termination signals within the imagination.
- Overall Training Loop (Algorithm 1):
- Collect experience from the real environment using the current policy .
- Update the diffusion world model using data from the replay buffer .
- Update the reward/termination predictor using data from .
- Train the actor-critic agent by generating imagined trajectories using and , and applying the RL losses.
Practical Implications and Key Results
- State-of-the-Art Performance: DIAMOND achieves a mean Human Normalized Score (HNS) of 1.46 on the Atari 100k benchmark, setting a new record for agents trained purely within a world model.
- Visual Fidelity Matters: Qualitative comparisons (Figure 6) show DIAMOND generates more visually consistent and detailed trajectories than IRIS (Micheli et al., 2022 ), a strong transformer-based discrete world model. For example, DIAMOND correctly renders details like scores updating or distinguishes enemies from rewards more reliably. This improved fidelity likely contributes to better agent performance, especially in games like Asterix, Breakout, and Road Runner where small objects are critical.
- Computational Efficiency: By using EDM and only NFE=3, the diffusion world model's sampling cost is kept manageable, comparable to or even faster than some discrete models like IRIS (which used NFE=16). DIAMOND uses fewer parameters (13M) than IRIS (30M) or DreamerV3 (18M) (Hafner et al., 2023 ).
- Drop-in Substitute: Since the world model operates in pixel space, it can directly substitute the real environment for analysis or interaction (e.g., the playable models provided with the code).
- Code Availability: The implementation is open-sourced at \href{https://github.com/eloialonso/diamond}{https://github.com/eloialonso/diamond}, facilitating replication and extension.
Implementation Considerations and Limitations
- Conditioning: The current model uses simple frame stacking () for history. More sophisticated temporal modeling (e.g., using transformers across time like DiT (Peebles et al., 2022 )) might improve long-range dependencies but was found less effective in initial experiments (Appendix G).
- Reward/Termination Model: These are predicted by a separate model. Integrating them into the diffusion process is non-trivial and left for future work.
- Environment Scope: Evaluation is primarily on discrete-action Atari environments. Performance in continuous control or more complex 3D environments needs further investigation (though Appendix G provides promising initial results on CS:GO and Driving datasets).
- Computational Cost: While NFE is low, training still requires significant GPU resources (approx. 2.9 days per game on an RTX 4090).
In summary, DIAMOND demonstrates that using diffusion models as world models is a viable and effective approach for RL, particularly when preserving visual detail is important. The key implementation choices involve using the EDM framework for efficient sampling (low NFE) and conditioning the U-Net appropriately on past states and actions. Its strong performance and improved visual fidelity over discrete models highlight the potential of diffusion for building more capable world models.