Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
139 tokens/sec
GPT-4o
7 tokens/sec
Gemini 2.5 Pro Pro
46 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Large Language Models to Diffusion Finetuning (2501.15781v2)

Published 27 Jan 2025 in cs.CL, cs.AI, and cs.LG

Abstract: We propose a new finetuning method to provide pre-trained LLMs (LMs) the ability to scale test-time compute through the diffusion framework. By increasing the number of diffusion steps, we show our finetuned models achieve monotonically increasing accuracy, directly translating to improved performance across downstream tasks. Furthermore, our finetuned models can expertly answer questions on specific topics by integrating powerful guidance techniques, and autonomously determine the compute required for a given problem by leveraging adaptive ODE solvers. Our method is universally applicable to any foundation model pre-trained with a cross-entropy loss and does not modify any of its original weights, fully preserving its strong single-step generation capabilities. We show our method is more effective and fully compatible with traditional finetuning approaches, introducing an orthogonal new direction to unify the strengths of the autoregressive and diffusion frameworks.

Summary

  • The paper presents L2D, a method that transforms pre-trained LMs into multi-step diffusion models to iteratively refine predictions.
  • It uses a cross-entropy loss and a parallel diffusion path with LoRA for efficient integration alongside frozen LM weights.
  • Experimental results show monotonic performance gains on math and coding tasks as the number of diffusion steps increases.

This paper introduces "LM to Diffusion (L2D)", a novel finetuning method that endows pre-trained LLMs (LMs) with the test-time compute scaling capabilities characteristic of diffusion models. The core idea is to cast LMs as single-step diffusion models and then finetune them to operate effectively over multiple diffusion steps. This allows L2D-finetuned models to monotonically improve their accuracy by increasing the number of diffusion steps at inference, without altering the original LM weights, thereby preserving their strong single-step generation capabilities.

L2D Method: Unifying Autoregressive and Diffusion Frameworks

  1. Gaussian Diffusion Foundation: L2D leverages the Gaussian diffusion framework. The process involves a corruption phase where target data points x1x_1 (token embeddings) are mixed with noise x0N(0,σ2I)x_0 \sim N(0, \sigma^2 I) to produce xt=αtx1+βtx0x_t = \alpha_t x_1 + \beta_t x_0. The paper uses schedules αt=t\alpha_t = t and βt=(1t)\beta_t = (1-t), simplifying the diffusion path, similar to rectified flow.
  2. L2D Parametrization and Training:
    • Instead of the common Mean Squared Error (MSE) loss used in continuous diffusion, L2D employs a cross-entropy loss for training. This aligns it directly with traditional LM training.
    • The loss is formulated as:

      LCE(θ)=Ex0,x1,t[log(fθ(xt,t,c)y)]L^{CE}(\theta) = -E_{x_0, x_1, t}\left[\log\left(f_\theta(x_t, t, c)_{y}\right)\right]

      where x1=Vyx_1 = V_y is the target token embedding, xt=tx1+(1t)x0x_t = tx_1 + (1-t)x_0 is the noisy input at timestep tU[0,1]t \sim U[0,1], cc is the context of preceding tokens, and fθf_\theta is the diffusion model predicting logits over the vocabulary.

    • This setup allows the model fθf_\theta to predict vocabulary logits, just like a standard LM, but conditioned on the noisy token xtx_t and timestep tt.

  3. L2D Inference:
    • Inference follows an Ordinary Differential Equation (ODE) formulation, specifically adopting the constant expected velocity from rectified flow:

      dxt=x^xt1tdx_t=\frac{\hat{x} - x_t}{1-t}

* To get x^\hat{x} (the estimate of the clean token embedding) from the model's categorical output, L2D samples a token ytfθ(xt,t,c)y_t \sim f_\theta(x_t, t, c) and then sets x^=Vyt\hat{x} = V_{y_t} (the embedding of the sampled token). This introduces stochasticity, which is found to be beneficial. * The denoising process starts with x0N(0,σ2I)x_0 \sim N(0, \sigma^2 I) and iteratively refines xtx_t using an ODE solver (e.g., Euler integration: xt+Δt=xt+Δt×dxtx_{t+\Delta_t} = x_t + \Delta_t \times dx_t) over TT steps. * The final prediction is made by sampling yfθ(x1,1,c)y \sim f_\theta(x_1, 1, c).

