Block Diffusion: Interpolating Between Autoregressive and Diffusion Language Models (2503.09573v3)
Abstract: Diffusion LLMs offer unique benefits over autoregressive models due to their potential for parallelized generation and controllability, yet they lag in likelihood modeling and are limited to fixed-length generation. In this work, we introduce a class of block diffusion LLMs that interpolate between discrete denoising diffusion and autoregressive models. Block diffusion overcomes key limitations of both approaches by supporting flexible-length generation and improving inference efficiency with KV caching and parallel token sampling. We propose a recipe for building effective block diffusion models that includes an efficient training algorithm, estimators of gradient variance, and data-driven noise schedules to minimize the variance. Block diffusion sets a new state-of-the-art performance among diffusion models on LLMing benchmarks and enables generation of arbitrary-length sequences. We provide the code, along with the model weights and blog post on the project page: https://m-arriola.com/bd3lms
Summary
- The paper introduces BD3-LMs that integrate autoregressive block-level generation with discrete denoising diffusion, enabling flexible-length text synthesis.
- It leverages a two-pass and vectorized training strategy using custom attention masks and KV caching to boost efficiency by 20–25%.
- Experimental results demonstrate improved perplexity on benchmarks like LM1B and OpenWebText, with enhanced sample quality and accelerated inference.
This paper introduces Block Discrete Denoising Diffusion LLMs (BD3-LMs), a novel class of models that bridges the gap between autoregressive (AR) and discrete denoising diffusion models for language generation. The primary motivation is to overcome key limitations of both paradigms: diffusion models often struggle with likelihood modeling, are restricted to fixed-length generation, and lack efficient inference mechanisms like KV caching, while AR models generate tokens sequentially, limiting speed.
BD3-LMs operate by being autoregressive over blocks of tokens while performing discrete denoising diffusion within each block. This hybrid approach allows for flexible-length generation and improves inference efficiency through KV caching and parallel token sampling within blocks.
Core Concepts and Implementation
1. Model Architecture and Likelihood:
- A sequence of L tokens x is divided into B blocks, each of length L′, so L=B⋅L′.
- The log-likelihood is factorized autoregressively over these blocks:
logpθ(x)=b=1∑Blogpθ(xb∣x<b)
- Each conditional probability pθ(xb∣x<b) is modeled by a discrete diffusion process specific to block b, conditioned on previously generated blocks x<b.
- A single transformer neural network fθ parameterizes the base denoiser for all blocks. It uses a block-causal attention mask, where tokens in block b attend to other tokens within the (potentially noised) block b and all clean tokens in preceding blocks x<b.
- The model supports KV caching:
logitsb,Kb,Vb←fθb(xtb,K1:b−1,V1:b−1)
where xtb is the noised version of block b at timestep t, and K1:b−1,V1:b−1 are cached keys and values from previous blocks.
2. Training Objective:
- The training objective is derived by applying the Negative ELBO (NELBO) to each block-conditional term:
LBD(x;θ):=b=1∑BL(xb,x<b;θ)
where L(xb,x<b;θ) is the standard diffusion NELBO for block b conditioned on x<b.
- For masked BD3-LMs (using a masking noise process), a simplified objective is adopted:
LBD(x;θ):=b=1∑BEt∼[0,1]Eq1−αtαt′logpθ(xb∣xtb,x<b)
where αt defines the noise schedule (probability of a token not being masked at time t), and αt′ is its derivative.
3. Efficient Training Algorithm (Algorithm 1):
- Naively computing the loss would require B separate forward passes for denoising each block, as denoising block b uses a noised xtb while conditioning on clean previous blocks x<b.
- Two-Pass Approach:
1. First Pass (KV Cache Precomputation): Compute keys and values K1:B,V1:B for the entire clean sequence x in one forward pass: (∅,K1:B,V1:B)←fθ(x). 2. Second Pass (Denoising): For each block b, sample noise levels tb and create noised blocks xtbb. Compute denoised predictions for all blocks simultaneously using the precomputed KV cache: logitb,∅,∅←fθb(xtbb,K1:b−1,V1:b−1). - Vectorized Single-Pass Training: An even more efficient method concatenates the noisy data xnoisy=xt11⊕⋯⊕xtBB and clean data x into a single input sequence of length $2L$. A custom attention mask (detailed in Appendix \ref{suppl:masks}) is designed so that noisy tokens attend to other noisy tokens in their block and to clean tokens in preceding blocks. This leverages efficient attention kernels like FlashAttention or the proposed FlexAttention (Appendix \ref{suppl:flex-attention-kernels}), yielding a 20-25% training speed-up over the two-pass approach.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
def vectorized_bd3_lm_train_step(model, clean_sequence_x, num_blocks, block_size): # 1. Sample noise levels t_b for each block and generate noisy_sequence_x_t noisy_blocks_x_t = [] for b in range(num_blocks): block_x_b = clean_sequence_x[b*block_size : (b+1)*block_size] t_b = sample_noise_level() # e.g., from U[0,1] or clipped schedule noisy_blocks_x_t.append(noise_process(block_x_b, t_b)) noisy_sequence_x_t = concatenate(noisy_blocks_x_t) # 2. Concatenate noisy and clean sequences combined_input = concatenate(noisy_sequence_x_t, clean_sequence_x) # Length 2L # 3. Define the specialized attention mask (M_full) # M_BD: noisy_block_b attends to noisy_block_b # M_OBC: noisy_block_b attends to clean_block_<b # M_BC: clean_block_b attends to clean_block_<=b # (See Appendix Fig. 7 for visualization) attention_mask = create_specialized_block_diffusion_mask(num_blocks, block_size, combined_input.length) # 4. Single forward pass # The model's transformer layers use this custom attention_mask. # Logits are typically extracted from the first L positions (corresponding to noisy_sequence_x_t) # which predict the original clean_sequence_x. output_logits = model(combined_input, attention_mask=attention_mask) predicted_denoised_logits = output_logits[:clean_sequence_x.length] # Predictions for x from x_t # 5. Compute loss (e.g., cross-entropy based on Eq. 9 for masked diffusion) # The loss is computed for each block, considering its conditioning context. # For example, for block b, loss uses predicted_denoised_logits for block b, # conditioned on clean_sequence_x_<b (implicitly handled by the mask). loss = compute_block_diffusion_nelbo(predicted_denoised_logits, clean_sequence_x, noise_levels_t, num_blocks) # 6. Backpropagate and update model parameters loss.backward() optimizer.step() return loss |
4. Efficient Sampling Algorithm (Algorithm 2):
- Blocks are generated sequentially.
- For each block b:
- Sample the clean block x^b using a diffusion sampling procedure (e.g., D3PM sampler) conditioned on previously generated clean blocks x^<b (via their cached keys and values K1:b−1,V1:b−1). This step involves multiple denoising steps within the block.
- Compute and cache keys and values for the newly sampled block x^b: ∅,Kb,Vb←fθb(x^b).
- Append x^b to the generated sequence and update the overall KV cache.
- This allows for arbitrary-length sequence generation and benefits from parallel generation within each block.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
def bd3_lm_sample(model, num_blocks_to_generate, block_size, diffusion_sampler): generated_sequence = [] kv_cache_K = [] kv_cache_V = [] for b in range(num_blocks_to_generate): # Conditioning context is implicitly handled by passing kv_cache # The model's forward pass uses cross-attention to these cached K,V # diffusion_sampler runs the iterative denoising process for the current block current_block_sampled_x_hat = diffusion_sampler( model_block_predictor=lambda noisy_block_xt, t: model.predict_denoised_block(noisy_block_xt, t, kv_cache_K, kv_cache_V), block_size=block_size ) # Cache Keys/Values for the newly generated clean block _, new_K_b, new_V_b = model.compute_kv(current_block_sampled_x_hat, kv_cache_K, kv_cache_V) # Pass previous K,V for context generated_sequence.append(current_block_sampled_x_hat) kv_cache_K.append(new_K_b) kv_cache_V.append(new_V_b) return concatenate(generated_sequence) |
5. Addressing Gradient Variance and Improving Performance:
- A key finding is that the perplexity gap between diffusion models and AR models can be attributed to high variance in the gradients of the diffusion objective during training.
- Case Study (L′=1): When block size is 1, BD3-LM is theoretically equivalent to AR. However, standard masked diffusion (masking ~50% of tokens) results in higher perplexity than AR. This is because the diffusion objective effectively trains on fewer tokens per step. By using a "full masking" schedule (q(xtℓ=[MASK]∣xℓ)=1), the BD3-LM (L′=1) matches AR performance, and gradient variance is significantly reduced.
- Clipped Noise Schedules: To minimize gradient variance for L′>1, the paper proposes "clipped" noise schedules where mask rates (1−αt) are sampled uniformly from a sub-interval [β,ω] instead of [0,1]. This avoids extreme masking rates (very few or very many masks) which provide poor learning signals and lead to high-variance gradients.
- Data-Driven Schedule Optimization: The optimal β and ω are found to be block-size dependent. They are learned adaptively during training by performing a grid search at regular intervals to find values that minimize the variance of the NELBO estimator (used as a proxy for gradient variance):
β,ωminVarX,t[L(X;θ,β,ω)]
Experimental Results and Practical Implications
- State-of-the-Art Perplexity: BD3-LMs achieve new state-of-the-art perplexities among discrete diffusion models on LM1B and OpenWebText benchmarks, significantly closing the gap to AR models. For example, on LM1B, BD3-LM (L′=4) achieves ≤28.23 PPL, compared to MDLM's ≤31.78 PPL.
- Variable-Length Generation: BD3-LMs can generate sequences much longer than their training context (e.g., up to ~10x longer than fixed-length diffusion models like SEDD on OWT).
- Improved Sample Quality: BD3-LMs show better generative perplexity (Gen. PPL, evaluated by GPT2-Large) compared to prior diffusion methods like SEDD, MDLM, and SSD-LM, often with an order of magnitude fewer generation steps (NFEs) than methods like SSD-LM.
- For L=2048, BD3-LM (L′=4) achieves Gen. PPL of 23.6 with 2K NFEs, while SSD-LM (L′=25, comparable NFEs) gets 281.9, and MDLM gets 41.3.
- Efficiency of Clipped Schedules: Data-driven clipped noise schedules are shown to reduce training variance and improve test perplexity compared to standard linear or other common schedules. The optimal clipping range varies with block size (e.g., heavier masking for smaller L′).
- Computational Cost: Training BD3-LMs is inherently more expensive than standard diffusion due to the block-autoregressive nature and potentially multiple passes or larger effective sequence lengths. The proposed vectorized training algorithm keeps this overhead manageable (within <2x of standard diffusion). Pre-training with a standard diffusion loss before fine-tuning with the block diffusion objective can further reduce costs.
Implementation Considerations
- Computational Requirements: Training requires careful management of memory and computation, especially with the vectorized approach (concatenating sequences). Efficient attention kernels (FlashAttention, FlexAttention) are crucial.
- Choosing Block Size (L′): The optimal block size is task-dependent. Smaller L′ approaches AR behavior (more sequential steps, potentially better perplexity). Larger L′ increases parallelism but might make learning harder or loosen the NELBO bound more. Experiments show L′=4 often gives the best perplexity.
- Noise Schedule Tuning: Implementing the data-driven clipped schedule optimization requires periodic evaluation of NELBO variance for different [β,ω] ranges. This adds some overhead but is shown to be beneficial.
- KV Cache Implementation: Standard transformer KV caching mechanisms can be adapted. The key is to correctly pass and update the cache across block generation steps during sampling, and to use it appropriately during the second pass or vectorized pass of training.
- Deployment: For inference, the block-sequential generation means latency will be higher than fully parallel diffusion models but potentially lower than token-by-token AR models if L′ is large enough and intra-block parallelism is exploited.
In summary, BD3-LMs offer a practical framework for building high-quality, flexible-length LLMs that combine strengths from AR and diffusion paradigms. The paper provides concrete algorithms for training and sampling, addresses the critical issue of gradient variance through novel noise schedules, and demonstrates strong empirical results. The code and model weights are made available, facilitating adoption and further research.
Related Papers
Tweets
YouTube
HackerNews
- Block Diffusion: Interpolating between autoregressive and diffusion models (156 points, 32 comments)
- Block Diffusion: Interpolating between autoregressive and diffusion models (3 points, 1 comment)
- Block Diffusion: Interpolating Between Autoregressive and Diffusion Models (1 point, 0 comments)
- [2503.09573] Block Diffusion: Interpolating Between Autoregressive and Diffusion Language Models (1 point, 0 comments)