SpAtten: Co-Design for Sparse Attention
- SpAtten is an end-to-end co-design framework that employs cascade token and head pruning to reduce the O(N²) complexity in Transformer attention.
- It leverages progressive quantization and a custom top-k hardware accelerator to minimize data movement and computation costs.
- The approach achieves substantial speedups and energy savings by dynamically optimizing precision and pruning non-essential tokens and heads.
SpAtten is an end-to-end algorithm-architecture co-design framework targeting efficient sparse attention in Transformer-style self-attention, focusing on layerwise dynamic token and head pruning, together with progressive quantization and a tailored hardware accelerator architecture. The approach addresses the quadratic computational and memory complexity of attention, specifically the scaling with input sequence length and the low arithmetic intensity, which are bottlenecks for both compute- and memory-bound regimes in natural language processing tasks.
1. Motivation and Design Principles
Self-attention incurs work and DRAM traffic, particularly pronounced during sequence generation (as in GPT-2), leading to a memory-bound workload on conventional CPUs and GPUs. In BERT-style summarization—primarily compute-bound—overhead persists due to intricate data movement (split, transpose, concatenate) that general-purpose hardware inadequately supports. To address the "memory-compute gap," SpAtten targets simultaneous reductions in compute, data movement, and bit-width via:
- Cascade Token Pruning: Layerwise elimination of unimportant tokens, reducing effective sequence length .
- Cascade Head Pruning: Dynamic selection of relevant attention heads, decreasing total head count .
- Progressive Quantization: Two-stage quantization of queries, keys, and values (Q/K/V), initially using only the most significant bits and selectively increasing precision as dictated by softmax distribution sharpness.
- Custom Accelerator: A pipelined hardware design that avoids fetching or computing on pruned data and dynamically adapts to precision requirements, utilizing on-chip buffers and high-throughput top- engines.
This unified strategy aims to minimize end-to-end memory and compute demands while maintaining model accuracy (Wang et al., 2020).
2. Cascade Token Pruning Mechanism
Token Importance Scoring
For each layer , SpAtten computes a cumulative importance score for K/V tokens ():
The importance accumulator is updated:
%%%%10%%%%
This summation provides a global measure across heads and query positions, capturing aggregate token influence.
Pruning Schedule and On-the-Fly Top-k
A global token-keep ratio prescribes the number of active tokens at layer , using an interpolated schedule . Once a token is pruned at any layer, it is permanently removed. Selection utilizes a custom quick-select hardware engine emitting highest-scoring tokens in a streaming manner without full sorting.
Algorithmic Summary
At each layer:
- Compute attention probabilities over all heads, queries, and tokens.
- Update cumulative importance scores .
- Use the accelerator's top- engine to select surviving tokens.
- Shrink K/V (and Q if applicable) tensors accordingly.
This enables a staged ("cascade") reduction in sequence width, with early layers keeping more tokens and later layers more aggressively pruning.
3. Layerwise Cascade Head Pruning
Head Importance
After each head ’s output is computed, a global score is accumulated:
The magnitude signifies contribution to the block output. Heads are pruned according to a scheduled per-layer "head-keep" ratio analogous to the token regimen.
Complexity Reduction
Attention computation per layer is reduced from to , with further savings in downstream two-layer feed-forward (FFN) blocks due to reduced token counts.
4. Hardware-Aware Top-k Selection
Pipeline and Optimizations
The hardware top- engine streams score arrays into a quick-select core using random pivots (via LFSR), partitions input into less-than/greater-than FIFOs, identifies the k-th threshold, and compacts the surviving indices. It leverages:
- 16 parallel comparators per cycle.
- Double-buffered FIFOs (64 entries/stage) to hide DRAM latency.
- Multi-stage pipelining for sustained throughput.
This achieves average latency with minimal passes (), maintaining streaming data order and avoiding global shuffles.
5. Progressive Quantization for Q/K/V
Two-Stage Quantization Process
Q, K, and V vectors in DRAM are stored as concatenated most-significant bit (MSB, ) and least-significant bit (LSB, ) planes (, typical):
- First stage: Fetch only -bit MSBs; compute attention and softmax in reduced precision.
- Dynamic confidence test: If rowwise (e.g., , indicating a "flat" distribution and possible quantization error), fetch -bit LSBs, reconstruct Q/K, and recompute.
- Otherwise: Retain coarse computation output for "peaked" distributions.
Error Propagation
Quantization errors in softmax translate as: (for ), and otherwise, with aggregate error . This ensures error vanishes for confident () softmax outputs.
6. Accelerator Architecture and Dataflow
Block Structure
- HBM2 DRAM (16×128-bit@2GHz)
- Q/K/V fetchers, programmable crossbars, address FIFOs
- Double-buffered key/value SRAM, 512×12-bit multiplier arrays, floating-point softmax units
- Progressive quantization control
- Token/head top- engines, accumulation modules, controller sequencer
Computation Pipeline
Per layer:
- Load Q.
- Apply token-top- on ; generate addresses of surviving K_j.
- Fetch MSB-only K_j into key SRAM.
- Attend and softmax.
- Verify quantization confidence; fetch LSB if necessary.
- Update .
- Prob×V for attention output E_h.
- Update for head importance.
- After all queries/heads, apply head-top-.
- Prune Q/K/V chunks.
- Residual, layer-norm, FFN handled on GPU or shared multipliers.
Architecture maximizes pipeline utilization through double buffering, crossbars, and coarse-to-fine bitwidth conversion modules.
7. Performance Results and Comparisons
Empirical Outcomes
SpAtten evaluated on 30 NLP benchmarks shows:
- 10× average DRAM access reduction (no accuracy loss)
- 2.1× overall computation reduction (1.9× via token pruning, 1.1× via head pruning)
- End-to-end speedup: 162× over TITAN Xp GPU, 347× over Intel Xeon, 1095× over Jetson Nano, 5071× over Raspberry Pi
- Energy savings: 1193× (TITAN), 4059× (Xeon), 406× (Nano), 1910× (Pi) (Wang et al., 2020)
Comparative Summary with MNNFast and A³
| Feature | MNNFast | A³ | SpAtten |
|---|---|---|---|
| Token Pruning | local only | per-head local | global cascade |
| Head Pruning | — | — | cascade on-the-fly |
| Quantization | fixed | fixed | progressive MSB→LSB |
| DRAM Reduction | none | none | 10× |
| GOP/s | 120 | 221 | 360 |
| Energy Eff. (GOP/J) | 120 | 269 | 382 |
| Area Eff. (GOP/s/mm²) | — | 106 | 238 |
Both MNNFast and A³ must fetch all Q/K/V before pruning, yielding no DRAM savings and only accelerating compute-bound BERT. SpAtten’s cascade approach yields global savings in both self-attention and downstream FFN layers and is applicable to generative models (e.g., GPT-2). Progressive quantization closely approaches HBM bandwidth limits in memory-bound settings.
Roofline Analysis
For BERT (high arithmetic intensity), SpAtten achieves 1.61 TFLOPS against a 2 TFLOPS hardware roof, compared to ~0.45 TFLOPS on an NVIDIA GTX1080Ti. For GPT-2 (low intensity), SpAtten approaches the 512 GB/s HBM bandwidth ceiling, while the GPU is limited by memory efficiency.
SpAtten exemplifies a synergy of algorithmic sparsification and dynamic quantization, tightly integrated with a streaming, pipelined accelerator design. The result is significant reductions in DRAM traffic, computational load, and energy consumption for Transformer attention, establishing a new efficiency benchmark in hardware/software co-design for deep sequence models (Wang et al., 2020).