SLIM-Diff: Compact Joint Diffusion for Imaging
- SLIM-Diff is a joint image-mask diffusion framework that uses a shared-bottleneck U-Net to couple anatomical and lesion features in data-scarce settings.
- It employs a tunable Lp loss to balance image realism and mask sharpness, outperforming standard diffusion models in FLAIR MRI augmentation.
- The framework reduces parameter count significantly while enabling simultaneous synthesis of high-fidelity FLAIR images and corresponding lesion masks for rare neurological disorders.
SLIM-Diff denotes a compact joint image–mask diffusion framework designed for data-scarce medical imaging regimes, particularly for epilepsy-focused FLAIR MRI. Its core contributions are (i) a single shared-bottleneck U-Net enforcing coupling between anatomy and lesion geometry through a 2-channel representation, and (ii) explicit loss-geometry tuning via a tunable objective. Unlike standard large-scale diffusion architectures, SLIM-Diff is tailored for low-data settings and the simultaneous synthesis of both FLAIR images and corresponding lesion masks, targeting robust data augmentation and generative segmentation for rare disorders such as focal cortical dysplasia (FCD) (Pascual-González et al., 3 Feb 2026).
1. Background and Motivation
Medical imaging datasets for rare neurological disorders, such as FCD II, are highly limited, typically yielding no more than 80 lesion-positive samples in public collections. This scarcity, combined with the subtle and highly-localized nature of lesion morphology on FLAIR MRI, triggers instability and overfitting in high-capacity generative models (e.g., canonical Stable Diffusion or LDM U-Nets with 860M parameters). The simultaneous generation of anatomically plausible images and spatially consistent lesion masks can enable task-specific data augmentation for segmentation and classification pipelines—but only if the generative process faithfully preserves both modalities' structure and does not merely memorize the limited training set.
SLIM-Diff addresses these requirements by adopting a tightly-coupled architecture (shared-bottleneck U-Net), reducing parameter count by more than an order of magnitude, and optimizing with a generalized loss geometry rather than canonical regression (Pascual-González et al., 3 Feb 2026).
2. Mathematical Formulation
SLIM-Diff implements the discrete-time DDPM framework ( steps, cosine noise schedule) for a joint image–mask tensor . The forward process adds Gaussian noise: The reverse process, parameterized by U-Net , targets one of three prediction parameterizations:
- -prediction (standard, predicts noise)
- velocity
- direct -prediction (clean image + mask)
The loss is a general norm over the prediction target: where is tunable (), with the specific choice impacting fidelity vs. mask morphology.
3. Shared-Bottleneck U-Net Architecture
The SLIM-Diff U-Net is strictly single-stream: both FLAIR image and binary lesion mask are concatenated as input. The encoder consists of four downsampling levels with channel widths , each with two residual blocks using GroupNorm (32 groups), with strided convolutions for spatial reduction. Multihead self-attention (32-dim heads) is included at the deepest two scales. A shared bottleneck compresses both modalities, imposing an information choke point to enforce representational coupling.
The decoder mirrors the encoder. Skip connections preserve spatial resolution, and a final convolution predicts both modalities. Conditioning is supplied via:
- Axial + pathology indicator (60 discrete tokens: 30 axial bins {control, lesion}), embedded via a learnable embedding and sinusoidal positional encoding.
- Timestep encoding (sinusoidal + 2-layer MLP).
- Both are injected into ResBlocks via FiLM-style bias modulation after the first convolution.
The total parameter count is 26.9M, a fraction of large-scale DMs.
4. Loss Geometry and Empirical Analysis
SLIM-Diff replaces the canonical -prediction loss with explicit tuning over target and :
- For -prediction, (sub-quadratic) down-weights large residuals (e.g., hyperintense lesion pixels), improving global image realism (measured by KID, LPIPS).
- gives optimal mask boundary sharpness (lowest MMD-MF), as the uniform quadratic penalty preserves subtle geometry.
- over-penalizes outliers, degrading performance.
- Across all , -prediction significantly outperforms -prediction and velocity-based targets in both image and mask metrics.
Table: Summary of best/typical configurations
| Target / | KID | LPIPS | MMD-MF |
|---|---|---|---|
| , | 0.432 | 0.821 | 15.06 |
| , | 0.012 | 0.305 | 1.43 |
| , | 0.034 | 0.310 | 0.95 |
Qualitative samples with exhibit realistic FLAIR contrast with coherent lesion masks; sharpens mask boundaries at the cost of slightly over-smoothed intensities.
No evidence of pixel-wise memorization was found; Kernel Maximum Mean Discrepancy tests confirmed distributional (not copy-based) generation.
5. Training and Sampling Procedures
The network is trained via AdamW (learning rate ), cosine annealing, weight EMA ($0.999$ decay), and early stopping (validation patience 25 epochs). Preprocessing includes MNI registration, skull-stripping (ANTS SyN), N4 bias correction, and percentile normalization. Data is organized at the subject and slice level, with lesion oversampling to counteract class imbalance.
Sampling employs DDIM (300 steps, ), conditioned on desired axial depth and pathology indicator. Both training and inference require explicit condition tokens.
Pseudocode (Algorithm 1—Training):
1 2 3 4 5 6 7 8 9 10 |
for epoch in 1...max_epochs: for (x0, c) in minibatch: t ~ Uniform(1, T) eps ~ N(0, I) xt = sqrt(bar_alpha_t)*x0 + sqrt(1-bar_alpha_t)*eps output = f_theta(xt, t, c) L = ||target(x0, eps, t) - output||_p^p # Backpropagation and parameter update # EMA update # Early stopping check |
Pseudocode (Algorithm 2—Sampling):
1 2 3 4 5 6 |
x_S ~ N(0, I) for s in S...1: t = schedule(s) out = f_theta_bar(x_s, t, c) x_{s-1} = DDIM_update(x_s, out, t, eta) return x0_image, x0_mask |
6. Limitations and Applicability
SLIM-Diff operates on 2D slices, potentially introducing slight inter-slice inconsistencies. No direct comparison is provided against multi-stream and multi-stage joint generative frameworks (such as MedSegFactory or brainSPADE) under matched data conditions. Extension to full 3D or enforcing explicit slice-to-volume consistency is suggested as future work. For clinical deployment, care must be taken to evaluate robustness across lesion types, scanners, and domains; fine-tuning may be necessary.
Memory and computational requirements are modest, supporting real-time slice-level synthesis on standard GPUs.
7. Code and Reproducibility
The official implementation and pretrained models are available at https://github.com/MarioPasc/slim-diff. To reproduce, preprocessing (registration, skull stripping, bias correction), slice extraction, and balanced splitting are required. Training proceeds with configurable --target and --p_norm options; at convergence, the sample.py script enables generation under specified anatomical or pathological conditions.
In summary, SLIM-Diff establishes a methodological foundation for joint image–mask diffusion in data-scarce medical regimes, demonstrating that shared low-capacity architectures with explicit loss-geometry control enable robust synthesis of anatomically faithful images and masks while resisting memorization, and supporting downstream data augmentation for rare-disease imaging research (Pascual-González et al., 3 Feb 2026).