Large Language Diffusion Models (2502.09992v2)
Abstract: Autoregressive models (ARMs) are widely regarded as the cornerstone of LLMs. We challenge this notion by introducing LLaDA, a diffusion model trained from scratch under the pre-training and supervised fine-tuning (SFT) paradigm. LLaDA models distributions through a forward data masking process and a reverse process, parameterized by a vanilla Transformer to predict masked tokens. By optimizing a likelihood bound, it provides a principled generative approach for probabilistic inference. Across extensive benchmarks, LLaDA demonstrates strong scalability, outperforming our self-constructed ARM baselines. Remarkably, LLaDA 8B is competitive with strong LLMs like LLaMA3 8B in in-context learning and, after SFT, exhibits impressive instruction-following abilities in case studies such as multi-turn dialogue. Moreover, LLaDA addresses the reversal curse, surpassing GPT-4o in a reversal poem completion task. Our findings establish diffusion models as a viable and promising alternative to ARMs, challenging the assumption that key LLM capabilities discussed above are inherently tied to ARMs. Project page and codes: https://ml-gsai.github.io/LLaDA-demo/.
Summary
- The paper presents LLaDA, a diffusion-based model that challenges autoregressive approaches by using a discrete token masking and reverse denoising process.
- It utilizes a standard Transformer architecture trained with an ELBO objective to iteratively reconstruct masked sequences.
- Empirical results show competitive in-context learning, instruction following, and superior reversal handling compared to traditional autoregressive models.
The paper "Large Language Diffusion Models" (2502.09992) introduces LLaDA (Large Language Diffusion Assistant), a diffusion-based model designed for large-scale language generation, presenting it as a potential alternative to the prevalent autoregressive models (ARMs) like GPT and LLaMA. The work challenges the assumption that core LLM capabilities such as in-context learning (ICL) and instruction following are intrinsically tied to the autoregressive paradigm. LLaDA is trained from scratch using standard pre-training and supervised fine-tuning (SFT) methodologies.
Methodology: LLaDA Diffusion Process
LLaDA employs a discrete diffusion process operating directly on token sequences. The core idea involves two stages: a forward noising process and a reverse denoising (generation) process.
Forward Process (Data Masking): The forward process, denoted as q(xt∣xt−1), progressively corrupts an initial clean sequence x0 by introducing mask tokens ([MASK]
) over T discrete timesteps. This can be conceptualized as sampling from a transition kernel that replaces tokens with [MASK]
according to a predefined schedule. Unlike continuous diffusion, which adds Gaussian noise, this process operates in the discrete token space. Let x0=(w1,w2,...,wL) be the original sequence of length L. At each step t, tokens are randomly selected and replaced with [MASK]
based on a transition probability, leading to increasingly masked sequences x1,x2,...,xT. The final state xT typically approximates a sequence composed entirely or predominantly of mask tokens. The probability of transitioning from xt−1 to xt often follows a simple rule, like independently masking each non-masked token with a certain probability at step t. The overall distribution q(xt∣x0) can usually be computed in closed form, representing the probability of observing the masked sequence xt given the original x0 after t steps.
Reverse Process (Mask Prediction): The reverse process, parameterized by a neural network pθ(xt−1∣xt,t), aims to reverse the masking process. Given a masked sequence xt and the timestep t, the model predicts the less masked sequence xt−1. Generation starts from a fully masked sequence xT (or a sequence sampled from the prior distribution p(xT)) and iteratively applies the learned reverse transitions pθ(xt−1∣xt,t) for t=T,T−1,...,1 to generate the final sequence x0.
Model Parameterization: LLaDA utilizes a standard Transformer architecture (referred to as "vanilla") to parameterize the reverse process pθ(xt−1∣xt,t). The Transformer takes the masked sequence xt and the timestep t (usually encoded as an embedding and added to the input) as input. Its objective is to predict the original tokens at the masked positions. Specifically, for each position i where xt(i)=[MASK], the model outputs a probability distribution over the vocabulary for the original token x0(i).
Training Objective: The model is trained by optimizing a variational lower bound (ELBO) on the data log-likelihood logpθ(x0). This objective can typically be simplified into a form resembling a denoising score matching objective, weighted across timesteps. A common formulation involves minimizing the negative log-likelihood of predicting the original tokens given the masked sequence xt:
L(θ)=Et∼U(1,T),x0∼D,xt∼q(xt∣x0)[−logpθ(x0∣xt,t)]
Here, D is the training data distribution. The term pθ(x0∣xt,t) often simplifies to predicting the masked tokens. In practice, the model might predict the probability distribution for each masked token independently or jointly, depending on the specific factorization chosen. The paper states LLaDA optimizes a "likelihood bound" for "principled generative approach for probabilistic inference."
Implementation Details
Implementing LLaDA involves several key components:
1. Tokenization and Embedding: Standard subword tokenization (e.g., BPE) is used. Token embeddings are fed into the Transformer. A special [MASK]
token is added to the vocabulary.
2. Masking Schedule: A noise or masking schedule determines the probability of masking tokens at each timestep t. Common schedules include linear, cosine, or square-root schedules, adapted for the discrete masking process. The choice of schedule and the total number of diffusion steps T can significantly impact performance and sampling speed.
3. Transformer Architecture: A standard decoder-only or encoder-decoder Transformer can be adapted. Given the task is to predict masked tokens based on the surrounding context (masked sequence xt), an encoder-like architecture (similar to BERT) or a non-causal decoder seems appropriate. The paper mentions a "vanilla Transformer," suggesting a standard architecture without major modifications specific to diffusion, potentially leveraging bidirectional attention over xt. Timestep t is typically incorporated via sinusoidal embeddings added to the token embeddings.
4. Training:
- Pre-training: LLaDA is pre-trained on large text corpora. During each training step:
- Sample a clean sequence x0 from the dataset.
- Sample a random timestep t∼U(1,T).
- Generate a masked sequence xt by applying the forward process: xt∼q(xt∣x0).
- Feed xt and t into the Transformer model pθ.
- Compute the loss, typically cross-entropy between the model's predicted distributions for masked tokens and the true original tokens in x0.
- Update model parameters θ using gradient descent.
Supervised Fine-Tuning (SFT): After pre-training, LLaDA is fine-tuned on instruction-following datasets (e.g., question-answering pairs, dialogue data). The format likely involves concatenating instruction and response, applying the diffusion training objective to this combined sequence. This stage adapts the model to follow instructions and generate desired outputs.
5. Inference (Sampling): Generating text involves iteratively applying the learned reverse transition pθ(xt−1∣xt,t).
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
function generate(model, length, num_steps): # Start with a fully masked sequence x_t = ["[MASK]"] * length t = num_steps while t > 0: # Predict probabilities for masked tokens based on current x_t and t token_probabilities = model.predict(x_t, t) # Shape: [length, vocab_size] # Sample the previous state x_{t-1} # Option 1: Sample original tokens for *all* positions based on p(x_0 | x_t) # and then sample x_{t-1} using q(x_{t-1} | x_t, x_0). # Option 2 (Simpler heuristic): Directly sample tokens for currently masked positions # and potentially unmask according to the schedule. # Simplified Example: Direct sampling for masked tokens x_{t-1} = list(x_t) # Copy current state predicted_x0 = sample_tokens(token_probabilities) # Sample potential original tokens # Determine which tokens should be unmasked at step t-1 based on schedule q mask_indices_at_t = {i for i, token in enumerate(x_t) if token == "[MASK]"} mask_indices_at_t_minus_1 = determine_mask_indices(t-1, length, schedule=q) # Theoretical indices based on schedule # Unmask tokens that were masked at t but shouldn't be at t-1 indices_to_unmask = mask_indices_at_t - mask_indices_at_t_minus_1 for i in indices_to_unmask: x_{t-1}[i] = predicted_x0[i] # Fill with sampled prediction x_t = tuple(x_{t-1}) # Update state t = t - 1 return x_t # Final generated sequence |
Empirical Evaluation and Results
LLaDA's performance was evaluated against self-constructed ARM baselines and existing strong LLMs.
- Scalability: The paper reports strong scalability, with LLaDA models outperforming their ARM counterparts (trained by the authors for direct comparison) across various model sizes. This suggests that the diffusion framework is amenable to large-scale training.
- In-Context Learning: LLaDA 8B demonstrated competitive performance on ICL benchmarks compared to established ARM models like LLaMA3 8B. This finding is significant as ICL is often considered a haLLMark capability strongly associated with autoregressive generation.
- Instruction Following: After SFT, LLaDA showed proficiency in instruction following, illustrated through case studies involving multi-turn dialogues. This indicates that the diffusion framework can be successfully adapted via SFT to align with user intentions, similar to ARMs.
- Reversal Curse: A notable claim is LLaDA's ability to address the "reversal curse" – the difficulty of standard ARMs in reversing sequences (e.g., given "A is B", query "B is A?"). The paper specifically highlights that LLaDA surpasses GPT-4o on a reversal poem completion task. This suggests the bidirectional nature inherent in attending to the full masked sequence xt during the reverse process might be advantageous for tasks requiring non-sequential reasoning or manipulation.
Discussion and Implications
The introduction of LLaDA presents several implications for the LLM field:
- Viability of Diffusion Models: LLaDA provides empirical evidence that diffusion models are a viable architectural choice for large-scale LLMing, capable of achieving performance comparable to strong ARMs on key benchmarks like ICL and instruction following. This potentially broadens the architectural search space for future foundational models.
- Challenging ARM Dominance: The results question the necessity of the autoregressive formulation for achieving advanced LLM capabilities. If capabilities like ICL are not exclusive to ARMs, it opens avenues for exploring alternative generative frameworks that might offer different trade-offs.
- Potential Advantages: Diffusion models might possess inherent advantages for certain tasks. The reported success on the reversal curse suggests better handling of bidirectional dependencies or non-sequential relationships compared to left-to-right ARMs. Furthermore, diffusion models offer possibilities for controllable generation by manipulating the sampling process or conditioning information. Non-autoregressive generation, characteristic of diffusion sampling (predicting multiple tokens somewhat simultaneously within a step), could potentially lead to faster inference compared to sequential token-by-token generation in ARMs, although iterative refinement over many steps is still required.
- Limitations and Trade-offs: Diffusion models typically require multiple iterative steps for generation, which can be computationally expensive compared to single-pass generation in some non-ARM architectures (though often faster than ARMs per generated token, the total time depends on the number of steps T). The complexity of the sampling process and the choice of masking schedule are critical design decisions. Evaluating computational efficiency (training and inference FLOPs, latency) compared to optimized ARMs remains important.
Conclusion
LLaDA (2502.09992) positions diffusion models as a competitive alternative framework for building LLMs. By demonstrating strong performance in scalability, in-context learning, instruction following, and specific tasks like sequence reversal, the work challenges the prevailing dominance of autoregressive models. While further research is needed to fully understand the trade-offs and optimize diffusion-based LLMs, LLaDA signifies a potentially significant development in exploring diverse architectures for advanced language generation.
Related Papers
Tweets
YouTube
HackerNews
- Large Language Diffusion Models (7 points, 3 comments)
- Large Language Diffusion Models (2 points, 0 comments)
- Diffusion LLM Has Arrived (2 points, 0 comments)
- Large Language Diffusion Models (77 points, 12 comments)