Group Shifted Window Attention (GSWA)
- Group Shifted Window Attention (GSWA) is an architectural modification that partitions the standard multi-head attention into parallel groups to significantly reduce memory usage.
- It employs alternating unshifted and shifted window strategies with learnable bias and residual cascades to maintain spatial connectivity and performance.
- Empirical results demonstrate that with a moderate group count (e.g., G=4), GSWA achieves over 50% memory savings with negligible (<0.1 dB) loss in PSNR on image restoration benchmarks.
Group Shifted Window Attention (GSWA) is an architectural modification of the shifted window attention mechanism used in Swin Transformer, designed to reduce memory usage and accelerate training in high-resolution vision tasks. GSWA operates by decomposing the multi-head self-attention in shifted window frameworks into parallel, cascaded groups over head subsets, enabling large memory savings with negligible loss in performance in dense prediction tasks and image restoration (Liu et al., 2021, Cai et al., 2024).
1. Mathematical Formulation
GSWA generalizes the standard shifted window multi-head self-attention (SW-MSA) and windowed multi-head self-attention (W-MSA) by splitting the attention heads into disjoint groups and applying all attention computations within each group individually. For input , partitioned into non-overlapping windows of size , each group computes: where , . For each window or shifted window, group-wise scaled dot product attention with relative positional bias and masking is: Here, denotes a learnable bias (shifted in SW-MSA), and applies Swin-standard masking to block discontiguous attention in shifted layouts.
To promote cross-group interaction, GSWA cascades group outputs: The group outputs are concatenated, re-projected via , and finally aggregated with MLP and residual connections, as in classical Swin blocks.
2. Architectural Integration and Alternating Shifted Windows
GSWA adopts Swin’s alternating pattern of W-MSA and SW-MSA, partitioning features into regular windows and shifted windows every other block. During W-MSA operations, windows are non-overlapping and unshifted (). In SW-MSA blocks, features are cyclically shifted by , with masking and bias table indices adjusted accordingly.
The learnable relative position bias table is learned per head group. When windows are shifted, 's indexing is correspondingly shifted so that relative displacements are preserved, ensuring the biases remain consistent with spatial relationships across blocks. This bias mechanism supports cross-window communication and enables the hierarchical feature pyramid in Swin-based models (Liu et al., 2021, Cai et al., 2024).
A layer pseudocode skeleton for one GSWA block is as follows:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
Input: X ∈ ℝ^{H×W×C} for g in 1..G: Q_g = X @ W_Q^g K_g = X @ W_K^g V_g = X @ W_V^g B_g = shifted_B if SW-MSA else B A_g = softmax(Q_g @ K_g^T / sqrt(d) + B_g + Mask_g) O_g = A_g @ V_g if g > 1: O_g += O_{g-1} O_cat = concat(O_1, ..., O_G, dim=heads) Y = O_cat @ W_O if SW-MSA: Y = reverse_shift(Y) X_out = X + Y # residual X_out = X_out + MLP(LN(X_out)) |
3. Complexity and Memory Analysis
Time complexity for window-based or group-shifted attention remains , where . Since each group’s attention attends within all windows in parallel, total computational cost—summing over groups—does not exceed that of ordinary (S)W-MSA. However, the primary memory savings derive from the grouping. During backpropagation, GSWA requires only $1/G$ of the activations (Q, K, V, A) per group before concatenation, reducing peak intermediate storage in the head-dimension direction by a factor close to .
Empirical measurements on NVIDIA A100-80G GPUs with window size and batch size 256 on DIV2K show:
- SwinIR-light with SW-MSA: 67.5 GB peak memory
- AgileIR (GSWA) with : 30.2 GB This corresponds to a reduction, directly enabling significantly larger batch sizes or deeper models (Cai et al., 2024).
4. Empirical Performance in Image Restoration
In image super-resolution on Set5, using , GSWA achieves near-parity with Swin-based baselines with strong memory efficiency:
| Method | Set5 ×2 (PSNR Y, dB) | Set5 ×4 (PSNR Y, dB) | Peak GPU Mem (GB) |
|---|---|---|---|
| SwinIR-small | 38.14 | 32.44 | 50 / 67 |
| QuantSR‐T (8-bit) | 38.10 | 32.18 | N/A |
| AgileIR+ (GSWA) | 38.05 | 32.20 | 30 (G=4) |
The ablation with G shows:
- (vanilla): 50 GB, 38.14 dB
- : 38 GB, 38.10 dB
- : 30 GB, 38.05 dB
- : 28 GB, 37.95 dB
This demonstrates that a moderate group count () produces 50% memory savings with performance drop dB in PSNR, illustrating the practical value of groupwise attention partitioning (Cai et al., 2024).
5. Design Choices and Practical Guidelines
- Group count : reduces to classical Swin attention. yields 2× memory reduction with only minor PSNR reduction. Larger continues to cut memory demand but with increasingly visible accuracy penalties.
- Window size : Reducing per-head dimension enables moderately increased (e.g., 12 or 16), marginally enlarging the effective receptive field of each attention window without increasing total memory.
- Batch size scaling: With GSWA (), a typical 80 GB GPU handles input with batch sizes up to 256, whereas vanilla SwinIR OOMs at batch sizes above 128.
- Training throughput: While per-step runtime is comparable to baseline SwinIR, larger batches under GSWA increase overall throughput (in images per second) by 10–15% due to improved GPU and multi-card utilization.
GSWA is thus suited for resource-constrained scenarios, especially where large batch sizes, high-resolution inputs, or deep stacking of shifted-window attention is required (Cai et al., 2024).
6. Relationship to Swin Transformer and Hierarchical Feature Learning
As in Swin Transformer, alternating regular and shifted window blocks in GSWA “stitches” together interior and boundary tokens across layers, supporting global information flow within a two-layer scope without resorting to global attention. The patch-merging strategy is also retained: after fixed numbers of (S)W-MSA (now GSWA) layers, patch merging reduces resolution and doubles feature channels, constructing a four-level feature pyramid accessible to downstream heads (detection, segmentation, classification, restoration). GSWA thus inherits the hierarchical and cross-window connectivity properties central to Swin (Liu et al., 2021, Cai et al., 2024).
7. Applications and Limitations
GSWA is primarily demonstrated in image restoration, especially super-resolution, with compelling results in AgileIR on Set5 and DIV2K benchmarks. It is compatible with other vision tasks structured for Swin Transformer backbones. A plausible implication is that similar groupwise decomposition could benefit other dense or high-resolution domains where standard multi-head attention is memory-bottlenecked. However, increasing too aggressively does result in measurable PSNR drops, limiting its use in applications strictly requiring maximum fidelity (Cai et al., 2024).
References:
(Liu et al., 2021): "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" (Cai et al., 2024): "AgileIR: Memory-Efficient Group Shifted Windows Attention for Agile Image Restoration"