Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
140 tokens/sec
GPT-4o
7 tokens/sec
Gemini 2.5 Pro Pro
46 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Causal Diffusion Transformers for Generative Modeling (2412.12095v2)

Published 16 Dec 2024 in cs.CV

Abstract: We introduce Causal Diffusion as the autoregressive (AR) counterpart of Diffusion models. It is a next-token(s) forecasting framework that is friendly to both discrete and continuous modalities and compatible with existing next-token prediction models like LLaMA and GPT. While recent works attempt to combine diffusion with AR models, we show that introducing sequential factorization to a diffusion model can substantially improve its performance and enables a smooth transition between AR and diffusion generation modes. Hence, we propose CausalFusion - a decoder-only transformer that dual-factorizes data across sequential tokens and diffusion noise levels, leading to state-of-the-art results on the ImageNet generation benchmark while also enjoying the AR advantage of generating an arbitrary number of tokens for in-context reasoning. We further demonstrate CausalFusion's multimodal capabilities through a joint image generation and captioning model, and showcase CausalFusion's ability for zero-shot in-context image manipulations. We hope that this work could provide the community with a fresh perspective on training multimodal models over discrete and continuous data.

Summary

  • The paper introduces CausalFusion, a novel generative model that unifies autoregressive and diffusion paradigms through a decoder-only transformer architecture for both discrete and continuous data.
  • CausalFusion achieves state-of-the-art results on ImageNet generation benchmarks and enables arbitrary token generation for in-context reasoning and cross-modal integration with language.
  • The model naturally supports zero-shot image editing and learns superior image representations compared to prior diffusion transformers, improving performance on downstream tasks like classification and captioning.

The paper introduces CausalFusion, an autoregressive counterpart to diffusion models, designed for both discrete and continuous data modalities. CausalFusion integrates sequential and noise-level data factorization, unifying the advantages of autoregressive and diffusion paradigms. The approach allows for adjustable autoregressive and diffusion steps, facilitating a smooth transition between traditional autoregressive and diffusion modes. The model is a decoder-only transformer that factorizes data across sequential tokens and diffusion noise levels.

Key aspects of CausalFusion are:

  • It achieves state-of-the-art results on the ImageNet generation benchmark.
  • It enables arbitrary token generation for in-context reasoning.
  • It facilitates a smooth, cohesive integration with LLMing for cross-modal generation and reasoning.

The paper begins by highlighting the distinction between autoregressive and diffusion models. Autoregressive models factorize data along the sequential axis, conditioning the probability of each token on all preceding tokens. This is expressed mathematically as:

q(x1:L)=q(x1)l=2Lq(xlx1:l1)q(\mathbf{x}_{1:L}) = q(\mathbf{x}_{1}) \prod_{l=2}^{L} q(\mathbf{x}_{l} | \mathbf{x}_{1:l-1})

where:

  • X\mathbf{X} is a sample of training images.
  • x1:L\mathbf{x}_{1:L} is a sequence of tokens.
  • LL is the number of tokens.

Diffusion models, on the other hand, factorize data along the noise-level axis. The joint distribution is given by:

q(x0:T)=q(x0)t=1Tq(xtxt1)q(\mathbf{x}_{0:T}) = q(\mathbf{x}_0)\prod_{t=1}^T q(\mathbf{x}_t | \mathbf{x}_{t-1})

where:

  • TT is the number of diffusion steps.
  • xt\mathbf{x}_t represents the noisy version of the image at step tt.
  • q(xtxt1)=N(xt;1βtxt1,βtI)q(\mathbf{x}_t | \mathbf{x}_{t-1}) = \mathcal{N}(\mathbf{x}_t; \sqrt{1 - \beta_t} \mathbf{x}_{t-1}, \beta_t \mathbf{I})
    • βt\beta_t is a variance schedule.
    • I\mathbf{I} is the identity matrix.

CausalFusion extends this formulation to encompass autoregressive factorization:

