d1: Scaling Reasoning in Diffusion Large Language Models via Reinforcement Learning (2504.12216v2)
Abstract: Recent LLMs have demonstrated strong reasoning capabilities that benefits from online reinforcement learning (RL). These capabilities have primarily been demonstrated within the left-to-right autoregressive (AR) generation paradigm. In contrast, non-autoregressive paradigms based on diffusion generate text in a coarse-to-fine manner. Although recent diffusion-based LLMs (dLLMs) have achieved competitive LLMing performance compared to their AR counterparts, it remains unclear if dLLMs can also leverage recent advances in LLM reasoning. To this end, we propose d1, a framework to adapt pre-trained masked dLLMs into reasoning models via a combination of supervised finetuning (SFT) and RL. Specifically, we develop and extend techniques to improve reasoning in pretrained dLLMs: (a) we utilize a masked SFT technique to distill knowledge and instill self-improvement behavior directly from existing datasets, and (b) we introduce a novel critic-free, policy-gradient based RL algorithm called diffu-GRPO, the first integration of policy gradient methods to masked dLLMs. Through empirical studies, we investigate the performance of different post-training recipes on multiple mathematical and planning benchmarks. We find that d1 yields the best performance and significantly improves performance of a state-of-the-art dLLM. Our code is released at https://dLLM-reasoning.github.io/.
Summary
- The paper introduces a two-stage framework combining supervised finetuning and the novel diffu RL algorithm to enhance reasoning in masked diffusion LLMs.
- diffu innovatively estimates log-probabilities via random prompt masking and mean-field approximations, enabling efficient policy gradient updates.
- Experiments demonstrate that d1 yields significant performance improvements across tasks like GSM8K, MATH500, Sudoku, and Countdown.
This paper introduces d1, a two-stage framework designed to enhance the reasoning capabilities of pre-trained masked diffusion LLMs (dLLMs). The authors address the challenge that while autoregressive (AR) LLMs have significantly benefited from reinforcement learning (RL) for reasoning tasks, it's unclear if non-autoregressive dLLMs can achieve similar gains due to fundamental differences in their generation process.
The d1 framework consists of:
- Supervised Finetuning (SFT): The pre-trained masked dLLM is first fine-tuned on high-quality reasoning traces. This stage aims to distill knowledge and instill self-improvement behaviors.
- Reinforcement Learning (RL) with
diffu
: A novel policy-gradient-based RL algorithm calleddiffu
is introduced. This algorithm is specifically designed for masked dLLMs and adapts concepts from Group Relative Policy Optimization (GRPO).
Key Challenges and Solutions for RL in dLLMs
A core challenge in applying existing RL algorithms like PPO or GRPO to dLLMs is the computation of log-probabilities of generated sequences. AR models can easily compute this via sequential factorization. DdLLMs, however, generate text iteratively and non-sequentially, making direct log-probability calculation difficult and computationally expensive.
diffu
addresses this through:
- Efficient Log-Probability Estimation:
- Sequence Log-Probability: A mean-field approximation is used, decomposing the sequence log-probability into a sum of independent per-token log-probabilities: logπθ(o∣q)≈k=1∑∣o∣logπθ(ok∣q).
- Per-Token Log-Probability: A one-step estimation method is proposed. Given a prompt q and a generated completion o, the prompt q is perturbed by randomly masking its tokens with probability pmask to create q′. The log-probability is then estimated using a single call to the unmasking predictor: logfθ(ok∣q′⊕MASK…⊕MASK).
- Random Prompt Masking as Regularization:
- During policy gradient updates, the prompt q is randomly masked to q′ for each update step. This stochastic masking creates perturbed views of the same (prompt, completion) pairs.
- This acts as a form of regularization and data augmentation, allowing for a higher number of gradient updates (μ) per batch of samples without overfitting. This significantly reduces the number of computationally expensive online generations required, improving training efficiency.
The diffu
loss function is an adaptation of the GRPO loss, using these estimated log-probabilities:
$\resizebox{0.94\columnwidth}{!}{$% \begin{aligned} \mathcal{L}_{\text{diffu}(\theta)=\;& E_{\substack{q \sim \mathcal{D},\, q' \sim \text{masking}(q),\ o_1,\dots,o_G \sim \pi_{\theta_{\text{old}}}(\,\cdot \mid q)}\Bigg[ \frac{1}{G}\sum_{i=1}^{G}\frac{1}{|o_i|}\sum_{k=1}^{|o_i|} \min\!\Bigg( \frac{\phi^{\pi_\theta}(o^k_i \mid q')} {\phi^{\pi_{\theta_{\text{old}}}}(o^k_i \mid q')}A_i^{k}(\pi_{\theta_{\text{old}}}),\[4pt] &\operatorname{clip}\!\Bigg( \frac{\phi^{\pi_\theta}(o^k_i \mid q')} {\phi^{\pi_{\theta_{\text{old}}}}(o^k_i \mid q')}, 1-\varepsilon,\,1+\varepsilon \Bigg)A_i^{k}(\pi_{\theta_{\text{old}}}) \Bigg) -\beta\,D_{\text{KL}\!\Bigl[ \phi^{\pi_\theta}(\cdot \mid q')\,\bigl\|\,\phi^{\pi_{\text{ref}}}(\cdot \mid q') \Bigr] \Bigg] \end{aligned}$}%$
where ϕπ(ok∣q′) and ϕπ(o∣q′) are the estimated per-token and sequence probabilities for policy π, Aik(πθold) is the advantage calculated using the old policy πθold, and πref is the reference policy.
The diffu
algorithm is summarized as follows:
1 2 3 4 5 6 7 8 9 10 11 12 13 |
Algorithm: diffu 1. Initialize current policy pi_theta from reference model pi_ref. 2. Loop until convergence: 3. Set old policy pi_theta_old = pi_theta. 4. Sample a prompt q. 5. Sample G completions o_i from pi_theta_old given q. 6. For each o_i, compute reward r_i and advantage A_i^k. 7. For n = 1 to mu (number of inner updates): 8. Randomly mask prompt q to get q' (with probability p_mask). 9. Estimate log-probabilities for pi_theta, pi_theta_old, pi_ref given q' and o_i. 10. Compute diffu objective L_diffu(theta). 11. Update pi_theta by gradient descent on L_diffu(theta). 12. Return pi_theta. |
Supervised Finetuning (SFT) Implementation
SFT is performed on the LLaDA model using the s1K dataset, which contains high-quality reasoning traces. The SFT process involves:
- Sampling a prompt-response pair (p0,r0) and a random timestep t.
- Constructing a partially masked response rt by masking tokens of r0 according to the noise schedule αt=1−t.
- Calculating the loss based on predicting the original tokens r0i given the prompt p0 and the masked response rt: L(θ)=−t∣r0∣1i=1∑∣r0∣1[rti=MASK]logfθ(r0i∣p0⊕rt).
- Updating model parameters θ.
Practical SFT considerations include handling truncated sequences (s1K has long sequences), ensuring loss is computed on PAD tokens for effective generation termination, and balancing dataset difficulty with model strength.
Experimental Setup and Results
- Base Model: LLaDA-8B-Instruct.
- Tasks:
- Mathematical Reasoning: GSM8K, MATH500.
- Logical Reasoning: 4x4 Sudoku, Countdown (3 numbers).
- Training:
- SFT: On s1K dataset for 20 epochs, sequence length 4096.
- RL (
diffu
): Task-specific training. GSM8K/MATH500 on their training splits; Countdown/Sudoku on synthetic datasets. Online generation sequence length limited to 256. LoRA was used fordiffu
training (r=128,α=64).
- Evaluation: Zero-shot prompting with generation sequence lengths 128, 256, and 512.
Key Findings:
diffu
consistently outperforms base LLaDA and SFT-only: Applyingdiffu
alone (LLaDA+diffu) showed larger gains than SFT alone (LLaDA+SFT) across most setups.diffu
improves over its initialization: Whether starting from the base LLaDA or an SFT-adapted checkpoint,diffu
provided performance gains.- d1 recipe (SFT +
diffu
) yields the highest gains: d1-LLaDA outperformed purediffu
in 11 out of 12 setups, indicating a synergistic effect. Gains were modest on GSM8K (+3.9%) and MATH500 (+4.0%) but significant on Countdown (+26.2%) and Sudoku (+10.0%), potentially due to the base model nearing saturation on math tasks. - Reasoning improvement beyond training sequence length:
diffu
trained with a 256 token sequence length showed improvements at 128 and 512 token evaluation lengths. - Qualitative "aha moments": At sequence length 512, models trained with SFT (LLaDA+SFT and d1-LLaDA) exhibited self-correction and backtracking behaviors, likely learned from the s1K reasoning traces.
- An example from Appendix C for LLaDA+SFT on a GSM8K problem (calculating rows of 5 stars) shows the model first calculating the answer, then explicitly stating a check: "However, we need to check if the number of rows is 8, as if there are 8 rows of 5 stars, the total number of stars would be... This matches..."
- Another example for d1-LLaDA on a percentage problem (puppies with spots) shows the model making an initial calculation, then stating: "However, it seems there was a mistake in the calculation. Let's recheck the steps." followed by the correct calculation.
- Sequential scaling with generation length: Performance generally improved with longer sequence lengths for GSM8K and MATH500. Trends were mixed for Countdown and Sudoku, suggesting these search-intensive tasks might require stronger base dLLMs for robust scaling.
Ablations on diffu
:
- Randomized Masking Benefit: Random prompt masking during log-probability estimation consistently outperformed fixed masking. It allowed scaling the number of policy updates per batch (μ) to higher values (e.g., 12 or 24 vs. typical 2), leading to faster convergence and reduced computational cost by requiring fewer online generations.
- For example, Figure 4 shows that with random masking, μ=12 and μ=24 achieve similar or better GSM8K correctness rewards with far fewer generated completions (and thus less wall-clock time) compared to μ=2 or fixed masking scenarios.
- Effect of Masking Rate (pmask): Lower masking probabilities (e.g., 0.1, 0.3) for the prompt during log-probability estimation led to more stable training and better performance. Higher rates (e.g., 0.5, 0.7) introduced instability. The paper suggests pmask≤0.3 is optimal.
Implementation Considerations
- Computational Cost of dLLM Generation: Online generation for RL is slow. The paper limits generation length to 256 during RL training due to this. Efficient inference algorithms for dLLMs are crucial for scaling RL training further.
- Log-Probability Estimation: The proposed one-step estimator with prompt masking is key for making policy gradient methods feasible for masked dLLMs. The mean-field approximation for sequence log-probability is a practical simplification.
- Hyperparameters for
diffu
:- Number of inner updates μ: Can be higher with random masking (e.g., 12-24).
- Prompt token masking probability pmask: Around 0.15 was used in experiments, with ablations suggesting ≤0.3 is good.
- LoRA rank (r=128) and alpha (α=64) for RL.
- SFT Design Choices:
- Dataset: High-quality reasoning traces like s1K are beneficial.
- Sequence Length: Truncation might be necessary.
- Padding: Loss on PAD tokens is important. Padding to max model length can be better than batch-wise max length for small batches.
- LoRA rank (r=8) and alpha (α=16) for SFT.
Conclusion
The d1 framework, combining SFT with the novel diffu
RL algorithm, successfully scales reasoning capabilities in masked dLLMs. diffu
's efficient log-probability estimation with randomized prompt masking is a key innovation for applying policy gradient methods to this model class. The results demonstrate significant improvements on various reasoning benchmarks, showcasing a promising path for enhancing non-autoregressive LLMs. Future work includes developing more efficient decoding strategies for dLLMs to facilitate even more effective RL training.
Related Papers
- Language Models are Hidden Reasoners: Unlocking Latent Reasoning Capabilities via Self-Rewarding (2024)
- DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning (2025)
- Large Language Diffusion Models (2025)
- Reinforcing the Diffusion Chain of Lateral Thought with Diffusion Language Models (2025)
- AceReason-Nemotron 1.1: Advancing Math and Code Reasoning through SFT and RL Synergy (2025)