Causal Diffusion Transformers for Generative Modeling
(2412.12095v2)
Published 16 Dec 2024 in cs.CV
Abstract: We introduce Causal Diffusion as the autoregressive (AR) counterpart of Diffusion models. It is a next-token(s) forecasting framework that is friendly to both discrete and continuous modalities and compatible with existing next-token prediction models like LLaMA and GPT. While recent works attempt to combine diffusion with AR models, we show that introducing sequential factorization to a diffusion model can substantially improve its performance and enables a smooth transition between AR and diffusion generation modes. Hence, we propose CausalFusion - a decoder-only transformer that dual-factorizes data across sequential tokens and diffusion noise levels, leading to state-of-the-art results on the ImageNet generation benchmark while also enjoying the AR advantage of generating an arbitrary number of tokens for in-context reasoning. We further demonstrate CausalFusion's multimodal capabilities through a joint image generation and captioning model, and showcase CausalFusion's ability for zero-shot in-context image manipulations. We hope that this work could provide the community with a fresh perspective on training multimodal models over discrete and continuous data.
Summary
The paper introduces CausalFusion, a novel generative model that unifies autoregressive and diffusion paradigms through a decoder-only transformer architecture for both discrete and continuous data.
CausalFusion achieves state-of-the-art results on ImageNet generation benchmarks and enables arbitrary token generation for in-context reasoning and cross-modal integration with language.
The model naturally supports zero-shot image editing and learns superior image representations compared to prior diffusion transformers, improving performance on downstream tasks like classification and captioning.
The paper introduces CausalFusion, an autoregressive counterpart to diffusion models, designed for both discrete and continuous data modalities. CausalFusion integrates sequential and noise-level data factorization, unifying the advantages of autoregressive and diffusion paradigms. The approach allows for adjustable autoregressive and diffusion steps, facilitating a smooth transition between traditional autoregressive and diffusion modes. The model is a decoder-only transformer that factorizes data across sequential tokens and diffusion noise levels.
Key aspects of CausalFusion are:
It achieves state-of-the-art results on the ImageNet generation benchmark.
It enables arbitrary token generation for in-context reasoning.
It facilitates a smooth, cohesive integration with LLMing for cross-modal generation and reasoning.
The paper begins by highlighting the distinction between autoregressive and diffusion models. Autoregressive models factorize data along the sequential axis, conditioning the probability of each token on all preceding tokens. This is expressed mathematically as:
q(x1:L)=q(x1)l=2∏Lq(xl∣x1:l−1)
where:
X is a sample of training images.
x1:L is a sequence of tokens.
L is the number of tokens.
Diffusion models, on the other hand, factorize data along the noise-level axis. The joint distribution is given by:
q(x0:T)=q(x0)t=1∏Tq(xt∣xt−1)
where:
T is the number of diffusion steps.
xt represents the noisy version of the image at step t.
q(xt∣xt−1)=N(xt;1−βtxt−1,βtI)
βt is a variance schedule.
I is the identity matrix.
CausalFusion extends this formulation to encompass autoregressive factorization:
κs is an index set identifying the subset of image tokens processed at the s-th AR step.
xt,κs represents the dual-factorized image tokens at the s-th AR step and t-th diffusion step.
The training objective involves approximating pθ(xt−1,κs∣xt,κs,x0,κ1:s−1) for all t and s, incorporating noised image tokens at the current AR step xt,κs and clean image tokens from previous AR steps x0,κ1:s−1.
Experiments were conducted on the ImageNet dataset, training class-conditional image generation models at 256×256 resolution. The DiT-L/2 model was used as the base configuration. The original DiT incorporates conditional information and the diffusion time step through Adaptive Layer Normalization with zero initialization (AdaLN-zero). CausalFusion adopts an in-context design, treating class and time step conditions as tokens appended to the image token sequence. Several improvements were implemented to stabilize training, including:
Injecting the diffusion time step by adding a time step embedding to the image token embeddings.
Applying head-wise QK layer normalization within the self-attention layers.
Incorporating a learning rate warmup stage during training.
Ablations on AR steps revealed that CausalFusion trained with fixed AR steps cannot be robustly transferred to other inference settings. Models trained with more AR steps exhibited lower loss values, suggesting that the learning tasks become over-simplified as the number of AR steps increases. Training with a random number of AR steps resulted in a highly imbalanced ∣κs∣ distribution, where over 95% of AR steps have ∣κs∣≤16, causing the model to overly rely on visible context.
To address these issues, the authors adjusted the difficulties of the generative tasks in CausalFusion to balance training signal impact and ensure thorough exploration of the factorization space. This was achieved through:
Random AR steps with decayed sampling: Exponentially decreasing the sampling probability as S increases.
Loss weighting along the AR axis: Modifying the weighting term w(⋅) in the diffusion loss to further consider the AR step s, focusing more on the hard generative tasks at early AR steps or larger noise levels.
System-level comparisons on ImageNet class-conditional generation demonstrated that CausalFusion-L achieves an FID-50k of 5.12 without classifier-free guidance (CFG), outperforming DiT-XL/2 by 4.5 points with 50% fewer parameters. CausalFusion-XL further improves this result to 3.61, and when using CFG, achieves a state-of-the-art result of 1.77. CausalFusion-XL also achieved an FID of 1.98 on 512×512 images with CFG.
The model naturally supports zero-shot image editing. A generalized causal attention mechanism was designed for the model to maintain causal dependencies across all AR steps while ensuring that each AR step relies only on clean image tokens from preceding AR steps.
CausalFusion can integrate the language modality by applying a separate next-token prediction loss on text, enabling it to jointly model both image and text data. When compared to Transfusion, CausalFusion demonstrates superior performance in both text-to-image generation and image captioning. CausalFusion also learns superior representations compared to DiT, outperforming it on fine-tuning tasks for ImageNet classification and MSCOCO captioning.