Multi-Axis Attention in MaxViT-UNet
- Multi-axis attention is a technique that decomposes global attention into localized window and grid operations to reduce computational complexity.
- MaxViT-UNet integrates MBConv and transformer-based modules within a U-Net framework, balancing expressivity and efficiency for dense segmentation.
- Empirical results show improvements in Dice scores and reduced resource usage, validating multi-axis attention for high-resolution medical image segmentation.
Multi-axis attention is an architectural principle that integrates spatially orthogonal attention mechanisms within deep neural networks for dense prediction, enabling efficient capture of both local and global contextual dependencies. Within the U-Net family, the most prominent multi-axis attention realizations are the MaxViT family (notably MaxViT-UNet and QMaxViT-UNet+), as well as hybrid approaches such as AFTer-UNet, all designed to balance expressivity and computational tractability in high-resolution segmentation. These architectures combine convolutional inductive biases, hierarchical representations, and transformer-based non-local modeling, with particular mathematical strategies to maintain linear complexity and resource usage.
1. Multi-Axis Attention: Definitions and Mathematical Framework
Multi-axis attention decomposes the conventional global self-attention, which has quadratic complexity in the number of tokens, into a sequence or parallel arrangement of restricted attention sub-operators, each acting along a lower-dimensional axis or within a localized spatial region.
In MaxViT-UNet, this decomposition entails two principal components operating over an input :
- Window (blocked) attention: The spatial feature map is divided into non-overlapping windows. Self-attention is performed independently within each window:
- Grid (dilated global) attention: Features are sampled at regular grid intervals, yielding tokens per “grid group”, and attention is performed within these groups:
The outputs from window and grid attention branches are typically fused via summation or channel-wise concatenation, restoring the feature tensor’s original shape. Each attention module is followed by a pre-normalized residual pathway and MLP sub-block. This yields complexity per layer for fixed .
AFTer-UNet (Axial Fusion Transformer) exemplifies a different multi-axis decomposition for 3D stacked volumes (0): attention is applied along the “axial” (slice) and “spatial” (within-slice) axes in sequence, with the inter-slice operator realized as 1D attention over 1 slices at fixed 2. This reduces memory burden relative to full 3D attention, retaining 3 per-head complexity at the bottleneck (Yan et al., 2021).
2. Integration of MaxViT Blocks in U-Net Architectures
A MaxViT-UNet replaces standard convolutional encoder and decoder blocks with MaxViT blocks, which are composed sequentially as follows (Tu et al., 2022, Khan et al., 2023, Nguyen-Tat et al., 14 Feb 2025):
- MBConv (Mobile inverted bottleneck + SE).
- Block (window-based) attention + MLP.
- Grid (global) attention + MLP.
Each MaxViT block begins with an expansion/projection operator, computes windowed self-attention in non-overlapping patches, injects global context via grid attention, and employs residual pre-activations throughout. The arrangement of MaxViT blocks into U-Net topologies is hierarchical: progressively downsampling in the encoder (via stride-2 MBConv), and upsampling via transposed convolutions in the decoder. Skip connections remain, with decoded features at each stage fused before passing through MaxViT hybrid blocks. Hyperparameters such as block/window size (4), grid size (5), expansion ratio, and the number of heads control the trade-off between receptive field and cost (Tu et al., 2022, Khan et al., 2023, Nguyen-Tat et al., 14 Feb 2025).
In QMaxViT-Unet+, MaxViT blocks entirely replace all encoder/decoder units, building the architecture solely from MBConv–block attention–MLP–grid attention–MLP stages, and supplementing with edge enhancement and query-based Transformer decoders for weakly-supervised (scribble) segmentation (Nguyen-Tat et al., 14 Feb 2025).
3. Computational Complexity and Resource Efficiency
MaxViT’s multi-axis strategy mitigates the quadratic scaling of standard MHSA as follows:
- Each window attention: 6 for 7 pixels and local neighborhood 8.
- Each grid attention: 9.
- With two passes per block, total per-block cost is 0, linear in the spatial map for constant 1 (Tu et al., 2022).
Empirically, MaxViT-UNet achieves lower or comparable parameter counts and FLOPs to classic U-Net and Swin-UNet, e.g., 24.72M params and 7.51 GFLOPs (vs. 29.06M/50.64G for U-Net) on nuclei segmentation benchmarks (Khan et al., 2023). QMaxViT-Unet+ reports 39.10G MACs despite >100M parameters, with smaller operational cost than contemporary CNN-Transformer hybrids (Nguyen-Tat et al., 14 Feb 2025).
AFTer-UNet further exploits multi-axis decomposition to substantially reduce GPU memory compared to 3D self-attention: the model fits within 11GB (RTX2080Ti) for whole-volume multi-organ segmentation tasks, and its parameter count (41.5M) is competitive with TransUNet/CoTr (Yan et al., 2021).
4. Empirical Performance in Medical Segmentation
Multi-axis attention U-Nets exhibit gains in both dense fully-supervised and weakly/partially supervised segmentation:
- QMaxViT-Unet+ achieves DSC/HD95 of 89.1%/1.316mm (ACDC), 88.4%/2.226mm (MS-CMRSeg), 71.4%/4.996mm (SUN-SEG), and 69.4%/50.122mm (BUSI), with consistent improvements of 2-4 points in Dice and lower Hausdorff distances relative to all CNN or hybrid baselines (Nguyen-Tat et al., 14 Feb 2025).
- MaxViT-UNet (CNN/Transformer hybrid): Dice = 0.8378 (MoNuSeg18), 0.8215 (MoNuSAC20), outperforming U-Net and Swin-UNet by 2-5% (Khan et al., 2023).
- AFTer-UNet: Dice = 92.32% (Thorax-85), 81.02% (Synapse), with maximal benefit in elongated or cross-slice organs due to explicit inter-slice fusion (Yan et al., 2021).
Ablation studies show Dice rises with both more axial neighbors 2 and AFT layers 3, confirming the value of deeper multi-slice aggregation; dense sampling (4) is optimal for context capture (Yan et al., 2021).
5. Comparative Analysis: MaxViT, AFTer-UNet, and Other Axial Variants
The following table compares selected multi-axis attention U-Nets:
| Architecture | Attention Decomposition | Main Location in UNet | Complexity |
|---|---|---|---|
| MaxViT-UNet | Blocked + Grid (2D local/global) | Encoder/Decoder | 5 per block |
| QMaxViT-Unet+ | Blocked + Grid (2D) | Encoder/Decoder | 6 per block |
| AFTer-UNet | Intra-slice + Inter-slice (2D+1D) | Bottleneck only | 7 per head |
| GASA-UNet | Global axial, 3 axes (1D*3) | Bottleneck only (3D) | 8 |
MaxViT strategies enable attention at all spatial scales in every stage, via hierarchical multi-stage pyramid architectures. AFTer-UNet restricts axial attention to the bottleneck, leading to lower overhead, but applies full 9-size attention per slice (no windowing), and explicit 1D attention across slices—particularly advantageous for cross-slice anatomical consistency. In contrast, MaxViT’s repeated window+grid fusion excels at per-layer local/global context mixing, supporting U-Net-style progressive decoding (Yan et al., 2021, Khan et al., 2023, Nguyen-Tat et al., 14 Feb 2025, Tu et al., 2022).
6. Theoretical and Practical Considerations
- Expressivity: Multi-axis decomposition allows the architecture to approximate full global attention across the image (or volume) using far fewer operations and less memory, crucial for high-resolution or volumetric tasks (Tu et al., 2022, Yan et al., 2021).
- Inductive bias: By simultaneously retaining convolutional (MBConv) blocks and deploying attention only where it is most valuable (e.g., at the bottleneck or in deep features), these U-Nets exploit both spatial locality and global context.
- Scalability: Linear complexity at constant 0 enables deployment on images up to 512×512 (or higher), in contrast to MHSA models which scale quadratically (Tu et al., 2022).
- Limitations: In AFTer-UNet, spatial (within-slice) attention remains 1, so the bottleneck must operate at highly downsampled resolutions; MaxViT may be further optimized by combining slice-wise attention in video or 3D cases (Yan et al., 2021). There is no direct ablation of MaxViT blocks versus plain convolution in some studies, though documented gains are attributed to richer local/global modeling (Nguyen-Tat et al., 14 Feb 2025).
7. Extensions and Outlook
Hybridization of multi-axis attention continues to evolve:
- GASA-UNet generalizes to global axial self-attention over three axes in full 3D U-Nets, reporting improvements for structures with complex boundaries (Sun et al., 2024).
- Cross-fertilization between axial (AFTer-UNet) and window/grid (MaxViT) approaches is possible: e.g., block-wise spatial attention plus 1D axial “inter-slice” fusion for video/volumes or grid-style global attention within AFTer-UNet bottlenecks (Yan et al., 2021).
Multi-axis attention will likely remain foundational for high-resolution, memory-efficient dense prediction models in medical image analysis and beyond, providing scalable mechanisms for capturing both fine anatomical detail and long-range semantic coherence.