Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash 99 tok/s
Gemini 2.5 Pro 54 tok/s Pro
GPT-5 Medium 37 tok/s
GPT-5 High 38 tok/s Pro
GPT-4o 111 tok/s
GPT OSS 120B 470 tok/s Pro
Kimi K2 243 tok/s Pro
2000 character limit reached

Exploring the Deep Fusion of Large Language Models and Diffusion Transformers for Text-to-Image Synthesis (2505.10046v1)

Published 15 May 2025 in cs.CV

Abstract: This paper does not describe a new method; instead, it provides a thorough exploration of an important yet understudied design space related to recent advances in text-to-image synthesis -- specifically, the deep fusion of LLMs and diffusion transformers (DiTs) for multi-modal generation. Previous studies mainly focused on overall system performance rather than detailed comparisons with alternative methods, and key design details and training recipes were often left undisclosed. These gaps create uncertainty about the real potential of this approach. To fill these gaps, we conduct an empirical study on text-to-image generation, performing controlled comparisons with established baselines, analyzing important design choices, and providing a clear, reproducible recipe for training at scale. We hope this work offers meaningful data points and practical guidelines for future research in multi-modal generation.

List To Do Tasks Checklist Streamline Icon: https://streamlinehq.com

Collections

Sign up for free to add this paper to one or more collections.

Summary

  • The paper demonstrates that deep fusion of a frozen LLM and a trainable DiT enhances text-image alignment (GenEval up to 0.51) and inference speed.
  • Methodical evaluations reveal that design choices like removing timestep conditioning and using mixed positional encodings can improve FID scores and model efficiency.
  • Experiments scaling to FuseDiT (2B parameters) provide actionable guidelines for multi-modal synthesis, achieving significant performance boosts with limited compute.

This paper presents an empirical paper on the "deep fusion" of a frozen LLM and a trainable Diffusion Transformer (DiT) for text-to-image synthesis (2505.10046). Instead of introducing a novel method, the authors provide a thorough exploration of this specific design space, which they argue has been underexplored in terms of controlled comparisons, design choices, and reproducible training recipes. The primary goal is to offer meaningful data points and practical guidelines for future research in multi-modal generation.

The core deep fusion architecture involves integrating a frozen decoder-only LLM with a trainable DiT using layer-wise shared self-attention. The DiT's transformer architecture mirrors the LLM's, creating a two-stream system. Text embeddings are processed by the LLM, and noisy image latents by the DiT. At each layer, token sequences from both streams are concatenated for the self-attention operation, allowing the DiT to access linguistic context. A causal attention mask is used for text tokens and a bidirectional mask for image tokens, permitting image tokens to attend to text tokens but not vice-versa. This setup allows the key and value states of the text hidden states to be cached during inference, improving efficiency. The training objective uses the rectified flow formulation.

Experimental Setup:

  • LLM: Frozen Gemma 2B (and later Gemma 2 2B).
  • DiT: Randomly initialized 2.5B parameter DiT, architecturally matching the LLM (2B backbone). It uses 2D frequency absolute positional encoding, adaLN-Zero timestep conditioning, ViT-style weight initialization, and a patch size of 2. QK normalization is applied.
  • VAE: 16-channel VAE from Stable Diffusion 3.
  • Dataset: CC12M with synthetic captions (10.9M pairs) for most experiments, resized to 512×512512 \times 512. Texts padded/truncated to 256 tokens. For scaled experiments, a mix of CC12M, SA-1B, and JourneyDB (approx. 26M pairs) is used.
  • Training: Batch size 512, AdamW optimizer, constant LR 1×1041 \times 10^{-4}, weight decay 1×1041 \times 10^{-4}, gradient clipping 1.0, BF16 precision. EMA with decay 0.99. 10% text dropout for unconditional generation. Trained on TPU v4-256 pods.
  • Inference: Euler discretization, 25 steps, CFG scale 6.
  • Evaluation: GenEval and DPG-Bench for text-image alignment, FID on MJHQ-30K for visual quality.

