Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
38 tokens/sec
GPT-4o
59 tokens/sec
Gemini 2.5 Pro Pro
41 tokens/sec
o3 Pro
7 tokens/sec
GPT-4.1 Pro
50 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

INT-FlashAttention: Enabling Flash Attention for INT8 Quantization (2409.16997v2)

Published 25 Sep 2024 in cs.LG and cs.AI
INT-FlashAttention: Enabling Flash Attention for INT8 Quantization

Abstract: As the foundation of LLMs, self-attention module faces the challenge of quadratic time and memory complexity with respect to sequence length. FlashAttention accelerates attention computation and reduces its memory usage by leveraging the GPU memory hierarchy. A promising research direction is to integrate FlashAttention with quantization methods. This paper introduces INT-FlashAttention, the first INT8 quantization architecture compatible with the forward workflow of FlashAttention, which significantly improves the inference speed of FlashAttention on Ampere GPUs. We implement our INT-FlashAttention prototype with fully INT8 activations and general matrix-multiplication (GEMM) kernels, making it the first attention operator with fully INT8 input. As a general token-level post-training quantization framework, INT-FlashAttention is also compatible with other data formats like INT4, etc. Experimental results show INT-FlashAttention achieves 72% faster inference speed and 82% smaller quantization error compared to standard FlashAttention with FP16 and FP8 data format.

INT-FlashAttention: Enabling Flash Attention for INT8 Quantization

The paper titled "INT-FlashAttention: Enabling Flash Attention for INT8 Quantization" focuses on integrating INT8 quantization with FlashAttention to enhance the inference speed and accuracy of LLMs. The authors introduce a novel architecture, INT-FlashAttention, specifically designed for Ampere GPUs, which lack support for the FP8 data format integral to newer Hopper GPUs.

Background

LLMs like GPT and Llama leverage self-attention mechanisms to capture dependencies within sequences. However, the computational intricacies of self-attention involve quadratic time and memory complexity, which complicates application to longer sequences. FlashAttention addresses this by utilizing a tiling strategy and the GPU memory hierarchy to optimize memory usage, reducing it from quadratic to linear w.r.t. sequence length.

Quantization further enhances LLM performance by reducing the bit size of data formats, thereby minimizing memory usage and computational demands. Current hardware supports various formats, including FP16, FP8, and INT8, with quantization techniques broadly categorized into training-phase and post-training approaches. This paper's focus on post-training quantization aims to make FlashAttention feasible for INT8 data formats, particularly on Ampere GPUs.

INT-FlashAttention Architecture

INT-FlashAttention expands FlashAttention’s capabilities by implementing fully INT8 quantized matrices – Query (Q), Key (K), and Value (V) – and using INT8 general matrix-multiplication (GEMM) kernels. Employing a token-level quantization strategy enhances speed and accuracy compared to traditional tensor-level methods.

The framework maintains INT8 quantized inputs stored in high-bandwidth memory (HBM). During the forward pass, the INT8 GEMM operations are employed, preserving token-level information and leveraging integer operations to ensure computational efficiency. Key matrices are scaled adaptively using predefined scalers, allowing precise and computationally efficient updates within the INT8 domain. The architecture’s integration into the FlashAttention workflow manifests in considerable empirical performance enhancements.

Experimental Results

The authors conducted performance evaluations on an NVIDIA RTX4090 GPU, contrasting INT-FlashAttention against FlashAttention implemented with FP16 and FP8 formats.

Inference Speed:

  • INT-FlashAttention demonstrated a 72% improvement in inference speed over FlashAttention with FP16 data and comparable performance to FP8 data format FlashAttention.

Quantization Accuracy:

  • By evaluating mean relative error (MRE), INT-FlashAttention showed significant error reductions. Specifically, it achieved 46% and 82% smaller quantization errors under normal and uniform distributions of activations, respectively, compared to FP8 FlashAttention.

Implications and Future Directions

The introduction of INT-FlashAttention presents several implications for future research and practical deployment:

  1. Efficiency in Data Centers: INT-FlashAttention’s compatibility with the prevalent Ampere GPUs offers immediate benefits in terms of computational efficiency and energy usage, pertinent for large-scale deployments in data centers.
  2. Enhanced Model Scaling: The reductions in memory and computational requirements facilitate the scaling of LLMs to longer sequences, potentially improving the robustness and versatility of such models in real-world applications.
  3. Future Research: Future work includes refining the quantization of the V matrix on a per-block, instead of a tensor-level, basis. This may further enhance accuracy and harness additional computational savings. Additionally, combining INT-FlashAttention with Hadamard transformations could open new avenues for optimizing inference processes.
  4. Open-Source Contribution: The open-sourcing of the codebase ensures transparency and encourages further innovation within the research community, fostering collaborative improvement and adaptation of the INT-FlashAttention architecture.

Conclusion

INT-FlashAttention represents a significant methodological advance in enhancing the inference efficiency and accuracy of LLMs, particularly tailored for INT8 quantization on Ampere GPUs. Its adoption could lead to more efficient large-scale AI implementations, with significant theoretical and practical implications for the future of AI research and deployment.

User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (9)
  1. Shimao Chen (4 papers)
  2. Zirui Liu (58 papers)
  3. Zhiying Wu (8 papers)
  4. Ce Zheng (45 papers)
  5. Peizhuang Cong (4 papers)
  6. Zihan Jiang (19 papers)
  7. Lei Su (46 papers)
  8. Tong Yang (153 papers)
  9. Yuhan Wu (32 papers)

HackerNews