PolySketchFormer: Fast Transformers via Sketching Polynomial Kernels (2310.01655v3)
Abstract: The quadratic time and memory complexity inherent to self-attention mechanisms, with respect to sequence length, presents a critical computational bottleneck in the training and deployment of large-scale Transformer-based LLMs. Recent theoretical results indicate the intractability of sub-quadratic softmax attention approximation under reasonable complexity assumptions. This paper addresses this challenge by first demonstrating that polynomial attention with high degree can effectively replace softmax without sacrificing model quality. Next, we develop polynomial sketching techniques from numerical linear algebra to achieve linear-time polynomial attention with approximation guarantees. Crucially, our approach achieves this speedup without requiring the sparsification of attention matrices. We also present a block-based algorithm to apply causal masking efficiently. Combining these techniques, we provide \emph{PolySketchFormer}, a practical linear-time Transformer architecture for LLMing that offers provable guarantees. We validate PolySketchFormer empirically by training LLMs capable of handling long contexts. These experiments utilize both synthetic and real-world datasets (PG19, Wikipedia and C4) on Google Cloud TPUs. For context lengths of 32k and GPT-2 style models, our model achieves a 2.5-4x speedup in training compared to FlashAttention, with no observed degradation in quality across our experiments.
- Oblivious sketching of high-degree polynomial kernels. In Proceedings of the Fourteenth Annual ACM-SIAM Symposium on Discrete Algorithms, pp. 141–160. SIAM, 2020.
- Fast attention requires bounded entries. arXiv preprint arXiv:2302.13214, 2023.
- Palm 2 technical report. arXiv preprint arXiv:2305.10403, 2023.
- Adaptive sampled softmax with kernel based sampling. In International Conference on Machine Learning, pp. 590–599. PMLR, 2018.
- Language models are few-shot learners. Advances in neural information processing systems, 33:1877–1901, 2020.
- Rethinking attention with performers. arXiv preprint arXiv:2009.14794, 2020. JAX implementation of Performer is available at https://github.com/google-research/google-research/blob/master/performer/fast_attention/jax/fast_attention.py.
- Palm: Scaling language modeling with pathways. arXiv preprint arXiv:2204.02311, 2022.
- Tri Dao. Flashattention-2: Faster attention with better parallelism and work partitioning. arXiv preprint arXiv:2307.08691, 2023.
- Flashattention: Fast and memory-efficient exact attention with io-awareness. Advances in Neural Information Processing Systems, 35:16344–16359, 2022.
- Language modeling with gated convolutional networks. In International conference on machine learning, pp. 933–941. PMLR, 2017.
- An exploration of softmax alternatives belonging to the spherical loss family. arXiv preprint arXiv:1511.05042, 2015.
- BERT: Pre-training of deep bidirectional transformers for language understanding. In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers), 2019.
- Wiki-40b: Multilingual language model dataset. In Proceedings of the Twelfth Language Resources and Evaluation Conference, pp. 2440–2452, 2020.
- Transformer quality in linear time. In International Conference on Machine Learning, pp. 9099–9117. PMLR, 2022.
- JAX authors. Implementation of FlashAttention in Pallas. https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/attention.py, 2023.
- Transformers are rnns: Fast autoregressive transformers with linear attention. In International conference on machine learning, pp. 5156–5165. PMLR, 2020.
- OpenAI. Gpt-4 technical report, 2023.
- Self-attention does not need O(n2)𝑂superscript𝑛2O(n^{2})italic_O ( italic_n start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) memory. arXiv preprint arXiv:2112.05682, 2021.
- Do transformers need deep long-range memory? In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics, pp. 7524–7529, 2020.
- Compressive transformers for long-range sequence modelling. arXiv preprint, 2019. URL https://arxiv.org/abs/1911.05507.
- Linear transformers are secretly fast weight programmers. In International Conference on Machine Learning, pp. 9355–9366. PMLR, 2021.
- Noam Shazeer. Glu variants improve transformer. arXiv preprint arXiv:2002.05202, 2020.
- Roformer: Enhanced transformer with rotary position embedding. arXiv preprint arXiv:2104.09864, 2021.
- Efficient transformers: A survey. arXiv preprint arXiv:2009.06732, 2022.
- Transformer dissection: a unified understanding of transformer’s attention via the lens of kernel. arXiv preprint arXiv:1908.11775, 2019.
- Attention is all you need. In Advances in Neural Information Processing Systems, 2017.
- Linformer: Self-attention with linear complexity. arXiv preprint arXiv:2006.04768, 2020.