Large Language Models to Diffusion Finetuning (2501.15781v2)
Abstract: We propose a new finetuning method to provide pre-trained LLMs (LMs) the ability to scale test-time compute through the diffusion framework. By increasing the number of diffusion steps, we show our finetuned models achieve monotonically increasing accuracy, directly translating to improved performance across downstream tasks. Furthermore, our finetuned models can expertly answer questions on specific topics by integrating powerful guidance techniques, and autonomously determine the compute required for a given problem by leveraging adaptive ODE solvers. Our method is universally applicable to any foundation model pre-trained with a cross-entropy loss and does not modify any of its original weights, fully preserving its strong single-step generation capabilities. We show our method is more effective and fully compatible with traditional finetuning approaches, introducing an orthogonal new direction to unify the strengths of the autoregressive and diffusion frameworks.
Summary
- The paper presents L2D, a method that transforms pre-trained LMs into multi-step diffusion models to iteratively refine predictions.
- It uses a cross-entropy loss and a parallel diffusion path with LoRA for efficient integration alongside frozen LM weights.
- Experimental results show monotonic performance gains on math and coding tasks as the number of diffusion steps increases.
This paper introduces "LM to Diffusion (L2D)", a novel finetuning method that endows pre-trained LLMs (LMs) with the test-time compute scaling capabilities characteristic of diffusion models. The core idea is to cast LMs as single-step diffusion models and then finetune them to operate effectively over multiple diffusion steps. This allows L2D-finetuned models to monotonically improve their accuracy by increasing the number of diffusion steps at inference, without altering the original LM weights, thereby preserving their strong single-step generation capabilities.
L2D Method: Unifying Autoregressive and Diffusion Frameworks
- Gaussian Diffusion Foundation: L2D leverages the Gaussian diffusion framework. The process involves a corruption phase where target data points x1 (token embeddings) are mixed with noise x0∼N(0,σ2I) to produce xt=αtx1+βtx0. The paper uses schedules αt=t and βt=(1−t), simplifying the diffusion path, similar to rectified flow.
- L2D Parametrization and Training:
- Instead of the common Mean Squared Error (MSE) loss used in continuous diffusion, L2D employs a cross-entropy loss for training. This aligns it directly with traditional LM training.
The loss is formulated as:
LCE(θ)=−Ex0,x1,t[log(fθ(xt,t,c)y)]
where x1=Vy is the target token embedding, xt=tx1+(1−t)x0 is the noisy input at timestep t∼U[0,1], c is the context of preceding tokens, and fθ is the diffusion model predicting logits over the vocabulary.
This setup allows the model fθ to predict vocabulary logits, just like a standard LM, but conditioned on the noisy token xt and timestep t.
- L2D Inference:
Inference follows an Ordinary Differential Equation (ODE) formulation, specifically adopting the constant expected velocity from rectified flow:
dxt=1−tx^−xt
* To get x^ (the estimate of the clean token embedding) from the model's categorical output, L2D samples a token yt∼fθ(xt,t,c) and then sets x^=Vyt (the embedding of the sampled token). This introduces stochasticity, which is found to be beneficial. * The denoising process starts with x0∼N(0,σ2I) and iteratively refines xt using an ODE solver (e.g., Euler integration: xt+Δt=xt+Δt×dxt) over T steps. * The final prediction is made by sampling y∼fθ(x1,1,c).
The inference algorithm is summarized as:
1 2 3 4 5 6 7 8 9 10 11 |
Algorithm 1: Diffusion LLMing predictions 1: Input diffusion model f_theta, context c, budget T 2: Initialize t = 0, Delta_t = 1/(T-1) 3: Sample x_t ~ N(0, sigma^2 I) 4: FOR i = 1, 2, ..., T-1 DO 5: Sample y_t ~ f_theta(x_t, t, c) 6: Set hat_x = V_{y_t} 7: Compute dx_t = (hat_x - x_t) / (1-t) 8: Update t = t + Delta_t, x_t = x_t + Delta_t * dx_t 9: END FOR 10: Return y ~ f_theta(x_1, 1, c) |
- LMs as Single-step Diffusion Models: The paper argues that traditional LMs are effectively single-step diffusion models where t=0 (input is pure noise, uncorrelated with the target). L2D finetuning extends this by enabling multi-step reasoning, leveraging the pre-trained LM's "System 1" understanding.
L2D Implementation Details
L2D is designed as a modular extension to pre-trained transformers, introducing a parallel "diffusion path" while keeping the original LM weights frozen.
- Architecture (Figure 2):
- The original LM path (fθl) remains frozen.
- A new, parallel "diffusion path" (fθd) is introduced to process the diffusion token xt. This path has a transformer architecture similar to the main LM.
- Initialization: Layers in fθd are initialized with weights from the corresponding layers in fθl.
- Optimization: Only fθd is trained, typically using Low-Rank Adaptation (LoRA), making finetuning efficient and minimizing memory overhead.
- Diffusion Path Components: Transformer blocks in fθd consist of MLPs and cross-attention modules. The diffusion token xtk (for target token yk) attends to keys and values from the self-attention module in fθl for the context.
- Information Merging: Information from fθd is merged back to fθl only at the final layer, before the LM's linear head, via an element-wise weighted sum: fθl+wdfθd. The diffusion token xtk for target yk affects the latents of the previous token xk−1 in the main path.
- Advantages of this Design:
- Efficient Inference: Latents from fθl and the KV cache are computed once per generated token, regardless of diffusion steps. Only the smaller fθd is computed iteratively.
- Parallelizable Training: Diffusion losses for all sequence positions can be computed independently for each input context.
- L2D Conditioning:
- Diffusion Space Vocabulary: Token embeddings for the diffusion path (V) are derived from the base LM's vocabulary (Vl) by learning a linear mapping Wv to a lower dimension dˉ (e.g., 256), followed by norm scaling: Vy=dˉ∣∣WvVyl∣∣2WvVyl. A small "translation module" maps these dˉ-dimensional embeddings back to the LM's hidden dimension d.
- Timestep Conditioning: Timestep t is incorporated via:
- 1. Sinusoidal features processed to output shift/scale parameters for LayerNorms in fθd.
- 2. Time-conditioned element-wise rescaling before residual sums in fθd blocks.
- 3. Conditioning the final merging weight wd(t)=wθd(t)−wθd(0). This ensures wd(0)=0, so the diffusion path has no effect at t=0, preserving the original LM's t=0 (single-step) behavior.
- Classifier-Free Guidance (CFG):
- During training, class embeddings (gj) are added to timestep embeddings. A null class g0 is used with a class-dropout probability.
- During inference, guided predictions are made: x^g=wg⋅fθ(xt,t,gj,c)−(1−wg)⋅fθ(xt,t,g0,c), where wg is the guidance strength.
Experimental Results
- Models: Llama 3 (1B, 8B Instruct) and Qwen 2.5 (1.5B, 7B Instruct).
- Finetuning Data: A subset of SmolTalk dataset, focused on math and coding.
- Evaluation Tasks:
- Mathematics: GSM8K, MATH
- Coding: HumanEval, MBPP (pass@10)
- General Knowledge: MMLU, MMLU-Pro
- Hyperparameters: σ=64 for noise scaling, diffusion dimension dˉ=256. Inference typically uses a midpoint ODE solver with 8 discretization levels (15 fθd evaluations).
Key Findings:
- Performance Improvement: L2D consistently improves performance, especially on math and coding tasks, with minimal additional parameters (LoRA modules). (Table 1)
- Superior to Traditional Finetuning: L2D outperforms LoRA and full-weight finetuning on the same data, suggesting it qualitatively differs by augmenting computation rather than just altering weights.
- Inference-Time Scaling:
- Performance monotonically increases with more diffusion steps (Figure 1) and with progression of timestep t (Figure 3).
- Most benefits are achieved with a small number of steps (e.g., 15 evaluations).
- Adaptive Diffusion Process: Using an adaptive Runge-Kutta ODE solver allows the model to dynamically adjust compute per token, leading to further improvements. The number of steps varies by task difficulty (Figure 4, Table 2).
- Full fθd Optimization & Compatibility:
- Fully finetuning fθd (instead of LoRA) can yield further gains at higher computational cost.
- L2D is compatible with prior traditional finetuning (applying L2D on an already LoRA/fully finetuned model boosts performance further). (Table 2)
- Classifier-Free Guidance: CFG provides visible performance gains, especially when wg is tuned per task domain (Figure 5, Table 2). Math tasks benefit from higher wg.
Ablations and Additional Details (Appendices):
- L2D benefits from higher learning rates than traditional finetuning.
- The choice of noise schedule σ is important; σ=64 was found to be effective.
- Initializing fθd from fθl is crucial for performance.
- Sampling x^ for velocity computation is better than using the expectation for deterministic ODE solvers.
- Simpler ODE solvers (Euler, midpoint) work well for few steps; adaptive solvers and advanced timestep schedules (cosmap) show further potential.
Discussion and Future Work
L2D offers a practical way to combine the strengths of autoregressive LMs (strong pre-training, "System 1" knowledge) with diffusion models (test-time compute scaling, iterative refinement, guidance). It enables LMs to "think harder" on difficult problems by allocating more computation. Future work could explore more advanced diffusion techniques, finer-grained guidance, and applications in personalization. The code and data splits are promised to be released.
This paper presents a compelling method for enhancing pre-trained LMs. By adding a lightweight, diffusion-based reasoning path that reuses and augments the existing LM's knowledge, L2D allows for flexible scaling of computational effort at inference time. This leads to improved performance on complex tasks and opens up new avenues for controlling LM behavior through techniques like adaptive solvers and classifier-free guidance, all while preserving the original model's capabilities.