Papers
Topics
Authors
Recent
2000 character limit reached

Efficient MaskFormer with Prototype-based Attention

Updated 12 December 2025
  • The paper introduces a novel prototype-based cross-attention mechanism that reduces computational complexity by collapsing key tokens into representative prototypes.
  • It employs an efficient multi-scale pixel decoder with context-based self-modulation and deformable convolutions to significantly lower FLOPs and latency.
  • PEM achieves competitive segmentation accuracy on datasets like Cityscapes and ADE20K while being resource-efficient, making it ideal for edge deployment.

Prototype-based Efficient MaskFormer (PEM) is a transformer-based architecture for image segmentation designed to unify semantic and panoptic segmentation tasks under a single, computationally-efficient framework. Unlike prior segmenters requiring heavy cross-attention and resource-intensive pixel decoders, PEM introduces a novel prototype-based cross-attention mechanism and a multi-scale decoder leveraging context-based self-modulation and deformable convolutions. These methodological refinements deliver substantial gains in computational efficiency without sacrificing segmentation accuracy, establishing PEM as a high-performing alternative suitable for both research and deployment on edge devices (Cavagnero et al., 29 Feb 2024).

1. Architectural Foundation and Workflow

PEM generalizes the MaskFormer paradigm by integrating two core modules to improve both efficacy and efficiency:

  • Efficient Multi-Scale Pixel Decoder: A fully-convolutional, feature pyramid-based decoder incorporating context-based self-modulation (CSM) and deformable convolutions. CSM injects adaptable global context, while deformable convolutions permit spatially dynamic receptive fields, increasing semantic expressivity at each scale.
  • Prototype-based Masked Cross-Attention (PEM-CA): In the transformer decoder, PEM-CA replaces standard cross-attention with a prototype selection mechanism. It refines a set of NN object queries using only the most representative prototypes, greatly reducing attention computation.

The pipeline proceeds as follows: An input image is processed by a backbone (e.g., ResNet50), then features are decoded via the FPN-augmented pixel decoder for high-resolution, multi-scale representations. These features interact with NN learnable queries in the transformer decoder equipped with PEM-CA, yielding refined mask and class predictions. The use of identical heads and loss functions for semantic and panoptic segmentation provides a unified multi-task model.

2. Prototype-Based Masked Cross-Attention (PEM-CA)

PEM-CA introduces a selective cross-attention mechanism that leverages foreground feature redundancy to reduce computational costs. Let FiRHi×Wi×C\mathbf{F_{i}}\in\mathbb{R}^{H_i\times W_i \times C} be the feature map at scale ii and QinRN×CQ_{\textrm{in}}\in\mathbb{R}^{N\times C} denote the NN learnable object queries:

  1. Key and Query Projection: Features are flattened and projected: K=ProjK(Flatten(Fi))R(HiWi)×DK = \text{Proj}_K(\text{Flatten}(\mathbf{F_{i}})) \in \mathbb{R}^{(H_i W_i)\times D} and Q=ProjQ(Qin)RN×DQ = \text{Proj}_Q(Q_{\textrm{in}}) \in \mathbb{R}^{N\times D}.
  2. Prototype Selection: Similarity scores SR(HiWi)×NS\in\mathbb{R}^{(H_i W_i)\times N} are computed via S=KQTS = K Q^T. With a binary mask M\mathcal{M} (foreground: 0, background: -\infty), maximal foreground matches gjg_j are identified per query: G=argmaxp[Sp,j+Mp,j]G = \arg\max_{p}[S_{p,j} + \mathcal{M}_{p,j}]. Gathered prototypes Kp=K[G]RN×DK_p = K[G] \in \mathbb{R}^{N\times D} represent NN distinct foreground keys.
  3. Attention Computation: Element-wise attention A=(QKp)WAA = (Q \odot K_p) W_A is normalized and modulated: A^=A/A2\hat{A} = A / \|A\|_2, B=α(A^+Kp)B = \alpha \odot (\hat{A} + K_p) (α\alpha learnable), and mapped to output space with a residual: Qout=BWout+QQ_{\textrm{out}} = B W_{\textrm{out}} + Q.

By collapsing keys to NN prototypes instead of HiWiH_i W_i tokens, PEM-CA reduces the cross-attention complexity from O(NKD)O(NK D) to O(ND)O(N D), where K=HiWiK=H_i W_i. Empirically, for Cityscapes at scale F2F_2, this optimization delivers a 2×2\times speedup and halves overall latency (Cavagnero et al., 29 Feb 2024).

3. Multi-Scale Pixel Decoder: Context-Based Self-Modulation and Deformable Convolutions

The PEM pixel decoder operates across four spatial scales. For each:

  • CSM: Fi=Conv1×1(Fi)\mathbf{F'_i} = \text{Conv}_{1\times1}(\mathbf{F_i}) is globally aggregated (Ωi=MLP(GAP(Fi))\Omega_i = \text{MLP}(\text{GAP}(\mathbf{F'_i}))). Modulated features Fic=Fiσ(Ωi)+FiF^c_i = F'_i \odot \sigma(\Omega_i) + F'_i (with sigmoid σ\sigma) inject adaptive global context per channel.
  • Deformable Feature Pyramid: Fused features Fi^\hat{\mathbf{F}_i} are constructed recursively with deformable convolution and bilinear upsamples. For i=4i=4: F4^=DefConv(F4c+Proj(GAP(F4)))\hat{\mathbf{F}_4} = \text{DefConv}(F^c_4 + \text{Proj}(\text{GAP}(\mathbf{F}_4))), and for i=2,3i=2,3: Fi^=DefConv(Fic+Up(Fi+1^))\hat{\mathbf{F}_i} = \text{DefConv}(F^c_i + \text{Up}(\hat{\mathbf{F}_{i+1}})). The full-resolution F1^\hat{\mathbf{F}_1} is directly used for mask prediction.