The inference algorithm is summarized as:

1
2
3
4
5
6
7
8
9
10
11
Algorithm 1: Diffusion LLMing predictions
1: Input diffusion model f_theta, context c, budget T
2: Initialize t = 0, Delta_t = 1/(T-1)
3: Sample x_t ~ N(0, sigma^2 I)
4: FOR i = 1, 2, ..., T-1 DO
5:   Sample y_t ~ f_theta(x_t, t, c)
6:   Set hat_x = V_{y_t}
7:   Compute dx_t = (hat_x - x_t) / (1-t)
8:   Update t = t + Delta_t, x_t = x_t + Delta_t * dx_t
9: END FOR
10: Return y ~ f_theta(x_1, 1, c)

  1. LMs as Single-step Diffusion Models: The paper argues that traditional LMs are effectively single-step diffusion models where t=0t=0 (input is pure noise, uncorrelated with the target). L2D finetuning extends this by enabling multi-step reasoning, leveraging the pre-trained LM's "System 1" understanding.

L2D Implementation Details

L2D is designed as a modular extension to pre-trained transformers, introducing a parallel "diffusion path" while keeping the original LM weights frozen.

  • Architecture (Figure 2):
    • The original LM path (fθlf_{\theta_l}) remains frozen.
    • A new, parallel "diffusion path" (fθdf_{\theta_d}) is introduced to process the diffusion token xtx_t. This path has a transformer architecture similar to the main LM.
    • Initialization: Layers in fθdf_{\theta_d} are initialized with weights from the corresponding layers in fθlf_{\theta_l}.
    • Optimization: Only fθdf_{\theta_d} is trained, typically using Low-Rank Adaptation (LoRA), making finetuning efficient and minimizing memory overhead.
    • Diffusion Path Components: Transformer blocks in fθdf_{\theta_d} consist of MLPs and cross-attention modules. The diffusion token xtkx_t^k (for target token yky^k) attends to keys and values from the self-attention module in fθlf_{\theta_l} for the context.
    • Information Merging: Information from fθdf_{\theta_d} is merged back to fθlf_{\theta_l} only at the final layer, before the LM's linear head, via an element-wise weighted sum: fθl+wdfθdf_{\theta_l} + w_d f_{\theta_d}. The diffusion token xtkx_t^k for target yky^k affects the latents of the previous token xk1x^{k-1} in the main path.
  • Advantages of this Design:
    • Efficient Inference: Latents from fθlf_{\theta_l} and the KV cache are computed once per generated token, regardless of diffusion steps. Only the smaller fθdf_{\theta_d} is computed iteratively.
    • Parallelizable Training: Diffusion losses for all sequence positions can be computed independently for each input context.
  • L2D Conditioning:
    • Diffusion Space Vocabulary: Token embeddings for the diffusion path (VV) are derived from the base LM's vocabulary (VlV^l) by learning a linear mapping WvW_v to a lower dimension dˉ\bar{d} (e.g., 256), followed by norm scaling: Vy=dˉWvVylWvVyl2V_y=\sqrt{\bar{d}}\frac{W_vV_y^l}{||W_vV_y^l||_2}. A small "translation module" maps these dˉ\bar{d}-dimensional embeddings back to the LM's hidden dimension dd.
    • Timestep Conditioning: Timestep tt is incorporated via:
    • 1. Sinusoidal features processed to output shift/scale parameters for LayerNorms in fθdf_{\theta_d}.
    • 2. Time-conditioned element-wise rescaling before residual sums in fθdf_{\theta_d} blocks.
    • 3. Conditioning the final merging weight wd(t)=wθd(t)wθd(0)w_d(t) = w_{\theta_d}(t) - w_{\theta_d}(0). This ensures wd(0)=0w_d(0)=0, so the diffusion path has no effect at t=0t=0, preserving the original LM's t=0t=0 (single-step) behavior.
    • Classifier-Free Guidance (CFG):
    • During training, class embeddings (gjg_j) are added to timestep embeddings. A null class g0g_0 is used with a class-dropout probability.
    • During inference, guided predictions are made: x^g=wgfθ(xt,t,gj,c)(1wg)fθ(xt,t,g0,c)\hat{x}_g = w_g \cdot f_{\theta}(x_t, t, g_j, c) - (1-w_g) \cdot f_{\theta}(x_t, t, g_0, c), where wgw_g is the guidance strength.

