1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
|
import torch
import torch.nn as nn
import torch.nn.functional as F
def generate_geometry_prior(depth_map, patch_size, H, W):
# depth_map: Input depth map (B, 1, H_img, W_img)
# patch_size: Size of the patch (e.g., 16)
# H, W: Number of patches along height and width
# 1. Get patch depth representations
# Use average pooling to get average depth per patch
avg_pool = nn.AvgPool2d(kernel_size=patch_size, stride=patch_size)
patch_depths = avg_pool(depth_map) # Shape: (B, 1, H, W)
patch_depths = patch_depths.view(B, H * W) # Shape: (B, HW)
# 2. Calculate Depth Distance Matrix D
# Expand dims to compute pairwise differences
z_diff = patch_depths.unsqueeze(2) - patch_depths.unsqueeze(1) # Shape: (B, HW, HW)
D = torch.abs(z_diff) # Shape: (B, HW, HW)
# 3. Calculate Spatial Distance Matrix S
coords_h = torch.arange(H, device=depth_map.device)
coords_w = torch.arange(W, device=depth_map.device)
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing='ij'), dim=-1) # Shape: (H, W, 2)
coords_flat = coords.view(H * W, 2) # Shape: (HW, 2)
# Compute pairwise Manhattan distances
s_diff = coords_flat.unsqueeze(1) - coords_flat.unsqueeze(0) # Shape: (HW, HW, 2)
S = torch.abs(s_diff[..., 0]) + torch.abs(s_diff[..., 1]) # Shape: (HW, HW)
S = S.unsqueeze(0).expand(B, -1, -1) # Shape: (B, HW, HW)
# 4. Fuse D and S (Simplified: learnable weights w1, w2 per model)
# In practice, these weights (memories) are learnable parameters
w1 = 0.5 # Example weight
w2 = 0.5 # Example weight
G = w1 * D + w2 * S # Shape: (B, HW, HW)
return G
class GeometrySelfAttention(nn.Module):
def __init__(self, dim, num_heads, decay_min=0.75, decay_max=1.0):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)
# Linearly sample decay rates for each head
self.decay_rates = torch.linspace(decay_min, decay_max, num_heads)
# In practice, register as buffer or parameter if learnable
def forward(self, x, geometry_prior_G):
# x: input features (B, HW, C)
# geometry_prior_G: precomputed geometry prior (B, HW, HW)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # Shape: (B, num_heads, HW, head_dim)
attn = (q @ k.transpose(-2, -1)) * self.scale # Shape: (B, num_heads, HW, HW)
attn = attn.softmax(dim=-1)
# Apply Geometry Prior Decay
# Ensure decay_rates and G are on the same device
self.decay_rates = self.decay_rates.to(geometry_prior_G.device)
# Reshape decay rates for broadcasting: (1, num_heads, 1, 1)
decay_rates_b = self.decay_rates.view(1, -1, 1, 1)
# Reshape G for broadcasting: (B, 1, HW, HW)
geometry_prior_G_b = geometry_prior_G.unsqueeze(1)
# Calculate decay matrix: beta^G per head
decay_matrix = decay_rates_b ** geometry_prior_G_b # Shape: (B, num_heads, HW, HW)
# Modulate attention map
attn = attn * decay_matrix # Element-wise multiplication
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
|