CSM augments semantic capacity by introducing instance-dependent channel weighting, while deformable convolutions provide spatial adaptivity. Both mechanisms, validated by ablation, are essential for retaining accuracy in the efficient PEM pipeline.

4. Computational Analysis and Efficiency

PEM achieves a significant reduction in floating-point operations (FLOPs) and inference latency compared to standard MaskFormer and Mask2Former architectures. For cross-attention:

  • Standard: 2NKD\approx2 N K D FLOPs.
  • PEM-CA: ND\approx N D FLOPs.

In practice, PEM reduces total pixel-decoder FLOPs from 500\sim500G to $237$G (Cityscapes), with model latency nearly halved. Overall, speedup factors approach 2x, and efficiency gains scale with input resolution and feature-map size.

5. Experimental Setup and Hyperparameters

PEM was evaluated on semantic and panoptic segmentation across the Cityscapes (19 classes) and ADE20K (150 classes) datasets. Backbone architectures include ResNet50 (ImageNet-pretrained) and STDC1/2. Transformer decoder configuration: two stages, C=D=256C=D=256, eight heads, six layers, 100 queries. Pixel decoder hidden dimension is 128. Training employed AdamW optimizer with weight decay 0.05 and cosine learning rate schedules:

  • Cityscapes: lr = 7e47\text{e}^{-4}, batch = 32, 90k iterations.
  • ADE20K: lr = 4e44\text{e}^{-4}, batch = 32, 160k iterations.

Losses are deep-supervised at each decoder block: classification (BCE, weight 2.0) and mask (BCE, weight 5.0; Dice, weight 5.0) (Cavagnero et al., 29 Feb 2024).

6. Quantitative Results and Model Comparison

Performance metrics are benchmarked against Mask2Former and YOSO across both tasks and datasets. The tables below summarize core findings:

Panoptic Segmentation (PQ, Cityscapes and ADE20K)

Model PQ (Cityscapes) FPS FLOPs (G) Params (M) PQ (ADE20K) FPS FLOPs (G) Params (M)
Mask2Former (R50) 62.1 4.1 519 44.0 39.7 19.5 103 44.0
YOSO (R50) 59.7 11.1 265 42.6 38.0 35.4 52 42.0
PEM (R50) 61.1 13.8 237 35.6 38.5 35.7 47 35.6

Semantic Segmentation (mIoU, Cityscapes and ADE20K)

Model mIoU (Cityscapes) FPS FLOPs (G) Params (M) mIoU (ADE20K) FPS FLOPs (G) Params (M)
Mask2Former(R50) 79.4% 6.6 523 44.0 47.2% 21.5 70.1 44.0
YOSO (R50) 79.4% 11.1 268 42.6 44.7% 35.3 37.3 42.0
PEM(R50) 79.9% 13.8 240 35.6 45.5% 35.7 46.9 35.6
PEM(STDC1/2) 78.3–79.0% 22–24 92–118 17–21 39.6–45.0% 36–44 16–19.3 17–21

PEM achieves comparable or superior accuracy to leading, resource-intensive baselines at markedly greater throughput and lower computational footprint. For Cityscapes, PEM attains 79.9% mIoU and 61.1 PQ at 13.8 FPS with just 240G FLOPs and 35.6M parameters when using ResNet50. Notably, faster, lighter backbones (STDC1/2) yield high FPS (up to 24.3), albeit with some accuracy tradeoff.

7. Ablation Studies and Method Component Analysis

Comprehensive ablation studies isolate the contributions of PEM’s architectural components:

  • Attention Mechanism: Full PEM-CA yields 61.1 PQ, with masking and prototype selection together conferring +2.6 and +12.4 PQ over the respective baselines.
  • Pixel Decoder: Removal of CSM or deformable convolutions substantially reduces PQ (to 60.0 and 57.1, respectively) and increases efficiency (latency drops to 72\sim72 ms).
  • Queries and Decoder Layers: Optimal NN is 100, with negligible gains at N=200N=200; increasing decoder layers from 0 to 6 improves PQ (58 to 61), trading off latency.

This suggests that the synergy between prototype selection and contextual modulation is critical for balancing efficiency and segmentation accuracy in PEM.

Summary and Contextual Significance

PEM demonstrates that prototype-based cross-attention and efficient multi-scale decoding are effective strategies to reduce computational demands while maintaining state-of-the-art segmentation accuracy across semantic and panoptic tasks. Its unified pipeline enables deployment on resource-constrained systems and supports rapid inference, challenging the dominance of heavier transformer-based segmenters in multi-task scenarios (Cavagnero et al., 29 Feb 2024). A plausible implication is that PEM’s methodology may influence future transformer architectures for vision tasks, particularly where efficiency and multi-task learning are priorities.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Whiteboard

Follow Topic

Get notified by email when new papers are published related to Prototype-based Efficient MaskFormer (PEM).