MTMed3D: Unified 3D Medical Imaging
- MTMed3D is a multi-task framework designed to perform object detection, semantic segmentation, and classification on volumetric MRIs.
- It integrates a shared 3D Swin-Transformer encoder with specialized CNN decoders to balance performance and efficiency across tasks.
- MTMed3D achieves state-of-the-art detection and competitive segmentation/classification results while significantly lowering computational costs and inference time.
MTMed3D is an end-to-end multi-task Transformer-based framework developed for 3D medical imaging, designed to jointly perform object detection, semantic segmentation, and classification from volumetric data. Addressing the inefficiencies of single-task models, MTMed3D introduces a unified architecture where all three core tasks share a 3D Swin-Transformer encoder, with subsequent specialized CNN decoders. The model is evaluated on multi-modal brain tumor MRIs (BraTS 2018/2019), delivering state-of-the-art detection performance while maintaining competitive segmentation and classification, with substantial reductions in computational cost, parameter count, and inference latency compared to single-task models (Li et al., 15 Nov 2025).
1. Model Architecture
1.1. Shared 3D Swin-Transformer Encoder
MTMed3D inputs a multi-modal MRI volume (), which is center-cropped and normalized to . The volume is divided into non-overlapping 3D patches, then linearly embedded:
where is the number of patches and the feature dimension. Four hierarchical stages of Swin-Transformer blocks and patch-merging layers generate a set of multi-scale feature maps .
Each Swin-Transformer block alternates Windowed Multi-Head Self-Attention (W-MSA) with Shifted-Window Multi-Head Self-Attention (SW-MSA), ensuring localized and global context aggregation:
This design enables hard sharing of contextual, multi-scale features—eliminating redundancy across tasks.
1.2. Task-Specific CNN Decoders
- Detection Decoder: Features are unified in channel space using convolutions, normalized (GroupNorm), and activated (ReLU). The decoder employs a Path Aggregation Network (PANet) to propagate low-level features and facilitates multi-scale object prediction, producing classification scores and 3D bounding box offsets.
- Segmentation Decoder: Adopts a U-Net-style architecture as in Swin UNETR. Encoder features are routed via long skip connections into residual + deconvolutional layers, outputting region probability maps for Whole Tumor (WT), Tumor Core (TC), and Enhancing Tumor (ET).
- Classification Decoder: Refined segmentation features are input to a DenseNet-121, where feature maps are concatenated at every layer (). After global pooling, a fully connected layer predicts high- vs. low-grade glioma status.
2. Multi-Task Learning and Optimization
2.1. Composite Training Objective
MTMed3D is trained with a joint objective:
where (segmentation) is the Dice loss over three regions, (classification) is Focal Loss, and (detection) combines Smooth L1 loss for bounding box regression and Focal Loss for object classification.
2.2. Gradient-Normalization for Dynamic Task Weighting
Task-balancing is achieved using GradNorm, adaptively adjusting during training. Let be the gradient norm for task at iteration , with average , and relative inverse training rate
GradNorm minimizes with respect to . This ensures balanced gradient flow and prevents any single task from monopolizing encoder learning capacity. Synergistic feature learning is promoted as decoders back-propagate into the shared Swin backbone.
3. Data Sources, Preprocessing, and Training
3.1. Datasets and Labels
- Datasets: BraTS 2018 (285 cases: 210 high-grade glioma, 75 low-grade glioma) and BraTS 2019 (335 cases: 259 HGG, 76 LGG).
- Modalities: T1, T1ce, T2, FLAIR (stacked into 4-channel volumes).
- Labels: Segmentation masks for WT/TC/ET; categorical classification label (HGG/LGG); detection bounding boxes autogenerated from segmentation.
3.2. Augmentation and Training Protocols
- Spatial augmentations: random 3-axis flips (), random 3D rotations ().
- Photometric augmentations: per-channel intensity shifts and scalings (range ).
- Cropping: All volumes cropped to to accommodate GPU memory constraints.
- Optimization: AdamW optimizer with cosine annealing learning rate. Branch-specific learning rates: (segmentation), (detection/classification).
- Batch size: 1.
- Validation: Five-fold cross-validation; typical training runs are $200-300$ epochs.
3.3. Hardware
Experiments executed on Intel i7-13700F CPU and NVIDIA RTX 4070 Ti (12 GB VRAM).
4. Quantitative Results and Empirical Analysis
4.1. Overall Task Performance
| Task | MTMed3D (5-fold Avg/Best) | Leading Baseline(s) |
|---|---|---|
| Detection | mAP@[0.1:0.5]=0.9711/0.9082 | RetinaNet: 0.8664 / 0.6768 |
| [email protected]=0.8650 | nnDetection/MedYOLO: ~0.8360/0.8610 | |
| Segmentation | Dice WT=0.8793, TC=0.8019, ET=0.7755 | Top single-task methods: WT=0.9082, TC=0.8183, ET=0.8038 |
| Classification | Acc=0.9193, Sens=0.9761, Spec=0.7634 | Mask-RCNN(2D)=0.9630, CNN(2D)=0.9715, SE-ResNeXt=0.9745 |
MTMed3D outperforms previous 3D detection methods and matches or exceeds segmentation accuracy of leading single-task architectures.
4.2. Efficiency Metrics
- Task-averaged accuracy: MTMed3D: Seg Dice avg ; Acc ; Det mAP@[0.1:0.5] vs. Single-task averages: Seg Dice ; Acc ; Det mAP@[0.1:0.5] .
- Resource reductions vs. sum of three single-task models:
- MACs: vs (\%)
- FLOPs: vs (\%)
- Parameters: vs (\%)
- Inference time: $0.10$ s vs $0.22$ s (\%)
- Model size: 307 MB vs 588 MB
4.3. Ablation Studies
- PANet vs. FPN: PANet head in detection, combined with GradNorm, yields superior ET Dice, lower Hausdorff distance (HD), and improved mAP/mAR.
- Balancing strategy: GradNorm produces more balanced multi-task results and notably higher classification specificity and accuracy than MGDA.
5. Design Rationale, Limitations, and Prospective Developments
Hard parameter sharing via a 3D Swin-Transformer backbone leverages both long-range dependencies and hierarchical feature context, supporting robust multi-task performance, particularly for detection due to transfer of segmentation cues. A single forward pass through MTMed3D suffices for all three tasks, resulting in halved computational, memory, and latency footprints relative to triplicate single-task systems.
Classification accuracy, however, remains below leading dedicated 2D/3D classifiers. This is attributed to the lack of a learnable token, imbalanced 3D data, and lack of slice-level pre-processing. Future directions proposed include introducing a learnable classification token or cross-attention mechanisms, pre-training the Swin encoder on larger 3D medical datasets, expanding the framework to additional anatomical sites and tasks such as registration or prognosis, and evaluating dynamic task weighting/auxiliary losses beyond GradNorm.
6. Context and Significance in Multi-Task Medical Imaging
MTMed3D constitutes the first end-to-end, Swin-Transformer-based multi-task network for 3D medical imaging that simultaneously addresses detection, segmentation, and classification. This demonstrates both technical feasibility and substantial system-level efficiency for multi-purpose clinical AI pipelines, suggesting a shift toward unifying architectures in volumetric diagnostic applications (Li et al., 15 Nov 2025).
Sponsored by Paperpile, the PDF & BibTeX manager trusted by top AI labs.
Get 30 days free