Patchwise Stochastic Attention (PSAL)
- Patchwise Stochastic Attention (PSAL) is a deep learning mechanism that combines patch-based reasoning with stochastic components to achieve scalability and efficiency.
- It employs either sparse approximate nearest neighbor search or Gaussian embeddings to reduce memory demands while quantifying uncertainty at the patch level.
- Empirical results show PSAL excels in high-resolution tasks like colorization, inpainting, and super-resolution, while offering improved robustness and calibration.
Patchwise Stochastic Attention (PSAL) refers to a collection of attention mechanisms in deep learning architectures that combine non-local, patch-based reasoning with stochastic components. PSAL approaches enable scalable, memory-efficient attention across high-resolution inputs and, in some variants, incorporate uncertainty quantification via stochastic representations. The term PSAL encompasses two distinct families: (1) Patch-based Stochastic Attention Layers employing stochastic nearest-neighbor search for sparse, scalable attention (Cherel et al., 2022), and (2) Patchwise Stochastic Attention mechanisms that represent patches as probabilistic distributions, enabling attention computations informed by measures such as the Wasserstein distance (Erick et al., 2023).
1. Mathematical Formulations
1.1 Patch-Based Stochastic Attention Layer (Sparse Approximate NN-Based)
Given an input feature map , patches of size are extracted via a linear operator :
where is the number of patches and .
Attention is classically computed using a dense score matrix . PSAL restricts attention to a sparse set of approximate nearest neighbors per query , forming a sparse score matrix 0:
1
with 2 computed as a row-wise softmax restricted to 3 and the output 4.
1.2 Distribution-Based Patchwise Stochastic Attention
Each patch 5 is embedded as a Gaussian:
6
where 7 and 8. The 2-Wasserstein distance between two Gaussians 9 is
0
Attention logits are given by 1, normalized with softmax.
2. Stochastic and Approximate Attention Mechanisms
Sparse PSAL realizes non-local attention by leveraging the PatchMatch algorithm to identify 2 approximate nearest neighbors per query patch, reducing computational and memory requirements from 3 to 4 (Cherel et al., 2022). PatchMatch proceeds iteratively via:
- Random initialization of neighbors.
- Alternating propagation (using jump-flooding) and random search steps.
- Updates with best found candidates, maintaining a heap of top-5 matches.
In the Gaussian PSAL variant, stochasticity is introduced at the representational level: image patches are modeled as elliptical Gaussians, thereby encoding uncertainty at the patch level. This enables explicit modeling of patch uncertainty and facilitates distance-aware attention.
3. Differentiability and Aggregation Strategies
Differentiability in sparse PSAL is achieved for 6 by applying the softmax over the restricted neighbor set 7. Gradients propagate through attention weights directly. Two aggregation strategies ensure effective learning:
- PSAL-k: Top-8 matches are maintained per query, yielding 9 memory.
- PSAL-Agg: Supports are augmented with neighbors of spatial neighbors (up to 0 per query), improving match robustness at 1 memory. Both approaches maintain end-to-end differentiability for network training (Cherel et al., 2022).
4. Integration into Deep Architectures
PSAL is implemented as a drop-in replacement for full dot-product or non-local attention layers in CNNs or ViTs. Integration procedure:
- Project feature maps to query, key, value triplets via 1x1 convolutions.
- Extract patches to form 2 matrices.
- Employ PSAL to compute the sparse (or Wasserstein-based) attention output.
- Aggregate outputs and scatter to the pixel or feature locations.
- Add residual connections and proceed with convolutional or transformer blocks (Cherel et al., 2022, Erick et al., 2023).
In Wasserstein-based PSAL built on ViT-B and the data2vec self-supervised framework, transformer blocks are augmented with projection layers 3 for mean and variance, and the multi-head Wasserstein self-attention mechanism replaces conventional dot-product attention. There is a memory and compute overhead (e.g., 49.3h for 300 epochs on CIFAR-100 with 8xA40 GPUs versus 4h for baseline) (Erick et al., 2023).
5. Regularization and Loss Formulations
Distribution-based PSAL introduces Wasserstein-based regularization terms:
- Generic loss: 5
- Pre-training: 6
- Fine-tuning: Cross-entropy loss combined with positive/negative Wasserstein contrastive terms.
This enforces that similar patches (e.g., masked/unmasked) are close under the Wasserstein geometry and enhances uncertainty-awareness in the SSL representation (Erick et al., 2023).
6. Empirical Results and Evaluation
6.1 Sparse NN-Based PSAL
- Guided colorization (256×256 input): ℓ₂ error for PSAL-3 is 0.00228, PSAL-Agg is 0.00194, outperforming full attention (subsampled), Performer, Reformer, LinearAttention, and LocalAttention in both error and memory (0.2–0.7 GB vs 2–10 GB) (Cherel et al., 2022).
- Image inpainting (Places2 validation): PSAL-3 achieves ℓ₁≈11.6%, ℓ₂≈3.6%, PSNR 16.6 dB, SSIM 54.1%, comparable or superior to prior attention mechanisms, enabling high-resolution inpainting (up to 3300×3300) on 11 GB GPUs.
- Super-resolution (Urban100): For zoom×2, PSAL achieves 33.375 dB vs full 33.383 dB; for zoom×4, 27.184 dB vs 27.288 dB.
6.2 Wasserstein PSAL in SSL and ViT
- CIFAR-100 In-distribution: 69.42% Top-1 accuracy, NLL 1.223, ECE 0.445 (comparable to baseline).
- OOD Detection (CIFAR-100→CIFAR-10): AUROC 0.629 vs baseline 0.519.
- Corruption Robustness (CIFAR-100-C): mCE 0.487 vs baseline 0.506.
- Semi-supervised (10% labels): 57.86% accuracy vs 56.98% (baseline); ECE 0.462 vs 0.466.
Key takeaways:
- PSAL yields lower expected calibration error (ECE), stronger OOD and corruption robustness, with minimal accuracy tradeoff.
- Wasserstein-driven stochastic attention leverages the Gaussian parameterization to quantify patch-level uncertainty and enables distance-aware learning (Erick et al., 2023).
7. Comparative Discussion and Limitations
Advantages
| Variant | Core Strengths | Limiting Factors |
|---|---|---|
| PSAL-NN (Sparse) | Linear memory in 7, genuine non-local attention at high resolution, differentiable (8), pluggable | Potential for local artifacts from mis-matched NNs, requires careful tuning of parameters |
| PSAL-Gaussian (Wass.) | Uncertainty quantification, distance-aware representations, improved calibration and robustness | Model/data needs, overhead in compute/memory |
Sparse PSAL approaches enable practical, high-resolution, non-local attention by orders-of-magnitude memory reduction (e.g., from 250 GB to under 5 GB for a 512×512 feature map). Wasserstein PSAL exploits stochastic Gaussian representations to inject uncertainty-awareness throughout training and inference (Cherel et al., 2022, Erick et al., 2023).
Potential limitations include sensitivity to approximate NN assignment in sparse PSAL (sometimes mitigated by aggregation), and computational overhead from stochastic operations and Wasserstein regularization in distribution-based PSAL. Teacher-forced or aggregation strategies are necessary for robust learning in the sparse setting (Cherel et al., 2022).
References
- Patch-Based Stochastic Attention for Image Editing (Cherel et al., 2022)
- Stochastic Vision Transformers with Wasserstein Distance-Aware Attention (Erick et al., 2023)