Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
169 tokens/sec
GPT-4o
7 tokens/sec
Gemini 2.5 Pro Pro
45 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Block Diffusion: Interpolating Between Autoregressive and Diffusion Language Models (2503.09573v3)

Published 12 Mar 2025 in cs.LG and cs.AI

Abstract: Diffusion LLMs offer unique benefits over autoregressive models due to their potential for parallelized generation and controllability, yet they lag in likelihood modeling and are limited to fixed-length generation. In this work, we introduce a class of block diffusion LLMs that interpolate between discrete denoising diffusion and autoregressive models. Block diffusion overcomes key limitations of both approaches by supporting flexible-length generation and improving inference efficiency with KV caching and parallel token sampling. We propose a recipe for building effective block diffusion models that includes an efficient training algorithm, estimators of gradient variance, and data-driven noise schedules to minimize the variance. Block diffusion sets a new state-of-the-art performance among diffusion models on LLMing benchmarks and enables generation of arbitrary-length sequences. We provide the code, along with the model weights and blog post on the project page: https://m-arriola.com/bd3lms

Summary

  • The paper introduces BD3-LMs that integrate autoregressive block-level generation with discrete denoising diffusion, enabling flexible-length text synthesis.
  • It leverages a two-pass and vectorized training strategy using custom attention masks and KV caching to boost efficiency by 20–25%.
  • Experimental results demonstrate improved perplexity on benchmarks like LM1B and OpenWebText, with enhanced sample quality and accelerated inference.

This paper introduces Block Discrete Denoising Diffusion LLMs (BD3-LMs), a novel class of models that bridges the gap between autoregressive (AR) and discrete denoising diffusion models for language generation. The primary motivation is to overcome key limitations of both paradigms: diffusion models often struggle with likelihood modeling, are restricted to fixed-length generation, and lack efficient inference mechanisms like KV caching, while AR models generate tokens sequentially, limiting speed.

BD3-LMs operate by being autoregressive over blocks of tokens while performing discrete denoising diffusion within each block. This hybrid approach allows for flexible-length generation and improves inference efficiency through KV caching and parallel token sampling within blocks.

Core Concepts and Implementation

1. Model Architecture and Likelihood:

  • A sequence of LL tokens x\mathbf{x} is divided into BB blocks, each of length LL', so L=BLL = B \cdot L'.
  • The log-likelihood is factorized autoregressively over these blocks:

    logpθ(x)=b=1Blogpθ(xbx<b)\log p_\theta(\mathbf{x}) = \sum_{b = 1}^{B} \log p_\theta(\mathbf{x}^{b} \mid \mathbf{x}^{<b})

  • Each conditional probability pθ(xbx<b)p_\theta(\mathbf{x}^{b} \mid \mathbf{x}^{<b}) is modeled by a discrete diffusion process specific to block bb, conditioned on previously generated blocks x<b\mathbf{x}^{<b}.
  • A single transformer neural network fθf_\theta parameterizes the base denoiser for all blocks. It uses a block-causal attention mask, where tokens in block bb attend to other tokens within the (potentially noised) block bb and all clean tokens in preceding blocks x<b\mathbf{x}^{<b}.
  • The model supports KV caching:

    logitsb,Kb,Vbfθb(xtb,K1:b1,V1:b1)\text{logits}^b, \mathbf{K}^b, \mathbf{V}^b \gets f_\theta^b(\mathbf{x}_t^b, \mathbf{K}^{1:b-1}, \mathbf{V}^{1:b-1})

    where xtb\mathbf{x}_t^b is the noised version of block bb at timestep tt, and K1:b1,V1:b1\mathbf{K}^{1:b-1}, \mathbf{V}^{1:b-1} are cached keys and values from previous blocks.

