Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
38 tokens/sec
GPT-4o
59 tokens/sec
Gemini 2.5 Pro Pro
41 tokens/sec
o3 Pro
7 tokens/sec
GPT-4.1 Pro
50 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Striped Attention: Faster Ring Attention for Causal Transformers (2311.09431v1)

Published 15 Nov 2023 in cs.LG and cs.CL

Abstract: To help address the growing demand for ever-longer sequence lengths in transformer models, Liu et al. recently proposed Ring Attention, an exact attention algorithm capable of overcoming per-device memory bottle- necks by distributing self-attention across multiple devices. In this paper, we study the performance characteristics of Ring Attention in the important special case of causal transformer models, and identify a key workload imbal- ance due to triangular structure of causal attention computations. We propose a simple extension to Ring Attention, which we call Striped Attention to fix this imbalance. Instead of devices having contiguous subsequences, each device has a subset of tokens distributed uniformly throughout the sequence, which we demonstrate leads to more even workloads. In experiments running Striped Attention on A100 GPUs and TPUv4s, we are able to achieve up to 1.45x end-to-end throughput improvements over the original Ring Attention algorithm on causal transformer training at a sequence length of 256k. Furthermore, on 16 TPUv4 chips, we were able to achieve 1.65x speedups at sequence lengths of 786k. We release the code for our experiments as open source

Analysis of Striped Attention: A Performance Advancement in Distributed Attention Mechanisms for Causal Transformers

This paper presents "Striped Attention," an enhancement to the "Ring Attention" approach, targeting distributed attention computations in causal transformers. The paper specifically addresses the problem of imbalanced workloads in Ring Attention, a notable inefficiency when dealing with causal self-attention due to its triangular dependencies. The authors propose Striped Attention, which redistributes workloads across devices more evenly, thereby increasing throughput considerably during training of sequence lengths in transformers.

Causal transformers, such as popular generative models like GPT and Llama, employ causal self-attention to restrict each token's attention calculation to preceding tokens. This dependency introduces computational opportunities but also inefficiencies in prior distributed attention frameworks. The Ring Attention algorithm, which divides the attention workload across multiple processors in a ring topology, does not exploit the causal structure effectively, leading to uneven device workloads where some computations are fully unmasked and others are entirely masked.

Striped Attention solves this by altering how tokens are assigned to devices. Instead of assigning contiguous subsequences, tokens are distributed using a modulo operation on their indices. Each device thus manages tokens distributed uniformly throughout the sequence. This configuration ensures that each device's workload is consistently balanced, maximizing mask-based computation skipping in the context of causal attention.

In experimental evaluations, Striped Attention showed significant speedups over Ring Attention. Utilizing A100 GPUs and TPUv4 chips, it achieved up to 1.45x throughput improvements at sequence lengths of 256k and 1.65x speedups at 786k on TPUs. Such improvements underscore the successful rebalancing of workloads and more effective use of causal masking to reduce unnecessary computations.

Theoretical and practical implications of these findings are substantial. With the barriers lowered for processing extremely long sequences efficiently, Striped Attention holds promise for expanding the capabilities of causal transformers in real-time applications and model training. The detail-oriented partitioning of workloads serves not only to increase efficiency but does so while maintaining the exactitude of attention calculations — a critical advantage over approximative methods.

Looking forward, the presented solution opens new avenues for further enhancements in parallelized computation for neural networks. The decomposition strategy could inspire explorations of other attention forms, possibly adapting to different hardware configurations or intertwining with model architectures that demand sophisticated attention arenas. Future work will likely focus on perfecting this implementation using more granular attention computations like FlashAttention, providing higher fidelity and breaking barriers in long-sequence processing efficiency.

Overall, Striped Attention contributes significantly to the field of distributed machine learning systems, offering both theoretical insights and practical gains. As distributed machine learning advances, the community will benefit greatly from this work, making long-context models computationally feasible on existing hardware architectures without compromising model accuracy.

User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (7)
  1. William Brandon (6 papers)
  2. Aniruddha Nrusimha (8 papers)
  3. Kevin Qian (3 papers)
  4. Zachary Ankner (10 papers)
  5. Tian Jin (24 papers)
  6. Zhiye Song (1 paper)
  7. Jonathan Ragan-Kelley (28 papers)
Citations (21)