Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
38 tokens/sec
GPT-4o
59 tokens/sec
Gemini 2.5 Pro Pro
41 tokens/sec
o3 Pro
7 tokens/sec
GPT-4.1 Pro
50 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

BurstAttention: An Efficient Distributed Attention Framework for Extremely Long Sequences (2403.09347v4)

Published 14 Mar 2024 in cs.DC and cs.LG

Abstract: Effective attention modules have played a crucial role in the success of Transformer-based LLMs, but the quadratic time and memory complexities of these attention modules also pose a challenge when processing long sequences. One potential solution for the long sequence problem is to utilize distributed clusters to parallelize the computation of attention modules across multiple devices (e.g., GPUs). However, adopting a distributed approach inevitably introduces extra memory overheads to store local attention results and incurs additional communication costs to aggregate local results into global ones. In this paper, we propose a distributed attention framework named ``BurstAttention'' to optimize memory access and communication operations at both the global cluster and local device levels. In our experiments, we compare BurstAttention with other competitive distributed attention solutions for long sequence processing. The experimental results under different length settings demonstrate that BurstAttention offers significant advantages for processing long sequences compared with these competitive baselines, reducing 40% communication overheads and achieving 1.37 X speedup during training 128K sequence length on 32 X A100.

An Overview of BurstAttention: Addressing Long Sequence Processing in Distributed LLM Architectures

Introduction

The research paper titled "BurstAttention: An Efficient Distributed Attention Framework for Extremely Long Sequences" presents a novel approach to overcoming the computational inefficiencies of attention mechanisms in Transformer-based LLMs when processing extremely long sequences. The authors propose BurstAttention, a distributed attention framework designed to optimize memory access and communication operations across distributed computing clusters.

Problem Statement

Transformer architectures, despite their undeniable success in shaping the landscape of LLMs, are plagued by the quadratic time and memory complexities of their attention modules, posing significant challenges when dealing with long sequences. Existing solutions like FlashAttention and RingAttention provided improvements, yet each tackled separate bottlenecks, and their applicability in a distributed setting remained constrained due to memory overheads and communication costs.

Methodology

BurstAttention integrates and enhances concepts from previous methods, aiming to leverage both the distributed cluster capabilities and the single-device efficiencies. The framework undertakes a two-step partitioning strategy:

  1. Inter-Device Partitioning: Sequences are divided across multiple devices (e.g., GPUs) such that only local attention calculations are performed at each device, substantially reducing memory usage.
  2. Intra-Device Partitioning: Further splits subsequences into smaller tiles within each device to harness the high-speed SRAM, thus minimizing dependence on slower high-bandwidth memory, optimizing local attention computation.

The framework introduces:

  • Global Attention Optimization (GAO): Eschews high memory overhead by dynamically accumulating local results in lieu of storing them persistently by employing online softmax, which helps manage global aggregation effectively.
  • Local Attention Optimization (LAO): Utilizes SRAM's bandwidth to expedite block-wise computations within local attention scopes and exploits data buffers to overlap communication with computation processes.

Results

Through comprehensive experimentation, BurstAttention achieves significant improvements over existing distributed attention solutions across varying sequence lengths and model sizes. For instance, the proposed framework claimed a 1.37x speedup and reduced communication overheads by 40% during the training of sequences 128K in length on nodes with 32 A100 GPUs, when compared to tensor parallelism coupled with FlashAttention.

Inference Latency: BurstAttention effectively reduced first-token latency in LLaMA models and supported longer sequences compared to competitors, proving more efficient in practical applications where long sequences are common.

Training Performance: The method exhibited nearly 2.0x speedup relative to baselines for sequences beyond 128K, without sacrificing per-unit performance, owing to efficient memory management and overlapping computations.

Implications and Future Work

Practical Implications: The reduction in computational overhead paves the way for real-time AI applications like chatbots and language generation systems, where rapid processing of extensive user input is crucial.

Theoretical Implications: This research contributes to the discourse on distributed computing frameworks in machine learning, illustrating efficient partitioning techniques and optimization strategies applicable beyond just LLMs.

Future Work: Future developments could explore integrations of BurstAttention with various sparse attention mechanisms while analyzing the trade-offs between efficiency and computational accuracy. Expanding BurstAttention's applicability to other domains necessitating efficient long-sequence processing could also be a compelling avenue.