2. Training Objective:

  • The training objective is derived by applying the Negative ELBO (NELBO) to each block-conditional term:

    LBD(x;θ):=b=1BL(xb,x<b;θ)\mathcal{L}_\text{BD}(\mathbf{x}; \theta) := \sum_{b=1}^{B} \mathcal{L}(\mathbf{x}^b, \mathbf{x}^{<b}; \theta)

    where L(xb,x<b;θ)\mathcal{L}(\mathbf{x}^b, \mathbf{x}^{<b}; \theta) is the standard diffusion NELBO for block bb conditioned on x<b\mathbf{x}^{<b}.

  • For masked BD3-LMs (using a masking noise process), a simplified objective is adopted:

    LBD(x;θ):=b=1BEt[0,1]Eqαt1αtlogpθ(xbxtb,x<b)\mathcal{L}_\text{BD}(\mathbf{x}; \theta) := \sum_{b=1}^{B} \mathbb{E}_{t \sim [0, 1]} \mathbb{E}_{q} \frac{\alpha_t'}{1-\alpha_t} \log p_\theta(\mathbf{x}^b \mid \mathbf{x}_{t}^b, \mathbf{x}^{<b})

    where αt\alpha_t defines the noise schedule (probability of a token not being masked at time tt), and αt\alpha_t' is its derivative.

3. Efficient Training Algorithm (Algorithm 1):

  • Naively computing the loss would require BB separate forward passes for denoising each block, as denoising block bb uses a noised xtb\mathbf{x}_t^b while conditioning on clean previous blocks x<b\mathbf{x}^{<b}.
  • Two-Pass Approach:

1. First Pass (KV Cache Precomputation): Compute keys and values K1:B,V1:B\mathbf{K}^{1:B}, \mathbf{V}^{1:B} for the entire clean sequence x\mathbf{x} in one forward pass: (,K1:B,V1:B)fθ(x)(\emptyset, \mathbf{K}^{1:B}, \mathbf{V}^{1:B}) \gets f_\theta(\mathbf{x}). 2. Second Pass (Denoising): For each block bb, sample noise levels tbt_b and create noised blocks xtbb\mathbf{x}_{t_b}^b. Compute denoised predictions for all blocks simultaneously using the precomputed KV cache: logitb,,fθb(xtbb,K1:b1,V1:b1)\text{logit}^b, \emptyset, \emptyset \gets f_\theta^b(\mathbf{x}_{t_b}^b, \mathbf{K}^{1:b-1}, \mathbf{V}^{1:b-1}). - Vectorized Single-Pass Training: An even more efficient method concatenates the noisy data xnoisy=xt11xtBB\mathbf{x}_\text{noisy} = \mathbf{x}^1_{t_1} \oplus \dots \oplus \mathbf{x}^B_{t_B} and clean data x\mathbf{x} into a single input sequence of length $2L$. A custom attention mask (detailed in Appendix \ref{suppl:masks}) is designed so that noisy tokens attend to other noisy tokens in their block and to clean tokens in preceding blocks. This leverages efficient attention kernels like FlashAttention or the proposed FlexAttention (Appendix \ref{suppl:flex-attention-kernels}), yielding a 20-25% training speed-up over the two-pass approach.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def vectorized_bd3_lm_train_step(model, clean_sequence_x, num_blocks, block_size):
    # 1. Sample noise levels t_b for each block and generate noisy_sequence_x_t
    noisy_blocks_x_t = []
    for b in range(num_blocks):
        block_x_b = clean_sequence_x[b*block_size : (b+1)*block_size]
        t_b = sample_noise_level() # e.g., from U[0,1] or clipped schedule
        noisy_blocks_x_t.append(noise_process(block_x_b, t_b))
    noisy_sequence_x_t = concatenate(noisy_blocks_x_t)

    # 2. Concatenate noisy and clean sequences
    combined_input = concatenate(noisy_sequence_x_t, clean_sequence_x) # Length 2L

    # 3. Define the specialized attention mask (M_full)
    #    M_BD: noisy_block_b attends to noisy_block_b
    #    M_OBC: noisy_block_b attends to clean_block_<b
    #    M_BC: clean_block_b attends to clean_block_<=b
    #    (See Appendix Fig. 7 for visualization)
    attention_mask = create_specialized_block_diffusion_mask(num_blocks, block_size, combined_input.length)

    # 4. Single forward pass
    # The model's transformer layers use this custom attention_mask.
    # Logits are typically extracted from the first L positions (corresponding to noisy_sequence_x_t)
    # which predict the original clean_sequence_x.
    output_logits = model(combined_input, attention_mask=attention_mask)
    predicted_denoised_logits = output_logits[:clean_sequence_x.length] # Predictions for x from x_t

    # 5. Compute loss (e.g., cross-entropy based on Eq. 9 for masked diffusion)
    # The loss is computed for each block, considering its conditioning context.
    # For example, for block b, loss uses predicted_denoised_logits for block b,
    # conditioned on clean_sequence_x_<b (implicitly handled by the mask).
    loss = compute_block_diffusion_nelbo(predicted_denoised_logits, clean_sequence_x, noise_levels_t, num_blocks)

    # 6. Backpropagate and update model parameters
    loss.backward()
    optimizer.step()
    return loss

4. Efficient Sampling Algorithm (Algorithm 2):

  • Blocks are generated sequentially.
  • For each block bb:
    1. Sample the clean block x^b\hat{\mathbf{x}}^b using a diffusion sampling procedure (e.g., D3PM sampler) conditioned on previously generated clean blocks x^<b\hat{\mathbf{x}}^{<b} (via their cached keys and values K1:b1,V1:b1\mathbf{K}^{1:b-1}, \mathbf{V}^{1:b-1}). This step involves multiple denoising steps within the block.
    2. Compute and cache keys and values for the newly sampled block x^b\hat{\mathbf{x}}^b: ,Kb,Vbfθb(x^b)\emptyset, \mathbf{K}^{b}, \mathbf{V}^{b} \gets f_\theta^b(\hat{\mathbf{x}}^b).
    3. Append x^b\hat{\mathbf{x}}^b to the generated sequence and update the overall KV cache.
  • This allows for arbitrary-length sequence generation and benefits from parallel generation within each block.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def bd3_lm_sample(model, num_blocks_to_generate, block_size, diffusion_sampler):
    generated_sequence = []
    kv_cache_K = []
    kv_cache_V = []

    for b in range(num_blocks_to_generate):
        # Conditioning context is implicitly handled by passing kv_cache
        # The model's forward pass uses cross-attention to these cached K,V
        # diffusion_sampler runs the iterative denoising process for the current block
        current_block_sampled_x_hat = diffusion_sampler(
            model_block_predictor=lambda noisy_block_xt, t: model.predict_denoised_block(noisy_block_xt, t, kv_cache_K, kv_cache_V),
            block_size=block_size
        )

        # Cache Keys/Values for the newly generated clean block
        _, new_K_b, new_V_b = model.compute_kv(current_block_sampled_x_hat, kv_cache_K, kv_cache_V) # Pass previous K,V for context

        generated_sequence.append(current_block_sampled_x_hat)
        kv_cache_K.append(new_K_b)
        kv_cache_V.append(new_V_b)

    return concatenate(generated_sequence)

5. Addressing Gradient Variance and Improving Performance:

  • A key finding is that the perplexity gap between diffusion models and AR models can be attributed to high variance in the gradients of the diffusion objective during training.
  • Case Study (L=1L'=1): When block size is 1, BD3-LM is theoretically equivalent to AR. However, standard masked diffusion (masking ~50% of tokens) results in higher perplexity than AR. This is because the diffusion objective effectively trains on fewer tokens per step. By using a "full masking" schedule (q(xt=[MASK]x)=1q(\mathbf{x}_t^\ell = \text{[MASK]} \mid \mathbf{x}^\ell) = 1), the BD3-LM (L=1L'=1) matches AR performance, and gradient variance is significantly reduced.
  • Clipped Noise Schedules: To minimize gradient variance for L>1L' > 1, the paper proposes "clipped" noise schedules where mask rates (1αt1-\alpha_t) are sampled uniformly from a sub-interval [β,ω][\beta, \omega] instead of [0,1][0, 1]. This avoids extreme masking rates (very few or very many masks) which provide poor learning signals and lead to high-variance gradients.
  • Data-Driven Schedule Optimization: The optimal β\beta and ω\omega are found to be block-size dependent. They are learned adaptively during training by performing a grid search at regular intervals to find values that minimize the variance of the NELBO estimator (used as a proxy for gradient variance):

    minβ,ωVarX,t[L(X;θ,β,ω)]\min_{\beta, \omega} \text{Var}_{\mathbf{X}, t} \left[ \mathcal{L}(\mathbf{X}; \theta, \beta, \omega) \right]

Experimental Results and Practical Implications

  • State-of-the-Art Perplexity: BD3-LMs achieve new state-of-the-art perplexities among discrete diffusion models on LM1B and OpenWebText benchmarks, significantly closing the gap to AR models. For example, on LM1B, BD3-LM (L=4L'=4) achieves 28.23\le 28.23 PPL, compared to MDLM's 31.78\le 31.78 PPL.
  • Variable-Length Generation: BD3-LMs can generate sequences much longer than their training context (e.g., up to ~10x longer than fixed-length diffusion models like SEDD on OWT).
  • Improved Sample Quality: BD3-LMs show better generative perplexity (Gen. PPL, evaluated by GPT2-Large) compared to prior diffusion methods like SEDD, MDLM, and SSD-LM, often with an order of magnitude fewer generation steps (NFEs) than methods like SSD-LM.
    • For L=2048L=2048, BD3-LM (L=4L'=4) achieves Gen. PPL of 23.6 with 2K NFEs, while SSD-LM (L=25L'=25, comparable NFEs) gets 281.9, and MDLM gets 41.3.
  • Efficiency of Clipped Schedules: Data-driven clipped noise schedules are shown to reduce training variance and improve test perplexity compared to standard linear or other common schedules. The optimal clipping range varies with block size (e.g., heavier masking for smaller LL').
  • Computational Cost: Training BD3-LMs is inherently more expensive than standard diffusion due to the block-autoregressive nature and potentially multiple passes or larger effective sequence lengths. The proposed vectorized training algorithm keeps this overhead manageable (within <2x of standard diffusion). Pre-training with a standard diffusion loss before fine-tuning with the block diffusion objective can further reduce costs.

Implementation Considerations

  • Computational Requirements: Training requires careful management of memory and computation, especially with the vectorized approach (concatenating sequences). Efficient attention kernels (FlashAttention, FlexAttention) are crucial.
  • Choosing Block Size (LL'): The optimal block size is task-dependent. Smaller LL' approaches AR behavior (more sequential steps, potentially better perplexity). Larger LL' increases parallelism but might make learning harder or loosen the NELBO bound more. Experiments show L=4L'=4 often gives the best perplexity.
  • Noise Schedule Tuning: Implementing the data-driven clipped schedule optimization requires periodic evaluation of NELBO variance for different [β,ω][\beta, \omega] ranges. This adds some overhead but is shown to be beneficial.
  • KV Cache Implementation: Standard transformer KV caching mechanisms can be adapted. The key is to correctly pass and update the cache across block generation steps during sampling, and to use it appropriately during the second pass or vectorized pass of training.
  • Deployment: For inference, the block-sequential generation means latency will be higher than fully parallel diffusion models but potentially lower than token-by-token AR models if LL' is large enough and intra-block parallelism is exploited.

In summary, BD3-LMs offer a practical framework for building high-quality, flexible-length LLMs that combine strengths from AR and diffusion paradigms. The paper provides concrete algorithms for training and sampling, addresses the critical issue of gradient variance through novel noise schedules, and demonstrates strong empirical results. The code and model weights are made available, facilitating adoption and further research.

Youtube Logo Streamline Icon: https://streamlinehq.com