Conditional SPAdaIN for Adaptive Registration
- CSAIN is a conditional module that adapts instance normalization spatially to generate diverse deformation fields for image registration.
- It integrates with a Laplacian-pyramid U-Net to enable region-specific regularization without needing retraining for varied hyperparameter settings.
- Quantitative evaluations on brain MRI data demonstrate improved Dice scores and reduced deformation foldings compared to traditional methods.
Conditional Spatially-Adaptive Instance Normalization (CSAIN) is a module for deep neural networks that enables spatially-varying and adaptive regularization in the context of deformable image registration. The core innovation is the conditioning of instance normalization on a spatial hyperparameter map, allowing a single registration model to produce a family of plausible deformation fields governed by local regularization weights, with no need for retraining for each configuration. This approach addresses inherent limitations in previous methods that required training separate models per hyperparameter setting and did not support spatially-dependent regularization (Wang et al., 2023).
1. Mathematical Formulation of CSAIN
Let denote a feature map at a given network layer, where is the set of spatial positions and is the number of channels. For spatially adaptive conditioning, a spatial hyperparameter map is provided, typically corresponding to region-specific regularization weights.
CSAIN is implemented as follows:
- Instance Normalization:
- Spatial Conditioning via Learned Scale and Shift:
- is resampled to match the feature resolution, yielding .
- Two shallow convolutional layers generate per-channel, per-location scale and shift 0:
1
- The CSAIN-modulated feature:
2
or, vectorized over all channels,
3
where 4 denotes the Hadamard product and 5 are broadcast per channel.
This formulation permits each spatial location and channel to be adaptively scaled and shifted in response to arbitrarily specified local regularization weights.
2. Network Integration and Architecture
The CSAIN module is instantiated within a Laplacian-pyramid U-Net backbone ("LapIRN") comprising 6 resolution levels. Each level incorporates a 7 downsampling encoder, 8 residual blocks, and a decoder. The residual blocks are replaced by CSAIN-blocks, where each incorporates two consecutive CSAIN layers (employing 9 convolutions for 0 and 1) interleaved with LeakyReLU activations, along with a pre-activation skip connection.
Encoding of the conditioning map proceeds as follows:
- Binary region masks 2 select each of 3 anatomical regions. A vector of region-specific weights 4 induces a one-channel map 5.
- To mitigate sharp boundaries, 6 is convolved with a Gaussian kernel (std 7 voxels, window 8), resulting in 9.
- At each feature resolution, 0 is resampled and routed to the corresponding CSAIN-block's conditioning layers.
3. Deformable Registration Framework
Conditional SPAdaIN is deployed within an end-to-end deformable registration framework. The system takes as input a fixed image 1, a moving image 2, and a spatial regularization map 3. The output is a dense displacement field 4, yielding a deformation 5.
The objective at pyramid level 6 is:
7
where 8 are images downsampled to level 9, 0 is local normalized cross-correlation with window size 1, and 2 is the spatial gradient. The spatially-varying regularization is enforced by elementwise multiplication with 3. The total loss sums over all 4 levels:
5
This structure enables spatially-varying regularization directly within the data flow of the network.
4. Training Regimen and Inference Modality
Training is conducted using the OASIS T1 brain MRI dataset (416 volumes, pre-aligned, skull-stripped), partitioned into 340 training, 20 validation, and 56 test subjects. Registration pairs are established via subject permutation. Anatomical regions (6) are delineated as background, cortex, subcortical gray matter, white matter, and cerebrospinal fluid. For each minibatch, region weights 7 are sampled uniformly from 8, composed and Gaussian-smoothed to 9 per the protocol. The model parameters 0 (network weights, including CSAIN kernels) are optimized by Adam at learning rate 1, minimizing the overall loss 2.
Inference utilizes the fixed network parameters. To obtain a single best deformation, 3 may be manually selected or selected via automated search:
4
with 5 the registered output parameterized by 6. Critically, only 7 is tuned at inference without retraining, enabling rapid generation of multiple plausible outputs under varying spatially adaptive smoothness.
5. Quantitative and Qualitative Evaluation
Empirical evaluation on the OASIS test set (56 subjects, 5 regions) demonstrates:
| Method | Avg. Dice | %folds (8) | Avg 9 | Std(0) | 1 |
|---|---|---|---|---|---|
| CSAIN w/ Gaussian 2 | 0.764 | 1.04 | 3 | 4 | [3.76, 2.42, 2.61, 2.33, 0.67] |
| CSAIN w/o Gaussian (5 only) | 0.759 | 1.22 | 6 | 7 | [3.58, 1.83, 2.18, 1.98, 0.56] |
| Baseline (CIR-DM, 8) | 0.749 | 0.66 | 9 | 0 | [1, 1, 1, 1, 1] |
Key findings include:
- CSAIN with spatially-varying, Gaussian-smoothed regularization achieves 11.5-point improvement in Dice score over spatially-invariant baseline.
- Adaptive 2 selection reduces deformation gradient magnitudes (achieving smoother transformations) while maintaining or improving accuracy.
- Gaussian smoothing of 3 reduces the percentage of foldings (locations where the Jacobian determinant 4) compared to sharp boundary maps.
Ablation indicates that removing Gaussian smoothing results in a minor Dice decrease and increased foldings, supporting the benefit of enforcing smooth transitions in the hyperparameter map. Qualitative analysis shows that varying a single 5 affects Dice and deformation regularity locally within the specified anatomical region, minimally impacting distant areas. This evidences effective spatial adaptation (Wang et al., 2023).
6. Significance and Implications
CSAIN provides a mechanism for controlling spatially-variant regularization in deformable image registration using a single network conditioned at inference on user- or optimizer-specified hyperparameter maps. This obviates the need for retraining across hyperparameter sweeps and supports automatic or interactive hyperparameter selection. The approach is validated experimentally through improved Dice scores and enhanced control over deformation regularity, with fine-grained adaptation documented both quantitatively and qualitatively across anatomical regions. The methodology integrates seamlessly with common encoder-decoder networks and generalizes to variable region definitions via mask-based construction of conditioning maps.
The results suggest that CSAIN constitutes an advance in the design of conditional, normalization-based deep learning layers for spatially adaptive regularization in image registration workflows, meriting further investigation and adoption in broader medical imaging scenarios (Wang et al., 2023).