Efficient MaskFormer with Prototype-based Attention
- The paper introduces a novel prototype-based cross-attention mechanism that reduces computational complexity by collapsing key tokens into representative prototypes.
- It employs an efficient multi-scale pixel decoder with context-based self-modulation and deformable convolutions to significantly lower FLOPs and latency.
- PEM achieves competitive segmentation accuracy on datasets like Cityscapes and ADE20K while being resource-efficient, making it ideal for edge deployment.
Prototype-based Efficient MaskFormer (PEM) is a transformer-based architecture for image segmentation designed to unify semantic and panoptic segmentation tasks under a single, computationally-efficient framework. Unlike prior segmenters requiring heavy cross-attention and resource-intensive pixel decoders, PEM introduces a novel prototype-based cross-attention mechanism and a multi-scale decoder leveraging context-based self-modulation and deformable convolutions. These methodological refinements deliver substantial gains in computational efficiency without sacrificing segmentation accuracy, establishing PEM as a high-performing alternative suitable for both research and deployment on edge devices (Cavagnero et al., 29 Feb 2024).
1. Architectural Foundation and Workflow
PEM generalizes the MaskFormer paradigm by integrating two core modules to improve both efficacy and efficiency:
- Efficient Multi-Scale Pixel Decoder: A fully-convolutional, feature pyramid-based decoder incorporating context-based self-modulation (CSM) and deformable convolutions. CSM injects adaptable global context, while deformable convolutions permit spatially dynamic receptive fields, increasing semantic expressivity at each scale.
- Prototype-based Masked Cross-Attention (PEM-CA): In the transformer decoder, PEM-CA replaces standard cross-attention with a prototype selection mechanism. It refines a set of object queries using only the most representative prototypes, greatly reducing attention computation.
The pipeline proceeds as follows: An input image is processed by a backbone (e.g., ResNet50), then features are decoded via the FPN-augmented pixel decoder for high-resolution, multi-scale representations. These features interact with learnable queries in the transformer decoder equipped with PEM-CA, yielding refined mask and class predictions. The use of identical heads and loss functions for semantic and panoptic segmentation provides a unified multi-task model.
2. Prototype-Based Masked Cross-Attention (PEM-CA)
PEM-CA introduces a selective cross-attention mechanism that leverages foreground feature redundancy to reduce computational costs. Let be the feature map at scale and denote the learnable object queries:
- Key and Query Projection: Features are flattened and projected: and .
- Prototype Selection: Similarity scores are computed via . With a binary mask (foreground: 0, background: ), maximal foreground matches are identified per query: . Gathered prototypes represent distinct foreground keys.
- Attention Computation: Element-wise attention is normalized and modulated: , ( learnable), and mapped to output space with a residual: .
By collapsing keys to prototypes instead of tokens, PEM-CA reduces the cross-attention complexity from to , where . Empirically, for Cityscapes at scale , this optimization delivers a speedup and halves overall latency (Cavagnero et al., 29 Feb 2024).
3. Multi-Scale Pixel Decoder: Context-Based Self-Modulation and Deformable Convolutions
The PEM pixel decoder operates across four spatial scales. For each:
- CSM: is globally aggregated (). Modulated features (with sigmoid ) inject adaptive global context per channel.
- Deformable Feature Pyramid: Fused features are constructed recursively with deformable convolution and bilinear upsamples. For : , and for : . The full-resolution is directly used for mask prediction.
CSM augments semantic capacity by introducing instance-dependent channel weighting, while deformable convolutions provide spatial adaptivity. Both mechanisms, validated by ablation, are essential for retaining accuracy in the efficient PEM pipeline.
4. Computational Analysis and Efficiency
PEM achieves a significant reduction in floating-point operations (FLOPs) and inference latency compared to standard MaskFormer and Mask2Former architectures. For cross-attention:
- Standard: FLOPs.
- PEM-CA: FLOPs.
In practice, PEM reduces total pixel-decoder FLOPs from G to $237$G (Cityscapes), with model latency nearly halved. Overall, speedup factors approach 2x, and efficiency gains scale with input resolution and feature-map size.
5. Experimental Setup and Hyperparameters
PEM was evaluated on semantic and panoptic segmentation across the Cityscapes (19 classes) and ADE20K (150 classes) datasets. Backbone architectures include ResNet50 (ImageNet-pretrained) and STDC1/2. Transformer decoder configuration: two stages, , eight heads, six layers, 100 queries. Pixel decoder hidden dimension is 128. Training employed AdamW optimizer with weight decay 0.05 and cosine learning rate schedules:
- Cityscapes: lr = , batch = 32, 90k iterations.
- ADE20K: lr = , batch = 32, 160k iterations.
Losses are deep-supervised at each decoder block: classification (BCE, weight 2.0) and mask (BCE, weight 5.0; Dice, weight 5.0) (Cavagnero et al., 29 Feb 2024).
6. Quantitative Results and Model Comparison
Performance metrics are benchmarked against Mask2Former and YOSO across both tasks and datasets. The tables below summarize core findings:
Panoptic Segmentation (PQ, Cityscapes and ADE20K)
| Model | PQ (Cityscapes) | FPS | FLOPs (G) | Params (M) | PQ (ADE20K) | FPS | FLOPs (G) | Params (M) |
|---|---|---|---|---|---|---|---|---|
| Mask2Former (R50) | 62.1 | 4.1 | 519 | 44.0 | 39.7 | 19.5 | 103 | 44.0 |
| YOSO (R50) | 59.7 | 11.1 | 265 | 42.6 | 38.0 | 35.4 | 52 | 42.0 |
| PEM (R50) | 61.1 | 13.8 | 237 | 35.6 | 38.5 | 35.7 | 47 | 35.6 |
Semantic Segmentation (mIoU, Cityscapes and ADE20K)
| Model | mIoU (Cityscapes) | FPS | FLOPs (G) | Params (M) | mIoU (ADE20K) | FPS | FLOPs (G) | Params (M) |
|---|---|---|---|---|---|---|---|---|
| Mask2Former(R50) | 79.4% | 6.6 | 523 | 44.0 | 47.2% | 21.5 | 70.1 | 44.0 |
| YOSO (R50) | 79.4% | 11.1 | 268 | 42.6 | 44.7% | 35.3 | 37.3 | 42.0 |
| PEM(R50) | 79.9% | 13.8 | 240 | 35.6 | 45.5% | 35.7 | 46.9 | 35.6 |
| PEM(STDC1/2) | 78.3–79.0% | 22–24 | 92–118 | 17–21 | 39.6–45.0% | 36–44 | 16–19.3 | 17–21 |
PEM achieves comparable or superior accuracy to leading, resource-intensive baselines at markedly greater throughput and lower computational footprint. For Cityscapes, PEM attains 79.9% mIoU and 61.1 PQ at 13.8 FPS with just 240G FLOPs and 35.6M parameters when using ResNet50. Notably, faster, lighter backbones (STDC1/2) yield high FPS (up to 24.3), albeit with some accuracy tradeoff.
7. Ablation Studies and Method Component Analysis
Comprehensive ablation studies isolate the contributions of PEM’s architectural components:
- Attention Mechanism: Full PEM-CA yields 61.1 PQ, with masking and prototype selection together conferring +2.6 and +12.4 PQ over the respective baselines.
- Pixel Decoder: Removal of CSM or deformable convolutions substantially reduces PQ (to 60.0 and 57.1, respectively) and increases efficiency (latency drops to ms).
- Queries and Decoder Layers: Optimal is 100, with negligible gains at ; increasing decoder layers from 0 to 6 improves PQ (58 to 61), trading off latency.
This suggests that the synergy between prototype selection and contextual modulation is critical for balancing efficiency and segmentation accuracy in PEM.
Summary and Contextual Significance
PEM demonstrates that prototype-based cross-attention and efficient multi-scale decoding are effective strategies to reduce computational demands while maintaining state-of-the-art segmentation accuracy across semantic and panoptic tasks. Its unified pipeline enables deployment on resource-constrained systems and supports rapid inference, challenging the dominance of heavier transformer-based segmenters in multi-task scenarios (Cavagnero et al., 29 Feb 2024). A plausible implication is that PEM’s methodology may influence future transformer architectures for vision tasks, particularly where efficiency and multi-task learning are priorities.