Agglomerative Attention in Transformers
- Agglomerative Attention is an attention mechanism that clusters tokens into fixed classes, replacing quadratic dot-product attention with a linear, cluster-based procedure.
- It employs soft class assignments and per-class summaries to aggregate information efficiently, significantly cutting memory and computational costs.
- Empirical results demonstrate that when paired with convolutional encoding, agglomerative attention achieves competitive language modeling performance with improved scalability.
Agglomerative Attention is an attention mechanism designed for transformer-based neural networks that replaces the standard quadratic-scaling dot-product multi-head attention with a cluster-and-summarize procedure, reducing both memory and time complexity to linear in sequence length for fixed configuration. This mechanism achieves this through soft class assignments that partition input sequence tokens into a fixed number of classes, computes class-wise summaries, and outputs parameter-mixed representations for each query. Despite forfeiting the fine-grained pairwise attention typical of full attention, agglomerative attention delivers comparable performance on language modeling tasks, particularly when paired with convolutional encoding, while providing substantial computational and memory efficiency improvements (Spellings, 2019).
1. Mathematical Formulation
Let be the reference sequence and the query sequence, with and as sequence lengths and the model width. In self-attention settings (, ), the method introduces a fixed number of classes (analogous to heads), along with learned linear classification projections and for reference and query assignments, respectively.
For each reference token, the soft assignment to class 0 is: 1 and analogously for each query token: 2
Each class 3 is associated with a value-projection 4. The class-wise summary (agglomeration) in the full attention setting is: 5 For masked (causal) attention, the sums run up to each query index 6.
The query gathers the per-class summaries by weighted concatenation: 7 The output for each token is: 8
2. Algorithmic Procedure
The agglomerative attention layer consists of the following major stages:
- Soft Class Assignment: Compute reference and query assignments via softmaxed linear projections.
- Per-Class Projection: Each input token is projected with each value-projection matrix 9.
- Class-wise Agglomeration: Aggregate token projections into class-wise summaries, with support for both global and causally-masked aggregation.
- Query-Guided Mixture: For every query position, combine class summaries by weighting with the query’s class assignment vector, then concatenate and linearly mix to yield final output.
Vectorization enables these computations to be performed efficiently using matrix operations.
3. Computational and Memory Complexity
Agglomerative attention achieves a fundamental efficiency gain compared to standard dot-product attention. Core complexity aspects include:
- Assignment logits: 0 to compute both reference and query assignments.
- Per-class projections: 1, or 2 if 3 is small.
- Agglomeration: 4 in the full attention setting.
- Final Mixing: 5, typically realized as a bias-free residual feed-forward operation.
Aggregating these, the overall complexity is 6 to 7. In contrast, conventional dot-product attention exhibits 8 scaling due to pairwise operations and demands 9 memory per attention head. Agglomerative attention needs only 0 for assignments and 1 for mixing, supporting practical deployments on longer sequences (Spellings, 2019).
4. Integration into Transformer Architectures
Agglomerative attention can be directly substituted for standard multi-head self-attention within transformer-like architectures without further architectural modification. Each occurrence of multi-head attention is replaced by an AgglomerativeAttention layer, retaining all other sub-layers, including feed-forward layers, normalization, and positional encodings. Two distinct classification projections (2) enable directed reference-to-query flow. Models are trained with standard generative pre-training protocols, using the ADADELTA optimizer and early stopping based on validation loss. For character-level modeling, a causal convolution front-end (filter width 8) can be adopted, replacing the embedding layer for improved token-level inductive bias.
5. Empirical Performance and Scaling
Extensive evaluation demonstrates the computational advantage and competitive accuracy of agglomerative attention. Single-layer sequence length scaling on CPU shows that wall-clock time increases linearly with sequence length for agglomerative attention, with crossover at 3 tokens, after which it substantially outperforms full attention in speed.
Language modeling results on text8 (character-level) and WikiText-2 (word-level) exhibit the following:
| Attention | Encoding | # Params | Test BPC / PPL | Epoch Time (s) |
|---|---|---|---|---|
| Full | Embedding | 64K/1.5M | 2.271 / 122.0 | 75 / 41 |
| Full | Convolution | 89K/1.6M | 2.177 / 134.5 | 81 / 44 |
| Agglomerative | Embedding | 57K/1.47M | 2.520 / 134.0 | 54 / 31 |
| Agglomerative | Convolution | 81K/1.58M | 2.183 / 132.6 | 57 / 34 |
Agglomerative attention in convolutional contexts virtually matches the accuracy of full attention (difference of 0.006 BPC on text8, 2.2 perplexity on WikiText-2) while offering ~25–30% faster epoch times (Spellings, 2019).
6. Hyperparameter Analysis and Design Choices
Empirical ablations indicate several notable observations:
- Sequence encoding: Omission of a convolutional front-end incurs a significant accuracy penalty at the character level.
- Number of Classes (4): Not exhaustively optimized; 5 should be treated as a tunable hyperparameter, balancing representational coarseness against computational and memory costs.
- Masking: Masked (causal) attention modifies only the normalization step. For autoregressive language modeling, this is required for proper predictive sequence behavior.
- Width vs Scaling: The linear complexity allows model width 6 to be increased for a given computational budget, partially offsetting the loss of attention granularity.
A plausible implication is that optimal performance hinges on problem-appropriate selection of 7 and input encoding strategies.
7. Summary and Contextualization
Agglomerative Attention constitutes a principled, efficient alternative to dot-product self-attention, exchanging fine-scale token interactions for a soft clustering and summary mechanism. Its major advantages are:
- 8 runtime and 9 memory, improving feasibility for long-sequence tasks.
- Near-equivalent accuracy to dot-product attention in language modeling benchmarks, especially with convolutional token encoding.
- Enables wider and larger models within fixed computational budgets, benefiting scalability for practical deployments.
This mechanism is suitable as a drop-in replacement where long sequence processing or compute constraints are a concern, with potential for further gains through domain-specific hyperparameter tuning and architectural adaptation (Spellings, 2019).