SBM-Transformer: Data-Adaptive Sparse Attention
- SBM-Transformer is a Transformer variant that uses mixed-membership stochastic block models to generate dynamic, sparse attention masks, reducing quadratic computation.
- It employs a straight-through estimator for differentiable stochastic sampling, enabling gradient-based training of input-dependent attention mechanisms.
- Empirical results on benchmarks like LRA and GLUE show competitive accuracy with significantly lower memory and computational requirements.
The SBM-Transformer is a variant of the Transformer architecture that replaces expensive dense self-attention with a data-adaptive, learnable sparse attention mechanism parameterized by Stochastic Block Models (SBMs). Instead of attending to all possible query-key pairs, each attention head infers a sparse bipartite graph over the input tokens, drastically reducing computation while retaining expressiveness. The SBM-Transformer achieves this by learning both latent cluster assignments and inter-cluster affinities, generating input-dependent attention masks whose topology adapts to the sequence. The use of a straight-through estimator (STE) renders the stochastic edge-sampling differentiable for gradient-based learning. This attention mechanism has been shown to provide competitive or superior accuracy to full and other sparse attention variants on both synthetic and standard benchmarks, while realizing considerable reductions in computation and memory (Cho et al., 2022). The SBM-Transformer paradigm has also been adopted in other architectural contexts, such as code summarization over ASTs, confirming its broad applicability (Oh et al., 2024).
1. Data-Adaptive Attention via Mixed-Membership Stochastic Block Models
In SBM-Transformer, each multi-head attention head is paired with a mixed-membership SBM. For input feature matrix , each head forms queries , keys , and values as in the standard Transformer:
Each head introduces a learnable cluster embedding matrix for latent clusters, and a small MLP . Nodes (tokens) receive soft cluster assignments:
The inter-cluster affinity matrix is
yielding a probability of attention (edge) between query and key :
The actual attention mask is sampled (e.g., using fastRG, cost). The mask is then used to select active query-key pairs for attention.
2. Sparse Masked Attention and Computational Advantages
Given the binary attention mask , the SBM-Transformer computes attention only along sampled edges. The masked attention operation is
where and is the softmax applied only where (others are ). This yields complexity , where is the number of active edges, substantially lower than for dense attention. The number is data-adaptive and input-dependent; typically or in practice.
3. Differentiable Sampling and Training via the Straight-Through Estimator
Attention mask sampling is non-differentiable. The SBM-Transformer handles this with the straight-through estimator (STE): in the backward pass, gradients are taken with respect to the expectation , treating as if it were continuous. For a sampled edge with mask , gradients accumulate as:
An exploration enhancement is adopted to avoid permanently dropping edges, which ensures sufficient exploration during training.
4. Theoretical Expressiveness and Universality
The SBM-Transformer is a universal approximator of sequence-to-sequence functions in expectation. For any continuous sequence function , an SBM-Transformer can be constructed such that:
for any and , given sufficient model capacity. This follows from constructing SBMs that realize various graph structures (block-diagonal, global-relay) and concatenating Hamiltonian paths, guaranteeing path-wise and global connectivity requirements for universal sequence modeling (Cho et al., 2022).
5. Empirical Performance
Extensive benchmarks in both synthetic and real-world scenarios demonstrate the competitiveness of SBM-Transformer. On Long Range Arena (LRA) tasks (sequence lengths up to 4K), SBM-Transformer achieves accuracy equal to or surpassing full-attention Transformers while using only 20–30% of the possible edges in attention masks. For instance:
| Model | ListOps | Text | Retrieval | Image | Pathfinder | Avg. |
|---|---|---|---|---|---|---|
| Full-attention | 37.22% | 64.93% | 79.55% | 40.38% | 74.26% | 59.27% |
| SBM-Transformer | 37.45% | 65.79% | 80.00% | 41.31% | 75.12% | 59.93% |
| (Mask Density) | (20.1%) | (26.1%) | (29.5%) | (20.5%) | (18.6%) |
On GLUE, in a BERT-style setup, SBM-Transformer matches or exceeds dense and other sparse variants with 13.5% average density, e.g., F1 on SST-2: 89.8 (SBM), 89.8 (full attention). FLOP reductions of and comparable or reduced peak memory are consistently observed.
Ablations confirm the adaptive sparsity: SBM-Transformer densifies masks in harder instances and learns to specialize attention, in contrast to hand-crafted or uniform sparsifiers (Cho et al., 2022).
6. Application to Tree-Structured and Code Data
The SBM attention mechanism has been adapted for Transformers operating on Abstract Syntax Trees (ASTs) in code summarization tasks. In this context, each node–node pair is assigned an attention probability via learned node-cluster and cluster-cluster affinity matrices:
where and are node–cluster dot products, and is the symmetric block-affinity matrix. Sampling and the STE enable dynamic data-adaptive masks during training. Empirically, SBM attention in the CSA-Trans encoder yields improved summarization accuracy (BLEU-4 increase by 0.38–0.43), with 10–40% reductions in backward time and peak memory over standard and graph-based variants. Notably, attention heatmaps are found to be sparser and more interpretable, preserving non-local relationships discarded by fixed AST sparsity (Oh et al., 2024).
7. Implementation, Hyperparameters, and Limitations
The reference SBM-Transformer implementation defaults to dense operations with masked entries due to limited support for unstructured sparsity in standard deep learning libraries. A pure sparse graph-attention implementation would reveal the full computational benefits. Cluster count is typically set to 128. Task-dependent configurations (layer/head counts, embedding sizes) mirror those in common Transformer experiments.
Limitations include the lack of hardware-optimized sparse kernel support, which prevents realizing maximal speedups. Potential extensions identified include degree-corrected and hierarchical SBMs, dynamic adjustment of cluster count , and integration with block-sparse attentions.
SBM-Transformer introduces a principled, universal, and data-adaptive sparse attention mechanism, leveraging mixed-membership Stochastic Block Models for computational efficiency and increased interpretability, with demonstrated performance gains across language, vision, and structured code domains (Cho et al., 2022, Oh et al., 2024).