q(x0:T,κsx0,κ1:s1)=q(x0,κs)t=1Tq(xt,κsxt1,κs,x0,κ1:s1)q(\mathbf{x}_{0:T,\kappa_s} | \mathbf{x}_{0,\kappa_{1:s-1}}) = q(\mathbf{x}_{0,\kappa_s}) \prod_{t=1}^T q(\mathbf{x}_{t,\kappa_s} | \mathbf{x}_{t-1,\kappa_s},\mathbf{x}_{0,\kappa_{1:s-1}})

where:

  • SS denotes the total number of AR steps.
  • κs\kappa_s is an index set identifying the subset of image tokens processed at the ss-th AR step.
  • xt,κs\mathbf{x}_{t,\kappa_s} represents the dual-factorized image tokens at the ss-th AR step and tt-th diffusion step.

The training objective involves approximating pθ(xt1,κsxt,κs,x0,κ1:s1)p_\theta(\mathbf{x}_{t-1,\kappa_s} | \mathbf{x}_{t,\kappa_s},\mathbf{x}_{0,\kappa_{1:s-1}}) for all tt and ss, incorporating noised image tokens at the current AR step xt,κs\mathbf{x}_{t,\kappa_s} and clean image tokens from previous AR steps x0,κ1:s1\mathbf{x}_{0,\kappa_{1:s-1}}.

Experiments were conducted on the ImageNet dataset, training class-conditional image generation models at 256×256256 \times 256 resolution. The DiT-L/2 model was used as the base configuration. The original DiT incorporates conditional information and the diffusion time step through Adaptive Layer Normalization with zero initialization (AdaLN-zero). CausalFusion adopts an in-context design, treating class and time step conditions as tokens appended to the image token sequence. Several improvements were implemented to stabilize training, including:

  • Injecting the diffusion time step by adding a time step embedding to the image token embeddings.
  • Applying head-wise QK layer normalization within the self-attention layers.
  • Incorporating a learning rate warmup stage during training.

Ablations on AR steps revealed that CausalFusion trained with fixed AR steps cannot be robustly transferred to other inference settings. Models trained with more AR steps exhibited lower loss values, suggesting that the learning tasks become over-simplified as the number of AR steps increases. Training with a random number of AR steps resulted in a highly imbalanced κs|\kappa_s| distribution, where over 95% of AR steps have κs16|\kappa_s| \leq 16, causing the model to overly rely on visible context.

To address these issues, the authors adjusted the difficulties of the generative tasks in CausalFusion to balance training signal impact and ensure thorough exploration of the factorization space. This was achieved through:

  • Random AR steps with decayed sampling: Exponentially decreasing the sampling probability as SS increases.
  • Loss weighting along the AR axis: Modifying the weighting term w()w(\cdot) in the diffusion loss to further consider the AR step ss, focusing more on the hard generative tasks at early AR steps or larger noise levels.

System-level comparisons on ImageNet class-conditional generation demonstrated that CausalFusion-L achieves an FID-50k of 5.12 without classifier-free guidance (CFG), outperforming DiT-XL/2 by 4.5 points with 50% fewer parameters. CausalFusion-XL further improves this result to 3.61, and when using CFG, achieves a state-of-the-art result of 1.77. CausalFusion-XL also achieved an FID of 1.98 on 512×512512 \times 512 images with CFG.

The model naturally supports zero-shot image editing. A generalized causal attention mechanism was designed for the model to maintain causal dependencies across all AR steps while ensuring that each AR step relies only on clean image tokens from preceding AR steps.

CausalFusion can integrate the language modality by applying a separate next-token prediction loss on text, enabling it to jointly model both image and text data. When compared to Transfusion, CausalFusion demonstrates superior performance in both text-to-image generation and image captioning. CausalFusion also learns superior representations compared to DiT, outperforming it on fine-tuning tasks for ImageNet classification and MSCOCO captioning.

Youtube Logo Streamline Icon: https://streamlinehq.com
Reddit Logo Streamline Icon: https://streamlinehq.com