1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
|
class DeepFusionTransformerLayer:
def __init__(self, llm_layer, dit_layer):
self.llm_layer = llm_layer # Frozen
self.dit_layer = dit_layer # Trainable
def forward(self, text_hidden_states, image_hidden_states, text_attention_mask, image_attention_mask, timestep_embedding=None):
# LLM processes text (simplified, actual LLM layer is more complex)
# In practice, LLM K,V states for text can be precomputed if text is fixed
llm_output_text = self.llm_layer.self_attn(text_hidden_states, attention_mask=text_attention_mask)
# llm_output_text = self.llm_layer.mlp(llm_output_text)
# For deep fusion, we need K, V from text for the DiT's attention
# DiT processes image, conditioned on text
# 1. Modality-specific processing (e.g., AdaLN for DiT if used)
if timestep_embedding is not None and hasattr(self.dit_layer, 'adaLN_modulation'):
scale, shift = self.dit_layer.adaLN_modulation(timestep_embedding).chunk(2, dim=1)
image_hidden_states = self.dit_layer.norm1(image_hidden_states)
image_hidden_states = image_hidden_states * (1 + scale) + shift
else:
image_hidden_states = self.dit_layer.norm1(image_hidden_states) # if no adaLN
# 2. Shared Self-Attention
# Get Q from image, K,V from both image and text
q_img = self.dit_layer.to_q(image_hidden_states)
# K,V from LLM's corresponding layer's text_hidden_states
# These are a_T in the paper's notation (text activations)
# Note: The paper mentions "key and value states of the text hidden states"
# This implies using the K,V projections from the LLM's attention mechanism
# or passing text_hidden_states through DiT's K,V projections.
# Figure 1 suggests text_hidden_states go through LLM's own weights
# then K,V are used by DiT. Let's assume LLM provides its K,V.
k_text = self.llm_layer.self_attn.to_k(text_hidden_states) # Or from a compatible projection
v_text = self.llm_layer.self_attn.to_v(text_hidden_states) # Or from a compatible projection
k_img = self.dit_layer.to_k(image_hidden_states)
v_img = self.dit_layer.to_v(image_hidden_states)
# Concatenate K,V from text and image for DiT's attention
# Text tokens can only attend to preceding text tokens (causal_mask_text)
# Image tokens can attend to all image tokens and all text tokens
# This is handled by constructing a combined_attention_mask (as in Fig 3)
# For image tokens attending to text:
# q_img attends to k_text, v_text and k_img, v_img
combined_k = torch.cat([k_text, k_img], dim=1) # sequence dimension
combined_v = torch.cat([v_text, v_img], dim=1) # sequence dimension
# Construct the specific attention mask (Fig 3 of the paper)
# image_tokens_attend_to_text_mask (all ones)
# image_tokens_attend_to_image_mask (bidirectional)
# For q_img, the mask for combined_k, combined_v would allow attention to all text and all image
# Simplified attention calculation for DiT stream
# image_attn_output = scaled_dot_product_attention(q_img, combined_k, combined_v, attention_mask=correct_combined_mask_for_img_queries)
# This part is typically handled by a multi-head attention module in DiT
image_attn_output = self.dit_layer.attn(query=q_img, key=combined_k, value=combined_v, attention_mask=create_deep_fusion_mask(text_seq_len, img_seq_len))
image_hidden_states = image_hidden_states + self.dit_layer.drop(image_attn_output)
# 3. MLP for DiT
image_hidden_states = image_hidden_states + self.dit_layer.mlp(self.dit_layer.norm2(image_hidden_states))
# Return only the processed image_hidden_states for the DiT stream
# The text_hidden_states (llm_output_text) would be passed to the next LLM layer
return llm_output_text, image_hidden_states
def create_deep_fusion_mask(text_len, image_len):
# For LLM stream (text queries):
# Text tokens attend to preceding text tokens (causal)
# Text tokens DO NOT attend to image tokens
# For DiT stream (image queries):
# Image tokens attend to ALL text tokens
# Image tokens attend to ALL image tokens (bidirectional)
total_len = text_len + image_len
mask = torch.zeros(total_len, total_len, dtype=torch.bool)
# Text attending to text (causal)
mask[:text_len, :text_len] = ~torch.tril(torch.ones(text_len, text_len, dtype=torch.bool))
# Text does not attend to image
mask[:text_len, text_len:] = True
# Image does not attend to text (as per paper's Fig 3: "permitting the image tokens to attend to text tokens but not vice versa" - this seems contradictory to Figure 3 which shows image attending to text)
# Figure 3 shows: Row (Query) can attend to Column (Key/Value)
# Text Query: Attends to Text Key (causal), not Image Key
# Image Query: Attends to Text Key (full), Attends to Image Key (full)
# So, the statement "not vice versa" might mean text tokens don't attend to image tokens.
# Image attends to text (full) - This corresponds to Figure 3.
mask[text_len:, :text_len] = False # Allow attention
# Image attends to image (bidirectional/full)
mask[text_len:, text_len:] = False # Allow attention for image self-attention (often handled by no mask or a full mask)
# For image self-attention, usually no mask is needed unless specific (e.g. DiT might have its own way)
# If we consider the DiT self-attention part:
# image_q to image_k,v: all False (allow all)
# image_q to text_k,v: all False (allow all)
return mask # Returns a mask where True means "masked out" / "cannot attend" |