3D-CPS: Semi-Supervised Volumetric Segmentation
- The paper demonstrates that integrating cross-pseudo supervision with twin U-Nets substantially improves segmentation of small or low-contrast anatomical structures.
- 3D-CPS is a semi-supervised learning paradigm that employs synchronized U-Nets exchanging hard pseudo-labels during joint training to leverage both labeled and unlabeled data.
- The methodology achieves measurable gains in Dice and NSD metrics on abdominal organ segmentation, validating its efficiency over standard nnU-Net configurations.
3D Cross-Pseudo Supervision (3D-CPS) is a semi-supervised learning paradigm built upon the nnU-Net framework, designed to address data efficiency constraints in volumetric medical image segmentation. By leveraging both labeled and large quantities of unlabeled data, 3D-CPS establishes a cross-pseudo supervision mechanism between two synchronized U-Nets, improving performance particularly for small or low-contrast anatomical structures. The method is evaluated in the context of abdominal organ segmentation for the MICCAI FLARE2022 challenge and demonstrates substantial improvements over baseline nnU-Net configurations (Huang et al., 2022).
1. Architectural Principles
3D-CPS extends the baseline nnU-Net architecture by introducing a semi-supervised, co-training–style configuration:
- Twin Networks: Two sibling U-Nets, denoted and , are instantiated with identical topology (either 2D or 3D nnU-Net variants) but independent weight initializations and separate optimizers.
- Joint Training: Both networks are trained on identical mini-batches, which combine labeled and unlabeled samples. During training, each model’s outputs are provided to its counterpart as pseudo-labels, implementing the cross-pseudo supervision paradigm.
- Ensemble Removal: The final nnU-Net model ensembling and some postprocessing stages are omitted for ablation; only are retained at inference. Each model can be run independently or in concert with nnU-Net’s sliding-window inference and test-time augmentation (TTA) strategies.
2. Loss Formulation and Training Objective
The 3D-CPS framework combines conventional supervised segmentation losses with a cross-pseudo supervision (CPS) loss term:
- Supervised Loss: For each labeled example from (labeled dataset), a composite loss is used:
where denotes the nnU-Net standard sum of cross-entropy and soft Dice losses.
- CPS Loss: Each network supplies its own one-hot (hard) pseudo-labels to supervise the peer network on both labeled and unlabeled data:
where
with , , and denotes the stop-gradient with an operation to yield hard pseudo-labels.
- Total Loss: The final objective introduced a ramped schedule for the CPS loss:
where is increased linearly from $0$ to $0.5$ over the first epochs (), and then held constant at $0.5$. This mechanism prevents early learning from unreliable pseudo-labels.
3. Pseudo-Label Generation and Mechanisms for Error Mitigation
Central to 3D-CPS is the careful management of pseudo-label quality and gradient propagation:
- Hard One-Hot Labels: Pseudo-labels are produced as hard one-hot vectors via on the network’s softmax output, followed by a stop-gradient operation, such that gradients do not back-propagate through a network’s own predictions.
- Progressive Loss Weighting: By increasing the CPS loss weight from zero, networks only begin leveraging each other's pseudo-labels after early stages of supervised-only learning, which empirically reduces noise propagation and error accumulation from unreliable initial predictions.
- Peer Supervision: Each network acts both as a student (learning from its peer's pseudo-labels) and as a teacher (generating pseudo-labels for the peer)—a mutual co-training relationship.
4. Data Processing and Architectural Deviations from nnU-Net
Several notable modifications are included to optimize semi-supervised performance:
- Intensity Normalization: Unlike nnU-Net, which normalizes intensities using only the mask foreground, 3D-CPS computes intensity statistics using the full volume on both labeled and unlabeled scans, ensuring identical normalization.
- Data Augmentation: The complete nnU-Net data augmentation pipeline (random rotations, scaling, elastic deformations, gamma transformations, etc.) is applied equally to both labeled and unlabeled images.
- Topology and Patch Size: The U-Net backbone is preserved: 3D variant utilizes 6 encoder–decoder levels with patch size ; 2D variant uses 8 levels (patch ).
- Forced Spacing: For contest submission and to fit hardware constraints, “forced spacing” (resampling to a coarser voxel size) is applied, accommodating deployment on 28 GB GPU memory.
5. Training Protocol and Implementation
The methodological specifics of 3D-CPS training for FLARE2022 are as follows:
| Aspect | 3D-CPS Protocol | 2D-CPS Protocol |
|---|---|---|
| Labeled Volumes () | 50 | 50 |
| Unlabeled Volumes () | 1000 out of 2000 | 1000 out of 2000 |
| Batch Composition | 2 labeled + 2 unlabeled/Net | 12 labeled + 12 unlabeled/Net |
| Optimizer | SGD, momentum 0.99, weight decay | Same |
| LR Schedule | Initial LR 0.01, ReduceLROnPlateau | Same |
| Epochs | 1000 | 1000 |
| CPS Ramp-up () | 500 | 500 |
| Inference | Sliding window (step 0.7), TTA; forced spacing for submission | Same |
6. Quantitative Benchmarks and Performance Evaluation
Performance is evaluated via mean Dice Similarity Coefficient (mDSC) and mean Normalized Surface Distance (mNSD) over 20 held-out cases (MICCAI FLARE2022 validation):
- 2D variant:
- Baseline nnU-Net (supervised): mDSC = 0.7846, mNSD = 0.8274
- +CPS: mDSC = 0.8069, mNSD = 0.8562
- Net improvement: +0.0223 DSC, +0.0288 NSD
- 3D variant:
- Baseline nnU-Net (supervised): mDSC = 0.8627, mNSD = 0.8969
- +CPS (3D-CPS): mDSC = 0.8794, mNSD = 0.9128
- Net improvement: +0.0167 DSC, +0.0159 NSD
Notably, small and low-contrast organs (e.g., gallbladder, duodenum, adrenal glands) demonstrated comparatively greater absolute improvement with CPS integration (Huang et al., 2022).
7. High-Level Training Loop and Workflow
The fundamental cycle of CPS co-training is succinctly captured in the following pseudocode (notation per source):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
for epoch = 1 to max_epochs: λ ← compute_lambda(epoch) # linear ramp to 0.5 over first Er epochs for each batch: # 1. Sample minibatch of labeled (xℓ, yℓ) and unlabeled xu xℓ, yℓ ← next_labeled_batch() xu ← next_unlabeled_batch() X ← concat(xℓ, xu) # 2. Forward pass both networks P1 ← T1(X) # [B, C, ...], float confidences P2 ← T2(X) # 3. Supervised loss on labeled samples Lsup ← ℓ_sup(P1[ℓ], yℓ) + ℓ_sup(P2[ℓ], yℓ) # 4. Cross-pseudo supervision on all samples Y1 ← stopgrad(argmax(P1)) Y2 ← stopgrad(argmax(P2)) Lcps ← ℓ_sup(P1, Y2) + ℓ_sup(P2, Y1) # 5. Total loss Ltotal ← Lsup + λ * Lcps # 6. Backward and optimize each network optimizer1.zero_grad() optimizer2.zero_grad() Ltotal.backward() optimizer1.step() optimizer2.step() |
3D-CPS provides a principled, data-efficient semi-supervised approach for volumetric medical segmentation by uniting self-configuring preprocessing, systematic cross-pseudo supervision, and progressive loss weighting. The methodology demonstrates consistent improvements in key segmentation metrics without recourse to additional labeled data or pre-trained external models (Huang et al., 2022).