Key Findings and Comparisons:

  1. Deep Fusion vs. Shallow Fusion:
    • Shallow Fusion Baselines:
      • Self-attention DiT: Text representations (from LLM's last layer) projected to K, V states and concatenated with image K, V states in self-attention.
      • Cross-attention DiT: Text representations projected to K, V states for an additional cross-attention mechanism with image hidden states after self-attention.
    • Results: Deep fusion (2.45B params) achieved better text-image alignment (GenEval: 0.51) compared to self-attention DiT (2.47B params, GenEval: 0.42) and cross-attention DiT (2.62B params, GenEval: 0.49). However, shallow fusion models showed better visual quality (FID: Deep Fusion 27.33, Self-Attn 26.16, Cross-Attn 24.00). Deep fusion was slightly faster in inference (1.66s vs 1.75s/1.86s on A100).
    Method Params GenEval \uparrow DPG \uparrow FID \downarrow Inference Latency (s)
    Self-Attention 2.47B 0.42 73.9 26.16 1.75
    Cross-Attention 2.62B 0.49 76.3 24.00 1.86
    Deep Fusion 2.45B 0.51 76.6 27.33 1.66
  2. Examining Design Choices for Deep Fusion:
    • Timestep Conditioning:

      • Adding CLIP L/14 text embeddings to adaLN-Zero improved FID (24.00 vs 27.33) but slightly hurt alignment (GenEval 0.50 vs 0.51).
      • Reducing adaLN parameters (adaLN-Single, Addition) or removing timestep conditioning altogether was explored.
      • Removing timestep conditioning entirely (1.98B params) surprisingly yielded the best FID (21.27) while maintaining comparable alignment (GenEval 0.49), reducing parameters by 20%.
      Method Params GenEval \uparrow DPG \uparrow FID \downarrow
      adaLN-Zero 2.47B 0.51 76.6 27.33
      w/o timestep 1.98B 0.49 76.7 21.27

* Positional Encoding: * Default: 1D RoPE for text + APE for image. * 1D RoPE (text) + 2D RoPE (image) performed best overall (GenEval 0.51, FID 25.42) compared to default (GenEval 0.51, FID 27.33) and M-RoPE (GenEval 0.49, FID 27.60).

* Base LLM (Gemma 2B variants): * Instruction-tuned Gemma 2B IT (with or without "Imagine: " prompt) did not improve and sometimes slightly worsened performance compared to base Gemma 2B. * Gemma 2B with multi-modal tuning (from PaliGemma 3B PT) showed slight improvements (GenEval 0.52 vs 0.51, FID 26.30 vs 27.33). * Upgrading to Gemma 2 2B (next-gen model) yielded a drastic performance boost (GenEval 0.54 vs 0.51, DPG 79.1 vs 76.6, FID 23.94 vs 27.33), indicating strong dependency on base LLM capabilities.

  1. Training at Scale (FuseDiT):
    • Final Recipe:
      • Removed AdaLN-Zero modules.
      • Used 1D RoPE (text) + 2D RoPE (image).
      • Replaced Gemma 2B with Gemma 2 2B (adjusting DiT accordingly).
    • Trained for 800K steps on a mixed dataset of ~26M image-caption pairs (CC12M, SA-1B, JourneyDB with synthetic captions).
    • FuseDiT (2B params) achieved GenEval 0.60, DPG 81.6, and FID 7.54, surpassing many established systems despite limited compute and data.
  2. Further Explorations:
    • Architecture Alignment: Explored decoupling LLM and DiT architectures by reducing DiT hidden size or layers while fusing into middle LLM layers.
      • Reducing DiT hidden size (e.g., from 2048 to 1792) gracefully degraded alignment but sometimes improved FID (24.27 for 1792 hidden).
      • Reducing DiT layers (e.g., from 18 to 14) showed faster performance degradation for alignment. This suggests DiT architecture can be scaled independently to some extent.
    • Attention Mechanism: Replacing shared self-attention with a deep fusion variant of cross-attention (using LLM layer K,V states instead of projected last-layer K,V) showed minor gains in alignment (GenEval 0.52 vs 0.51) but increased latency by ~12% (1.86s vs 1.66s), leading to retention of self-attention.

Implementation of Deep Fusion:

The core idea is the shared self-attention mechanism across LLM and DiT layers.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
class DeepFusionTransformerLayer:
    def __init__(self, LLM_layer, dit_layer):
        self.LLM_layer = LLM_layer # Frozen
        self.dit_layer = dit_layer # Trainable

    def forward(self, text_hidden_states, image_hidden_states, text_attention_mask, image_attention_mask, timestep_embedding=None):
        # LLM processes text (simplified, actual LLM layer is more complex)
        # In practice, LLM K,V states for text can be precomputed if text is fixed
        LLM_output_text = self.LLM_layer.self_attn(text_hidden_states, attention_mask=text_attention_mask)
        # LLM_output_text = self.LLM_layer.mlp(LLM_output_text)
        # For deep fusion, we need K, V from text for the DiT's attention

        # DiT processes image, conditioned on text
        # 1. Modality-specific processing (e.g., AdaLN for DiT if used)
        if timestep_embedding is not None and hasattr(self.dit_layer, 'adaLN_modulation'):
             scale, shift = self.dit_layer.adaLN_modulation(timestep_embedding).chunk(2, dim=1)
             image_hidden_states = self.dit_layer.norm1(image_hidden_states)
             image_hidden_states = image_hidden_states * (1 + scale) + shift
        else:
             image_hidden_states = self.dit_layer.norm1(image_hidden_states) # if no adaLN

        # 2. Shared Self-Attention
        # Get Q from image, K,V from both image and text
        q_img = self.dit_layer.to_q(image_hidden_states)

        # K,V from LLM's corresponding layer's text_hidden_states
        # These are a_T in the paper's notation (text activations)
        # Note: The paper mentions "key and value states of the text hidden states"
        # This implies using the K,V projections from the LLM's attention mechanism
        # or passing text_hidden_states through DiT's K,V projections.
        # Figure 1 suggests text_hidden_states go through LLM's own weights
        # then K,V are used by DiT. Let's assume LLM provides its K,V.
        k_text = self.LLM_layer.self_attn.to_k(text_hidden_states) # Or from a compatible projection
        v_text = self.LLM_layer.self_attn.to_v(text_hidden_states) # Or from a compatible projection

        k_img = self.dit_layer.to_k(image_hidden_states)
        v_img = self.dit_layer.to_v(image_hidden_states)

        # Concatenate K,V from text and image for DiT's attention
        # Text tokens can only attend to preceding text tokens (causal_mask_text)
        # Image tokens can attend to all image tokens and all text tokens
        # This is handled by constructing a combined_attention_mask (as in Fig 3)
        # For image tokens attending to text:
        #   q_img attends to k_text, v_text and k_img, v_img

        combined_k = torch.cat([k_text, k_img], dim=1) # sequence dimension
        combined_v = torch.cat([v_text, v_img], dim=1) # sequence dimension

        # Construct the specific attention mask (Fig 3 of the paper)
        # image_tokens_attend_to_text_mask (all ones)
        # image_tokens_attend_to_image_mask (bidirectional)
        # For q_img, the mask for combined_k, combined_v would allow attention to all text and all image
        
        # Simplified attention calculation for DiT stream
        # image_attn_output = scaled_dot_product_attention(q_img, combined_k, combined_v, attention_mask=correct_combined_mask_for_img_queries)
        # This part is typically handled by a multi-head attention module in DiT
        image_attn_output = self.dit_layer.attn(query=q_img, key=combined_k, value=combined_v, attention_mask=create_deep_fusion_mask(text_seq_len, img_seq_len))
        
        image_hidden_states = image_hidden_states + self.dit_layer.drop(image_attn_output)

        # 3. MLP for DiT
        image_hidden_states = image_hidden_states + self.dit_layer.mlp(self.dit_layer.norm2(image_hidden_states))

        # Return only the processed image_hidden_states for the DiT stream
        # The text_hidden_states (LLM_output_text) would be passed to the next LLM layer
        return LLM_output_text, image_hidden_states

def create_deep_fusion_mask(text_len, image_len):
    # For LLM stream (text queries):
    # Text tokens attend to preceding text tokens (causal)
    # Text tokens DO NOT attend to image tokens
    # For DiT stream (image queries):
    # Image tokens attend to ALL text tokens
    # Image tokens attend to ALL image tokens (bidirectional)

    total_len = text_len + image_len
    mask = torch.zeros(total_len, total_len, dtype=torch.bool)

    # Text attending to text (causal)
    mask[:text_len, :text_len] = ~torch.tril(torch.ones(text_len, text_len, dtype=torch.bool))
    # Text does not attend to image
    mask[:text_len, text_len:] = True 
    # Image does not attend to text (as per paper's Fig 3: "permitting the image tokens to attend to text tokens but not vice versa" - this seems contradictory to Figure 3 which shows image attending to text)
    # Figure 3 shows: Row (Query) can attend to Column (Key/Value)
    # Text Query: Attends to Text Key (causal), not Image Key
    # Image Query: Attends to Text Key (full), Attends to Image Key (full)
    # So, the statement "not vice versa" might mean text tokens don't attend to image tokens.

    # Image attends to text (full) - This corresponds to Figure 3.
    mask[text_len:, :text_len] = False # Allow attention
    # Image attends to image (bidirectional/full)
    mask[text_len:, text_len:] = False # Allow attention for image self-attention (often handled by no mask or a full mask)
                                      # For image self-attention, usually no mask is needed unless specific (e.g. DiT might have its own way)
                                      # If we consider the DiT self-attention part:
                                      # image_q to image_k,v: all False (allow all)
                                      # image_q to text_k,v: all False (allow all)
    return mask # Returns a mask where True means "masked out" / "cannot attend"

Attention Mask (Figure 1 clarification):

The mask ensures:

  1. Text tokens (queries) attending to text tokens (keys/values): Causal mask (can only attend to past and self).
  2. Text tokens (queries) attending to image tokens (keys/values): Masked out (cannot attend).
  3. Image tokens (queries) attending to text tokens (keys/values): Full attention (can attend to all text tokens).
  4. Image tokens (queries) attending to image tokens (keys/values): Full attention (can attend to all image tokens).

Training Recipe (FuseDiT):

  • Base LLM: Frozen Gemma 2 2B.
  • DiT: Matches Gemma 2 2B architecture, trained from scratch.
  • Timestep Conditioning: None (removed AdaLN-Zero).
  • Positional Encoding: 1D RoPE for text sequence, 2D RoPE for image patch sequence.
  • Data: ~26M image-caption pairs (CC12M, SA-1B, JourneyDB with synthetic captions).
  • Optimizer: AdamW (β1=0.9,β2=0.999\beta_1=0.9, \beta_2=0.999), LR 1×1041 \times 10^{-4}, WD 1×1041 \times 10^{-4}.
  • Batch Size: 512.
  • Precision: BF16.
  • Objective: Rectified Flow.
  • Duration: 800K steps.

This recipe provides a strong baseline for deep fusion models, emphasizing the importance of the base LLM's quality and thoughtful reduction of components like timestep conditioning for efficiency and performance. The paper concludes that deep fusion is a promising direction, and their empirical work helps bridge gaps in understanding its practical application.

Ai Generate Text Spark Streamline Icon: https://streamlinehq.com

Paper Prompts

Sign up for free to create and run paper prompts using GPT-5.

Dice Question Streamline Icon: https://streamlinehq.com

Follow-up Questions

We haven't generated follow-up questions for this paper yet.