Introduction
Linear attention mechanisms within Transformers propose the exciting potential to replace traditional softmax attention, which has a quadratic computational complexity with respect to the sequence length, with linear complexity alternatives. Despite these efficiency benefits, previously devised linear attentions often resulted in a substantially reduced model quality when compared to their softmax attention counterparts.
Bridging the Performance Gap
Identifying the crucial elements of softmax attention that linear variants lack, such as low-entropy weight distributions and dot-product monotonicity, the paper introduces an innovative approach. By utilizing trainable single-layer MLPs (multi-layer perceptrons) as feature maps, the proposed method—dubbed Hedgehog—achieves a high-performance linear attention that closely mirrors the qualities of softmax attention, specifically its capability to produce "spiky" and monotonic weights. Hedgehog's approach not only preserves linear computational complexity, but it also demonstrates excellent performance across several regimes, including training from scratch and finetuning.
Empirical Validation
Numerous experiments validate the effectiveness of Hedgehog, showcasing impressive performance that surpasses prior linear attention formulations. In training-from-scratch scenarios, Hedgehog demonstrates its prowess on standard benchmarks such as Long Range Arena (LRA) tasks and WikiText-103 LLMing, significantly closing the performance gap by 68.6% on the latter. In the finetuned-conversion and pretrained-conversion settings, Hedgehog consistently recovers over 99% of the original standard Transformer quality on tasks like Wikipedia text and the GLUE benchmark, convincingly outpacing prior linear attentions by substantial margins, with improvements up to 6 perplexity points and 8.7 GLUE score points, respectively.
Contributions and Scalability
The paper's method presents a compelling case for the practicality and scalability of linear attentions in Transformers, including state-of-the-art results for subquadratic models of a similar size after converting pretrained GPT-2 and significant improvements on the SAMSum summarization task using a scaled-up pretrained Llama-2 7B model. Notably, Hedgehog's attention preserves fidelity with increased sequence lengths and transfers effectively to new tasks, evidencing its adaptability and generalization capability. The findings suggest that by effectively mimicking softmax attention, it's possible to achieve near-equivalent performance with linear complexity, offering a blend of efficiency and expressivity previously unachieved by past linear attentions.