Conclusion

BurstAttention presents a robust framework effectively tackling the complexities imposed by long-sequence processing in transformer models at scale. By minimizing communication and memory overheads through innovative partitioning and optimization strategies, it offers both theoretical insights and practical enhancements for evolving LLM architectures.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (30)
  1. PaLM 2 technical report. arXiv preprint arXiv:2305.10403, 2023.
  2. On the opportunities and risks of foundation models. arXiv preprint arXiv:2108.07258, 2021.
  3. Language models are few-shot learners. In Proceedings of NeurIPS, pp.  1877–1901, 2020.
  4. Training deep nets with sublinear memory cost. arXiv preprint arXiv:1604.06174, 2016.
  5. Rethinking attention with performers. arXiv preprint arXiv:2009.14794, 2020.
  6. PaLM: Scaling language modeling with pathways. arXiv preprint arXiv:2204.02311, 2022.
  7. FlashAttention: Fast and memory-efficient exact attention with io-awareness. In Proceedings of NeurIPS, pp.  16344–16359, 2022.
  8. LongNet: Scaling transformers to 1,000,000,000 tokens. arXiv preprint arXiv:2307.02486, 2023.
  9. Pre-trained models: Past, present and future. AI Open, 2:225–250, 2021.
  10. GPipe: efficient training of giant neural networks using pipeline parallelism. In Proceedings of NuerIPS, pp.  103–112, 2019.
  11. Perceiver: General perception with iterative attention. In Proceedings of ICML, pp.  4651–4664, 2021.
  12. Transformers are RNNs: Fast autoregressive transformers with linear attention. In Proceedings of ICML, pp.  5156–5165, 2020.
  13. Reducing activation recomputation in large transformer models. In Proceedings of MLSYS, 2023.
  14. Set Transformer: A framework for attention-based permutation-invariant neural networks. In Proceedings of ICML, pp.  3744–3753, 2019.
  15. Sequence parallelism: Long sequence training from system perspective. arXiv preprint arXiv:2105.13120, 2021.
  16. Online normalizer calculation for softmax. arXiv preprint arXiv:1805.02867, 2018.
  17. Efficient large-scale language model training on gpu clusters using Megatron-LM. In Proceedings of SC, 2021.
  18. Training language models to follow instructions with human feedback. pp.  27730–27744, 2022.
  19. The devil in linear transformer. In Proceedings of EMNLP, pp.  7025–7041, 2022.
  20. Self-attention does not need o⁢(n2)𝑜superscript𝑛2o(n^{2})italic_o ( italic_n start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) memory. arXiv preprint arXiv:2112.05682, 2021.
  21. Exploring the limits of transfer learning with a unified text-to-text transformer. The Journal of Machine Learning Research, 21:5485–5551, 2020.
  22. ZeRO: Memory optimizations toward training trillion parameter models. In Proceedings of SC, 2020.
  23. ZeRO-Offload: Democratizing billion-scale model training. In Proceedings of ATC, pp.  551–564, 2021.
  24. LLaMA: Open and efficient foundation language models. arXiv preprint arXiv:2302.13971, 2023a.
  25. LLaMA 2: Open foundation and fine-tuned chat models. arXiv preprint arXiv:2307.09288, 2023b.
  26. Valiant, L. G. A bridging model for parallel computation. Communications of the ACM, pp.  103–111, 1990.
  27. Attention is all you need. In Proceedings of NeurIPS, 2017.
  28. Linformer: Self-attention with linear complexity. arXiv preprint arXiv:2006.04768, 2020.
  29. Lightweight and efficient end-to-end speech recognition using low-rank transformer. In Proceedings of ICASSP, pp.  6144–6148, 2020.
  30. A survey of large language models. arXiv preprint arXiv:2303.18223, 2023.
User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (7)
  1. Weilin Zhao (22 papers)
  2. Xu Han (270 papers)
  3. Cheng Yang (168 papers)
  4. Zhiyuan Liu (433 papers)
  5. Chuan Shi (92 papers)
  6. Maosong Sun (337 papers)
  7. Ao Sun (53 papers)
Citations (5)