Modality-Agnostic Deformable Attention
- Modality-Agnostic Deformable Attention is a parameter-free, locality-constrained method that retrieves continuous pixel-level displacements for robust image registration.
- It utilizes multi-resolution CNN feature extraction with a coarse-to-fine strategy to optimize both intra- and inter-modal alignments efficiently.
- Experimental evaluations show superior performance in Dice scores and target registration error compared to conventional methods on diverse imaging datasets.
Modality-Agnostic Deformable Attention is a parameter-free attention mechanism for pixel-level correspondence retrieval in deformable image registration, independent of the input modality. This technique, embodied in the Vector Field Attention (VFA) framework, utilizes multi-resolution convolutional neural networks for feature extraction and a locality-constrained attention module that directly produces continuous displacement fields via attention-weighted aggregation, without requiring learnable parameters in the matching stage. VFA is end-to-end differentiable and compatible with both intra- and inter-modality registration problems, excelling in accuracy and computational efficiency relative to baseline methods.
1. Theoretical Foundation and Framework Architecture
The VFA framework adopts a three-stage decomposition for deformable image registration:
- Feature Extraction: For input fixed () and moving () images defined on , two parallel multi-resolution U-Net-style CNNs extract feature maps at scales ( for finest, for coarsest), yielding and of matching dimensions. U-Net weights are shared for intra-modal registration and independent for inter-modal.
- Feature Matching (Attention): For each discrete voxel at scale , a local attention window is established.
- Location Retrieval (Vector-Field Assembly): Displacement vectors are retrieved as attention-weighted summations of fixed offsets, forming a dense continuous vector field.
A coarse-to-fine strategy is used: the displacement estimated at a coarser scale is upsampled and the moving image feature maps are warped accordingly before processing each finer scale.
2. Attention Module and Mechanism
The modality-agnostic attention module is parameter-free and consists of the following steps:
- Query and Key Construction:
At voxel in , query vector ; keys are collected within .
- Value Definition:
Value matrix provides fixed displacement vectors , indexed over all neighbors.
- Scaled Dot-Product Attention:
These scores are normalized with softmax:
- Displacement Calculation:
The attention-weighted displacement for position :
yielding continuous, sub-voxel estimates.
Key properties:
- No learnable parameters in attention; only feature extractor (CNN weights) and scale are learned.
- Cosine similarity may replace dot product as similarity kernel, yielding visually more coherent cross-modal correspondence, though with higher GPU memory consumption and no significant change in Dice performance.
3. Continuous Deformation and Regularization
The resultant displacement field is composed across scales and rendered as a continuous deformation:
- Multi-scale Composition:
At each scale , absolute map:
Composed with upsampled for hierarchical refinement.
- Smoothness Regularization:
Diffusion penalty encourages spatial smoothness:
- Warped Image Rendering:
The final deformation governs image warping via a differentiable grid sampler:
4. Training Methodology and Modal-Agnostic Loss Functions
VFA supports unsupervised, semi-supervised, and weakly supervised training regimes:
- Unsupervised Loss Function:
- : for intra-modal uses normalized cross-correlation (NCC); for inter-modal uses mutual information (MI); or for CT lung, mean-squared error (MSE).
- (Weakly) Supervised Loss Terms:
Provided anatomical labels (, ) or landmarks (, ), the losses include:
The total loss integrates these with appropriate weights.
5. Experimental Results and Performance Comparison
Empirical evaluations demonstrate VFA's superior performance across datasets and modalities:
| Dataset | Loss/Setting | VFA DSC | Baseline DSCs |
|---|---|---|---|
| IXI T1w atlas→subject | NCC+diffusion, λ=1 | 0.806±0.012 | Im2grid 0.792±0.012; TransMorph 0.774±0.029; VoxelMorph 0.726±0.048 |
| T2w→T1w MR (inter-modal) | MI, λ=0.2 | 0.725±0.022 | DMR 0.671±0.038; TransMorph 0.660±0.044; Im2grid 0.668±0.025 |
| Learn2Reg 2021 (OASIS, weakly supervised) | Dice, HD95, SDLogJ | DSC 0.834, HD95 1.66mm, SDLogJ 0.234 | Best DSC among all entrants |
| Learn2Reg 2022 lung CT (semi-supervised) | TRE, TRE30 | TRE 1.705mm, TRE30 2.311mm | Among top 3, best TRE30 overall |
Non-diffeomorphic voxels remain below 0.1% in intra-modal tasks.
6. Algorithmic Ablations and Modality-Agnostic Characteristics
VFA's design decouples feature extraction from spatial matching, yielding generality across imaging contrasts:
- Intra-modal registration: Shared CNN weights; NCC loss
- Inter-modal registration: Independent CNNs; MI loss
Ablations reveal:
- Replacing dot-product similarity with cosine similarity enhances visual coherence in feature matching across modalities, with minimal effect on Dice coefficient.
- VFA retains accuracy with half-width U-Net; baseline architectures do not match VFA with increased capacity.
This suggests VFA's parameter-free attention can robustly integrate with feature extractors, and generalizes across domains without modality-specific tuning.
7. Implementation Protocol
End-to-end differentiable computation enables direct integration with modern deep learning workflows. The following pseudocode summarizes the multi-scale inference and training loop for VFA:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
Input: I_f, I_m Params: CNN_f, CNN_m (U-Nets), β, λ, other loss weights for l = L…1 do # 1) extract or warp features F_f^l ← CNN_f features at level l from I_f if l == L: M_warp^l ← CNN_m features at level l from I_m else: M_warp^l ← warp(CNN_m^l(I_m), φ^{l+1}) # 2) option: conv to adjust channels F_f^l, M_warp^l ← Conv(F_f^l), Conv(M_warp^l) # 3) attention‐matching on every x for each voxel x in F_f^l: Q = F_f^l(x) # 1×C collect K_δ = M_warp^l(x+δ) for δ∈{-1,0,1}^d compute scores s_δ = (Q·K_δ)/sqrt(C) A_δ = softmax_δ(s_δ) u^l(x) = Σ_δ A_δ·(–δ) end φ^l(x) = x + β·u^l(x) if l<L: φ^l ← compose(φ^l, upsample(φ^{l+1})) end I_w = grid_sampler(I_m, φ^1) Compute total loss L = Sim(I_f, I_w) + λ Reg(u^1) [+ supervised losses…] Backprop and update CNN_f, CNN_m, β |
VFA's locality-constrained, modality-agnostic deformable attention paradigm is extensible to both supervised and unsupervised workflows, and can be plugged onto arbitrary feature extractors for registration tasks in medical imaging and beyond.