Experimental Results

  • Models: Llama 3 (1B, 8B Instruct) and Qwen 2.5 (1.5B, 7B Instruct).
  • Finetuning Data: A subset of SmolTalk dataset, focused on math and coding.
  • Evaluation Tasks:
    • Mathematics: GSM8K, MATH
    • Coding: HumanEval, MBPP (pass@10)
    • General Knowledge: MMLU, MMLU-Pro
  • Hyperparameters: σ=64\sigma=64 for noise scaling, diffusion dimension dˉ=256\bar{d}=256. Inference typically uses a midpoint ODE solver with 8 discretization levels (15 fθdf_{\theta_d} evaluations).

Key Findings:

  1. Performance Improvement: L2D consistently improves performance, especially on math and coding tasks, with minimal additional parameters (LoRA modules). (Table 1)
  2. Superior to Traditional Finetuning: L2D outperforms LoRA and full-weight finetuning on the same data, suggesting it qualitatively differs by augmenting computation rather than just altering weights.
  3. Inference-Time Scaling:
    • Performance monotonically increases with more diffusion steps (Figure 1) and with progression of timestep tt (Figure 3).
    • Most benefits are achieved with a small number of steps (e.g., 15 evaluations).
  4. Adaptive Diffusion Process: Using an adaptive Runge-Kutta ODE solver allows the model to dynamically adjust compute per token, leading to further improvements. The number of steps varies by task difficulty (Figure 4, Table 2).
  5. Full fθdf_{\theta_d} Optimization & Compatibility:
    • Fully finetuning fθdf_{\theta_d} (instead of LoRA) can yield further gains at higher computational cost.
    • L2D is compatible with prior traditional finetuning (applying L2D on an already LoRA/fully finetuned model boosts performance further). (Table 2)
  6. Classifier-Free Guidance: CFG provides visible performance gains, especially when wgw_g is tuned per task domain (Figure 5, Table 2). Math tasks benefit from higher wgw_g.

Ablations and Additional Details (Appendices):

  • L2D benefits from higher learning rates than traditional finetuning.
  • The choice of noise schedule σ\sigma is important; σ=64\sigma=64 was found to be effective.
  • Initializing fθdf_{\theta_d} from fθlf_{\theta_l} is crucial for performance.
  • Sampling x^\hat{x} for velocity computation is better than using the expectation for deterministic ODE solvers.
  • Simpler ODE solvers (Euler, midpoint) work well for few steps; adaptive solvers and advanced timestep schedules (cosmap) show further potential.

Discussion and Future Work

L2D offers a practical way to combine the strengths of autoregressive LMs (strong pre-training, "System 1" knowledge) with diffusion models (test-time compute scaling, iterative refinement, guidance). It enables LMs to "think harder" on difficult problems by allocating more computation. Future work could explore more advanced diffusion techniques, finer-grained guidance, and applications in personalization. The code and data splits are promised to be released.

This paper presents a compelling method for enhancing pre-trained LMs. By adding a lightweight, diffusion-based reasoning path that reuses and augments the existing LM's knowledge, L2D allows for flexible scaling of computational effort at inference time. This leads to improved performance on complex tasks and opens up new avenues for controlling LM behavior through techniques like adaptive solvers and classifier-free guidance, all while preserving the original model's capabilities.

Youtube Logo Streamline Icon: https://streamlinehq.com