Papers
Topics
Authors
Recent
2000 character limit reached

MTMed3D: Unified 3D Medical Imaging

Updated 22 November 2025
  • MTMed3D is a multi-task framework designed to perform object detection, semantic segmentation, and classification on volumetric MRIs.
  • It integrates a shared 3D Swin-Transformer encoder with specialized CNN decoders to balance performance and efficiency across tasks.
  • MTMed3D achieves state-of-the-art detection and competitive segmentation/classification results while significantly lowering computational costs and inference time.

MTMed3D is an end-to-end multi-task Transformer-based framework developed for 3D medical imaging, designed to jointly perform object detection, semantic segmentation, and classification from volumetric data. Addressing the inefficiencies of single-task models, MTMed3D introduces a unified architecture where all three core tasks share a 3D Swin-Transformer encoder, with subsequent specialized CNN decoders. The model is evaluated on multi-modal brain tumor MRIs (BraTS 2018/2019), delivering state-of-the-art detection performance while maintaining competitive segmentation and classification, with substantial reductions in computational cost, parameter count, and inference latency compared to single-task models (Li et al., 15 Nov 2025).

1. Model Architecture

1.1. Shared 3D Swin-Transformer Encoder

MTMed3D inputs a multi-modal MRI volume (XR4×240×240×155X \in \mathbb{R}^{4 \times 240 \times 240 \times 155}), which is center-cropped and normalized to X~R4×96×96×96\tilde X \in \mathbb{R}^{4 \times 96 \times 96 \times 96}. The volume is divided into non-overlapping 3D patches, then linearly embedded:

Z0=LinearEmbed(PatchPartition(X~))RM×C,Z^0 = \operatorname{LinearEmbed}\bigl(\operatorname{PatchPartition}(\tilde X)\bigr) \in \mathbb{R}^{M \times C},

where MM is the number of patches and CC the feature dimension. Four hierarchical stages of Swin-Transformer blocks and patch-merging layers generate a set of multi-scale feature maps {Zi}i=14\{Z^i\}_{i=1}^4.

Each Swin-Transformer block alternates Windowed Multi-Head Self-Attention (W-MSA) with Shifted-Window Multi-Head Self-Attention (SW-MSA), ensuring localized and global context aggregation:

Zl=W-MSA(LN(Zl1))+Zl1, Zl=MLP(LN(Zl))+Zl, Zl+1=SW-MSA(LN(Zl))+Zl, Zl+1=MLP(LN(Zl+1))+Zl+1.\begin{aligned} Z'_l &= \mathrm{W\text{-}MSA}(\mathrm{LN}(Z_{l-1})) + Z_{l-1},\ Z_l &= \mathrm{MLP}(\mathrm{LN}(Z'_l)) + Z'_l,\ Z'_{l+1}&=\mathrm{SW\text{-}MSA}(\mathrm{LN}(Z_l)) + Z_l,\ Z_{l+1}&=\mathrm{MLP}(\mathrm{LN}(Z'_{l+1})) + Z'_{l+1}. \end{aligned}

This design enables hard sharing of contextual, multi-scale features—eliminating redundancy across tasks.

1.2. Task-Specific CNN Decoders

  • Detection Decoder: Features {F2,F3,F4,F5}\{F_2, F_3, F_4, F_5\} are unified in channel space using 1×1×11 \times 1 \times 1 convolutions, normalized (GroupNorm), and activated (ReLU). The decoder employs a Path Aggregation Network (PANet) to propagate low-level features and facilitates multi-scale object prediction, producing classification scores and 3D bounding box offsets.
  • Segmentation Decoder: Adopts a U-Net-style architecture as in Swin UNETR. Encoder features are routed via long skip connections into residual + deconvolutional layers, outputting region probability maps for Whole Tumor (WT), Tumor Core (TC), and Enhancing Tumor (ET).
  • Classification Decoder: Refined segmentation features are input to a DenseNet-121, where feature maps are concatenated at every layer (xl=Hl([x0,x1,,xl1])x_{l} = H_l([x_0, x_1, \dots, x_{l-1}])). After global pooling, a fully connected layer predicts high- vs. low-grade glioma status.

2. Multi-Task Learning and Optimization

2.1. Composite Training Objective

MTMed3D is trained with a joint objective:

Ltotal=w1Lseg+w2Lcls+w3Ldet\mathcal{L}_{\mathrm{total}} = w_1\,\mathcal{L}_{\mathrm{seg}} + w_2\,\mathcal{L}_{\mathrm{cls}} + w_3\,\mathcal{L}_{\mathrm{det}}

where Lseg\mathcal{L}_{\mathrm{seg}} (segmentation) is the Dice loss over three regions, Lcls\mathcal{L}_{\mathrm{cls}} (classification) is Focal Loss, and Ldet\mathcal{L}_{\mathrm{det}} (detection) combines Smooth L1 loss for bounding box regression and Focal Loss for object classification.

2.2. Gradient-Normalization for Dynamic Task Weighting

Task-balancing is achieved using GradNorm, adaptively adjusting {wi}\{w_i\} during training. Let Gw(i)(t)=w[wi(t)Li(t)]2G_w^{(i)}(t) = \|\nabla_w[w_i(t)\,L_i(t)]\|_2 be the gradient norm for task ii at iteration tt, with average Gw(t)=1TiGw(i)(t)\overline G_w(t) = \frac{1}{T}\sum_i G_w^{(i)}(t), and relative inverse training rate

ri(t)=Li(t)/Li(0)1Tj[Lj(t)/Lj(0)].r_i(t) = \frac{L_i(t)/L_i(0)} {\frac{1}{T}\sum_j [L_j(t)/L_j(0)]}.

GradNorm minimizes Gw(i)(t)Gw(t)ri(t)|G_w^{(i)}(t)-\overline G_w(t)\,r_i(t)| with respect to {wi}\{w_i\}. This ensures balanced gradient flow and prevents any single task from monopolizing encoder learning capacity. Synergistic feature learning is promoted as decoders back-propagate into the shared Swin backbone.

3. Data Sources, Preprocessing, and Training

3.1. Datasets and Labels

  • Datasets: BraTS 2018 (285 cases: 210 high-grade glioma, 75 low-grade glioma) and BraTS 2019 (335 cases: 259 HGG, 76 LGG).
  • Modalities: T1, T1ce, T2, FLAIR (stacked into 4-channel volumes).
  • Labels: Segmentation masks for WT/TC/ET; categorical classification label (HGG/LGG); detection bounding boxes autogenerated from segmentation.

3.2. Augmentation and Training Protocols

  • Spatial augmentations: random 3-axis flips (p=0.5p=0.5), random 3D rotations (p=0.75p=0.75).
  • Photometric augmentations: per-channel intensity shifts and scalings (range (0.1,0.1)(-0.1, 0.1)).
  • Cropping: All volumes cropped to 96396^3 to accommodate GPU memory constraints.
  • Optimization: AdamW optimizer with cosine annealing learning rate. Branch-specific learning rates: 10410^{-4} (segmentation), 10510^{-5} (detection/classification).
  • Batch size: 1.
  • Validation: Five-fold cross-validation; typical training runs are $200-300$ epochs.

3.3. Hardware

Experiments executed on Intel i7-13700F CPU and NVIDIA RTX 4070 Ti (12 GB VRAM).

4. Quantitative Results and Empirical Analysis

4.1. Overall Task Performance

Task MTMed3D (5-fold Avg/Best) Leading Baseline(s)
Detection mAP@[0.1:0.5]=0.9711/0.9082 RetinaNet: 0.8664 / 0.6768
[email protected]=0.8650 nnDetection/MedYOLO: ~0.8360/0.8610
Segmentation Dice WT=0.8793, TC=0.8019, ET=0.7755 Top single-task methods: WT=0.9082, TC=0.8183, ET=0.8038
Classification Acc=0.9193, Sens=0.9761, Spec=0.7634 Mask-RCNN(2D)=0.9630, CNN(2D)=0.9715, SE-ResNeXt=0.9745

MTMed3D outperforms previous 3D detection methods and matches or exceeds segmentation accuracy of leading single-task architectures.

4.2. Efficiency Metrics

  • Task-averaged accuracy: MTMed3D: Seg Dice avg =0.8189=0.8189; Acc =0.9193=0.9193; Det mAP@[0.1:0.5] =0.9082=0.9082 vs. Single-task averages: Seg Dice =0.8052=0.8052; Acc =0.9017=0.9017; Det mAP@[0.1:0.5] =0.8217=0.8217.
  • Resource reductions vs. sum of three single-task models:
    • MACs: 2×10112 \times 10^{11} vs 4×10114 \times 10^{11} (46.6-46.6\%)
    • FLOPs: 4×10114 \times 10^{11} vs 7×10117 \times 10^{11} (42.9-42.9\%)
    • Parameters: 8×1078 \times 10^7 vs 1×1081 \times 10^8 (47.8-47.8\%)
    • Inference time: $0.10$ s vs $0.22$ s (55.7-55.7\%)
    • Model size: 307 MB vs 588 MB

4.3. Ablation Studies

  • PANet vs. FPN: PANet head in detection, combined with GradNorm, yields superior ET Dice, lower Hausdorff distance (HD), and improved mAP/mAR.
  • Balancing strategy: GradNorm produces more balanced multi-task results and notably higher classification specificity and accuracy than MGDA.

5. Design Rationale, Limitations, and Prospective Developments

Hard parameter sharing via a 3D Swin-Transformer backbone leverages both long-range dependencies and hierarchical feature context, supporting robust multi-task performance, particularly for detection due to transfer of segmentation cues. A single forward pass through MTMed3D suffices for all three tasks, resulting in halved computational, memory, and latency footprints relative to triplicate single-task systems.

Classification accuracy, however, remains below leading dedicated 2D/3D classifiers. This is attributed to the lack of a learnable [CLS][\textrm{CLS}] token, imbalanced 3D data, and lack of slice-level pre-processing. Future directions proposed include introducing a learnable classification token or cross-attention mechanisms, pre-training the Swin encoder on larger 3D medical datasets, expanding the framework to additional anatomical sites and tasks such as registration or prognosis, and evaluating dynamic task weighting/auxiliary losses beyond GradNorm.

6. Context and Significance in Multi-Task Medical Imaging

MTMed3D constitutes the first end-to-end, Swin-Transformer-based multi-task network for 3D medical imaging that simultaneously addresses detection, segmentation, and classification. This demonstrates both technical feasibility and substantial system-level efficiency for multi-purpose clinical AI pipelines, suggesting a shift toward unifying architectures in volumetric diagnostic applications (Li et al., 15 Nov 2025).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)
Forward Email Streamline Icon: https://streamlinehq.com

Follow Topic

Get notified by email when new papers are published related to MTMed3D.