Diffusion-Empowered AutoMedSAM
- The paper introduces AutoMedSAM, an end-to-end framework that automates semantic segmentation via a diffusion-based dual-branch prompt encoder.
- It leverages a joint uncertainty-aware multi-loss strategy and adapts MedSAM’s backbone to optimize class-specific mask prediction.
- Empirical evaluations across CT, MR, and X-ray modalities demonstrate superior segmentation accuracy and robust cross-dataset generalization.
Diffusion-Empowered AutoPrompt MedSAM (AutoMedSAM) is an end-to-end medical image segmentation framework that extends the Segment Anything Model (SAM) and its medical adaptation, MedSAM. Addressing the notable challenges of manual prompt dependency and lack of semantic labeling in conventional MedSAM, AutoMedSAM integrates a diffusion-based dual-branch prompt encoder to automate class-conditioned segmentation. This framework enables fully automated mask prediction with semantic association, optimized via a joint uncertainty-aware multi-loss strategy, and demonstrates superior segmentation accuracy and generalization across multiple clinical imaging modalities (Huang et al., 5 Feb 2025).
1. Architecture Overview
AutoMedSAM retains the architectural backbone of MedSAM, composed of a frozen image encoder and a mask decoder , while fundamentally replacing the manual prompt encoder with a diffusion-based class prompt encoder . The input image is encoded to feature maps: Given an anatomical class index , the encoder generates two prompt embeddings: where (sparse prompt) encodes global cues and (dense prompt) encodes local features. The mask decoder then combines the image features, positional embedding , and both prompt vectors to predict the segmentation mask: This pipeline eliminates the need for manual clicks, boxes, or scribbles and embeds semantic class information directly into the segmentation masks (Huang et al., 5 Feb 2025).
2. Diffusion-Based Class Prompt Encoder Design
AutoMedSAM’s class prompt encoder operates as a conditional diffusion model. The class index is projected and reshaped for conditioning: For forward diffusion, isotropic Gaussian noise with is added to the image feature,
This forms the noisy, class-conditioned embedding.
The reverse diffusion employs a U-Net structure, processing through convolutional layers with class re-injection at each layer. The encoder branches into:
- Dense/local branch: Element-wise attention is computed,
followed by masked feature multiplication and upsampling to produce .
- Sparse/global branch: Channel attention leverages spatially average-pooled features,
and produces via channelwise scaling.
Final prompt embeddings are typically concatenated: This enables integration of both fine-grained and global context within the prompt representation (Huang et al., 5 Feb 2025).
3. Prompt Integration and Segmentation Mask Generation
The mask decoder incorporates prompt embeddings via cross-attention mechanisms:
Semantic prompt features are injected into the decoder’s latent space, ensuring that output masks encode both object shape and class semantics. This design provides fully automated semantic segmentation for specified anatomical classes, broadening utility for both clinical and non-expert contexts (Huang et al., 5 Feb 2025).
4. Joint Training with Uncertainty-Aware Loss Balancing
AutoMedSAM is optimized with a joint objective comprising five loss components:
- Sparse prompt MSE:
- Dense prompt MSE:
- Dice loss:
- Cross-entropy loss:
- Shape-distance loss:
Loss terms are dynamically weighted using the uncertainty weighting framework of Tsai et al.: This obviates manual tuning of loss weights and facilitates balanced learning across heterogeneous objectives (Huang et al., 5 Feb 2025).
5. Training Procedure
During training, the image encoder remains frozen while and are updated. Optimization employs AdamW with learning rate 5e-4, , , and , using a reduce-on-plateau scheduler (factor 0.9, patience 5), batch size 16, up to 100 epochs. The core process follows:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
for epoch in 1..100: for (I, c, M_gt) in train_loader: F_I = E_I(I) # frozen image encoder t = random.randint(0, T-1) ε_t = N(0, (1/(t+1))^2) F_t = F_I + c_expand(c) + ε_t # forward diffusion (P_s, P_d) = E_P.reverse_diffusion(F_t, c) # prompt encoding M_pred = D_M(F_I, P_p, P_s, P_d) # mask prediction L1 = MSE(P_s, MedSAM_s) L2 = MSE(P_d, MedSAM_d) L3 = Dice(M_pred, M_gt) L4 = CE(M_pred, M_gt) L5 = ShapeDist(M_pred, M_gt) L = uncertainty_weighted(L1, L2, L3, L4, L5) L.backward() optimizer.step() |
6. Empirical Evaluation
AutoMedSAM is evaluated across diverse medical imaging datasets: AbdomenCT-1K (CT, 5 organs), BraTS (MR-FLAIR, tumor), Kvasir-SEG (endoscopy, polyp), Chest-XML (X-ray, lung), and in cross-dataset scenarios (AMOS, BraTS-T1CE). Performance is measured using Dice Similarity Coefficient (DSC) and Normalized Surface Distance (NSD).
Representative Quantitative Results (AbdomenCT-1K):
| Method | DSC (%) | NSD (%) |
|---|---|---|
| MedSAM | 93.505 | 92.969 |
| SurgicalSAM | 75.505 | 70.119 |
| AutoMedSAM (O) | 94.580 | 95.148 |
On single-object datasets (BraTS, Kvasir, Chest-XML), AutoMedSAM outperforms all baselines by 1–5 points in DSC and NSD. Cross-dataset evaluation (train: AbdomenCT, test: AMOS) yields DSC 71.14% for AutoMedSAM vs. 56.93% for SurgicalSAM. Ablation studies demonstrate the benefits of dual-branch prompts, diffusion, and uncertainty weighting (Huang et al., 5 Feb 2025).
7. Strengths, Limitations, and Future Directions
AutoMedSAM delivers a fully automated, semantically labeled segmentation workflow, eliminating manual prompt annotation and enabling class-aware mask generation. The dual-branch diffusion encoder captures both global and local context, and uncertainty weighting harmonizes joint optimization. Nevertheless, computational overhead from diffusion steps is nontrivial, and current deployments require a predefined class index set, precluding open-vocabulary extension. There may be performance degradation on extremely small or highly noisy structures. Future work will target lightweight diffusion models, open-set recognition, and scaling to 3D volumetric data.
AutoMedSAM establishes a state-of-the-art, prompt-free, and semantically explicit segmentation paradigm for clinical and non-expert end users (Huang et al., 5 Feb 2025).