Exploring the Deep Fusion of Large Language Models and Diffusion Transformers for Text-to-Image Synthesis (2505.10046v1)
Abstract: This paper does not describe a new method; instead, it provides a thorough exploration of an important yet understudied design space related to recent advances in text-to-image synthesis -- specifically, the deep fusion of LLMs and diffusion transformers (DiTs) for multi-modal generation. Previous studies mainly focused on overall system performance rather than detailed comparisons with alternative methods, and key design details and training recipes were often left undisclosed. These gaps create uncertainty about the real potential of this approach. To fill these gaps, we conduct an empirical study on text-to-image generation, performing controlled comparisons with established baselines, analyzing important design choices, and providing a clear, reproducible recipe for training at scale. We hope this work offers meaningful data points and practical guidelines for future research in multi-modal generation.
Collections
Sign up for free to add this paper to one or more collections.
Summary
- The paper demonstrates that deep fusion of a frozen LLM and a trainable DiT enhances text-image alignment (GenEval up to 0.51) and inference speed.
- Methodical evaluations reveal that design choices like removing timestep conditioning and using mixed positional encodings can improve FID scores and model efficiency.
- Experiments scaling to FuseDiT (2B parameters) provide actionable guidelines for multi-modal synthesis, achieving significant performance boosts with limited compute.
This paper presents an empirical paper on the "deep fusion" of a frozen LLM and a trainable Diffusion Transformer (DiT) for text-to-image synthesis (2505.10046). Instead of introducing a novel method, the authors provide a thorough exploration of this specific design space, which they argue has been underexplored in terms of controlled comparisons, design choices, and reproducible training recipes. The primary goal is to offer meaningful data points and practical guidelines for future research in multi-modal generation.
The core deep fusion architecture involves integrating a frozen decoder-only LLM with a trainable DiT using layer-wise shared self-attention. The DiT's transformer architecture mirrors the LLM's, creating a two-stream system. Text embeddings are processed by the LLM, and noisy image latents by the DiT. At each layer, token sequences from both streams are concatenated for the self-attention operation, allowing the DiT to access linguistic context. A causal attention mask is used for text tokens and a bidirectional mask for image tokens, permitting image tokens to attend to text tokens but not vice-versa. This setup allows the key and value states of the text hidden states to be cached during inference, improving efficiency. The training objective uses the rectified flow formulation.
Experimental Setup:
- LLM: Frozen Gemma 2B (and later Gemma 2 2B).
- DiT: Randomly initialized 2.5B parameter DiT, architecturally matching the LLM (2B backbone). It uses 2D frequency absolute positional encoding, adaLN-Zero timestep conditioning, ViT-style weight initialization, and a patch size of 2. QK normalization is applied.
- VAE: 16-channel VAE from Stable Diffusion 3.
- Dataset: CC12M with synthetic captions (10.9M pairs) for most experiments, resized to 512×512. Texts padded/truncated to 256 tokens. For scaled experiments, a mix of CC12M, SA-1B, and JourneyDB (approx. 26M pairs) is used.
- Training: Batch size 512, AdamW optimizer, constant LR 1×10−4, weight decay 1×10−4, gradient clipping 1.0, BF16 precision. EMA with decay 0.99. 10% text dropout for unconditional generation. Trained on TPU v4-256 pods.
- Inference: Euler discretization, 25 steps, CFG scale 6.
- Evaluation: GenEval and DPG-Bench for text-image alignment, FID on MJHQ-30K for visual quality.
Key Findings and Comparisons:
- Deep Fusion vs. Shallow Fusion:
- Shallow Fusion Baselines:
- Self-attention DiT: Text representations (from LLM's last layer) projected to K, V states and concatenated with image K, V states in self-attention.
- Cross-attention DiT: Text representations projected to K, V states for an additional cross-attention mechanism with image hidden states after self-attention.
- Results: Deep fusion (2.45B params) achieved better text-image alignment (GenEval: 0.51) compared to self-attention DiT (2.47B params, GenEval: 0.42) and cross-attention DiT (2.62B params, GenEval: 0.49). However, shallow fusion models showed better visual quality (FID: Deep Fusion 27.33, Self-Attn 26.16, Cross-Attn 24.00). Deep fusion was slightly faster in inference (1.66s vs 1.75s/1.86s on A100).
Method Params GenEval ↑ DPG ↑ FID ↓ Inference Latency (s) Self-Attention 2.47B 0.42 73.9 26.16 1.75 Cross-Attention 2.62B 0.49 76.3 24.00 1.86 Deep Fusion 2.45B 0.51 76.6 27.33 1.66 - Shallow Fusion Baselines:
- Examining Design Choices for Deep Fusion:
Timestep Conditioning:
- Adding CLIP L/14 text embeddings to adaLN-Zero improved FID (24.00 vs 27.33) but slightly hurt alignment (GenEval 0.50 vs 0.51).
- Reducing adaLN parameters (adaLN-Single, Addition) or removing timestep conditioning altogether was explored.
- Removing timestep conditioning entirely (1.98B params) surprisingly yielded the best FID (21.27) while maintaining comparable alignment (GenEval 0.49), reducing parameters by 20%.
Method Params GenEval ↑ DPG ↑ FID ↓ adaLN-Zero 2.47B 0.51 76.6 27.33 w/o timestep 1.98B 0.49 76.7 21.27
* Positional Encoding: * Default: 1D RoPE for text + APE for image. * 1D RoPE (text) + 2D RoPE (image) performed best overall (GenEval 0.51, FID 25.42) compared to default (GenEval 0.51, FID 27.33) and M-RoPE (GenEval 0.49, FID 27.60).
* Base LLM (Gemma 2B variants): * Instruction-tuned Gemma 2B IT (with or without "Imagine: " prompt) did not improve and sometimes slightly worsened performance compared to base Gemma 2B. * Gemma 2B with multi-modal tuning (from PaliGemma 3B PT) showed slight improvements (GenEval 0.52 vs 0.51, FID 26.30 vs 27.33). * Upgrading to Gemma 2 2B (next-gen model) yielded a drastic performance boost (GenEval 0.54 vs 0.51, DPG 79.1 vs 76.6, FID 23.94 vs 27.33), indicating strong dependency on base LLM capabilities.
- Training at Scale (FuseDiT):
- Final Recipe:
- Removed AdaLN-Zero modules.
- Used 1D RoPE (text) + 2D RoPE (image).
- Replaced Gemma 2B with Gemma 2 2B (adjusting DiT accordingly).
- Trained for 800K steps on a mixed dataset of ~26M image-caption pairs (CC12M, SA-1B, JourneyDB with synthetic captions).
- FuseDiT (2B params) achieved GenEval 0.60, DPG 81.6, and FID 7.54, surpassing many established systems despite limited compute and data.
- Final Recipe:
- Further Explorations:
- Architecture Alignment: Explored decoupling LLM and DiT architectures by reducing DiT hidden size or layers while fusing into middle LLM layers.
- Reducing DiT hidden size (e.g., from 2048 to 1792) gracefully degraded alignment but sometimes improved FID (24.27 for 1792 hidden).
- Reducing DiT layers (e.g., from 18 to 14) showed faster performance degradation for alignment. This suggests DiT architecture can be scaled independently to some extent.
- Attention Mechanism: Replacing shared self-attention with a deep fusion variant of cross-attention (using LLM layer K,V states instead of projected last-layer K,V) showed minor gains in alignment (GenEval 0.52 vs 0.51) but increased latency by ~12% (1.86s vs 1.66s), leading to retention of self-attention.
- Architecture Alignment: Explored decoupling LLM and DiT architectures by reducing DiT hidden size or layers while fusing into middle LLM layers.
Implementation of Deep Fusion:
The core idea is the shared self-attention mechanism across LLM and DiT layers.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
class DeepFusionTransformerLayer: def __init__(self, LLM_layer, dit_layer): self.LLM_layer = LLM_layer # Frozen self.dit_layer = dit_layer # Trainable def forward(self, text_hidden_states, image_hidden_states, text_attention_mask, image_attention_mask, timestep_embedding=None): # LLM processes text (simplified, actual LLM layer is more complex) # In practice, LLM K,V states for text can be precomputed if text is fixed LLM_output_text = self.LLM_layer.self_attn(text_hidden_states, attention_mask=text_attention_mask) # LLM_output_text = self.LLM_layer.mlp(LLM_output_text) # For deep fusion, we need K, V from text for the DiT's attention # DiT processes image, conditioned on text # 1. Modality-specific processing (e.g., AdaLN for DiT if used) if timestep_embedding is not None and hasattr(self.dit_layer, 'adaLN_modulation'): scale, shift = self.dit_layer.adaLN_modulation(timestep_embedding).chunk(2, dim=1) image_hidden_states = self.dit_layer.norm1(image_hidden_states) image_hidden_states = image_hidden_states * (1 + scale) + shift else: image_hidden_states = self.dit_layer.norm1(image_hidden_states) # if no adaLN # 2. Shared Self-Attention # Get Q from image, K,V from both image and text q_img = self.dit_layer.to_q(image_hidden_states) # K,V from LLM's corresponding layer's text_hidden_states # These are a_T in the paper's notation (text activations) # Note: The paper mentions "key and value states of the text hidden states" # This implies using the K,V projections from the LLM's attention mechanism # or passing text_hidden_states through DiT's K,V projections. # Figure 1 suggests text_hidden_states go through LLM's own weights # then K,V are used by DiT. Let's assume LLM provides its K,V. k_text = self.LLM_layer.self_attn.to_k(text_hidden_states) # Or from a compatible projection v_text = self.LLM_layer.self_attn.to_v(text_hidden_states) # Or from a compatible projection k_img = self.dit_layer.to_k(image_hidden_states) v_img = self.dit_layer.to_v(image_hidden_states) # Concatenate K,V from text and image for DiT's attention # Text tokens can only attend to preceding text tokens (causal_mask_text) # Image tokens can attend to all image tokens and all text tokens # This is handled by constructing a combined_attention_mask (as in Fig 3) # For image tokens attending to text: # q_img attends to k_text, v_text and k_img, v_img combined_k = torch.cat([k_text, k_img], dim=1) # sequence dimension combined_v = torch.cat([v_text, v_img], dim=1) # sequence dimension # Construct the specific attention mask (Fig 3 of the paper) # image_tokens_attend_to_text_mask (all ones) # image_tokens_attend_to_image_mask (bidirectional) # For q_img, the mask for combined_k, combined_v would allow attention to all text and all image # Simplified attention calculation for DiT stream # image_attn_output = scaled_dot_product_attention(q_img, combined_k, combined_v, attention_mask=correct_combined_mask_for_img_queries) # This part is typically handled by a multi-head attention module in DiT image_attn_output = self.dit_layer.attn(query=q_img, key=combined_k, value=combined_v, attention_mask=create_deep_fusion_mask(text_seq_len, img_seq_len)) image_hidden_states = image_hidden_states + self.dit_layer.drop(image_attn_output) # 3. MLP for DiT image_hidden_states = image_hidden_states + self.dit_layer.mlp(self.dit_layer.norm2(image_hidden_states)) # Return only the processed image_hidden_states for the DiT stream # The text_hidden_states (LLM_output_text) would be passed to the next LLM layer return LLM_output_text, image_hidden_states def create_deep_fusion_mask(text_len, image_len): # For LLM stream (text queries): # Text tokens attend to preceding text tokens (causal) # Text tokens DO NOT attend to image tokens # For DiT stream (image queries): # Image tokens attend to ALL text tokens # Image tokens attend to ALL image tokens (bidirectional) total_len = text_len + image_len mask = torch.zeros(total_len, total_len, dtype=torch.bool) # Text attending to text (causal) mask[:text_len, :text_len] = ~torch.tril(torch.ones(text_len, text_len, dtype=torch.bool)) # Text does not attend to image mask[:text_len, text_len:] = True # Image does not attend to text (as per paper's Fig 3: "permitting the image tokens to attend to text tokens but not vice versa" - this seems contradictory to Figure 3 which shows image attending to text) # Figure 3 shows: Row (Query) can attend to Column (Key/Value) # Text Query: Attends to Text Key (causal), not Image Key # Image Query: Attends to Text Key (full), Attends to Image Key (full) # So, the statement "not vice versa" might mean text tokens don't attend to image tokens. # Image attends to text (full) - This corresponds to Figure 3. mask[text_len:, :text_len] = False # Allow attention # Image attends to image (bidirectional/full) mask[text_len:, text_len:] = False # Allow attention for image self-attention (often handled by no mask or a full mask) # For image self-attention, usually no mask is needed unless specific (e.g. DiT might have its own way) # If we consider the DiT self-attention part: # image_q to image_k,v: all False (allow all) # image_q to text_k,v: all False (allow all) return mask # Returns a mask where True means "masked out" / "cannot attend" |
Attention Mask (Figure 1 clarification):
The mask ensures:
- Text tokens (queries) attending to text tokens (keys/values): Causal mask (can only attend to past and self).
- Text tokens (queries) attending to image tokens (keys/values): Masked out (cannot attend).
- Image tokens (queries) attending to text tokens (keys/values): Full attention (can attend to all text tokens).
- Image tokens (queries) attending to image tokens (keys/values): Full attention (can attend to all image tokens).
Training Recipe (FuseDiT):
- Base LLM: Frozen Gemma 2 2B.
- DiT: Matches Gemma 2 2B architecture, trained from scratch.
- Timestep Conditioning: None (removed AdaLN-Zero).
- Positional Encoding: 1D RoPE for text sequence, 2D RoPE for image patch sequence.
- Data: ~26M image-caption pairs (CC12M, SA-1B, JourneyDB with synthetic captions).
- Optimizer: AdamW (β1=0.9,β2=0.999), LR 1×10−4, WD 1×10−4.
- Batch Size: 512.
- Precision: BF16.
- Objective: Rectified Flow.
- Duration: 800K steps.
This recipe provides a strong baseline for deep fusion models, emphasizing the importance of the base LLM's quality and thoughtful reduction of components like timestep conditioning for efficiency and performance. The paper concludes that deep fusion is a promising direction, and their empirical work helps bridge gaps in understanding its practical application.
Paper Prompts
Sign up for free to create and run paper prompts using GPT-5.
Follow-up Questions
We haven't generated follow-up questions for this paper yet.