Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
41 tokens/sec
GPT-4o
59 tokens/sec
Gemini 2.5 Pro Pro
41 tokens/sec
o3 Pro
7 tokens/sec
GPT-4.1 Pro
50 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Learned Token Pruning for Transformers (2107.00910v3)

Published 2 Jul 2021 in cs.CL
Learned Token Pruning for Transformers

Abstract: Deploying transformer models in practice is challenging due to their inference cost, which scales quadratically with input sequence length. To address this, we present a novel Learned Token Pruning (LTP) method which adaptively removes unimportant tokens as an input sequence passes through transformer layers. In particular, LTP prunes tokens with an attention score below a threshold value which is learned for each layer during training. Our threshold-based method allows the length of the pruned sequence to vary adaptively based on the input sequence, and avoids algorithmically expensive operations such as top-k token selection. We extensively test the performance of LTP on GLUE tasks and show that our method outperforms the prior state-of-the-art token pruning methods by up to ~2.5% higher accuracy with the same amount of FLOPs. In particular, LTP achieves up to 2.1x FLOPs reduction with less than 1% accuracy drop, which results in up to 1.9x and 2.0x throughput improvement on Intel Haswell CPUs and NVIDIA V100 GPUs, respectively. Furthermore, we demonstrate that LTP is more robust than prior methods to variations on input sentence lengths. Our code has been developed in PyTorch and has been open-sourced.

The paper "Learned Token Pruning for Transformers" tackles the significant challenge of high inference costs in transformer models, which scale quadratically with the input sequence length. To mitigate this issue, the authors introduce a novel method called Learned Token Pruning (LTP). This approach aims to make transformers more efficient by adaptively removing tokens deemed unimportant as the input sequence progresses through the transformer layers.

Key Contributions and Methodology:

  1. Token Pruning based on Attention Scores:
    • The central idea of LTP is to prune tokens whose attention scores fall below a threshold. This threshold is not static but learned during the training process for each layer of the transformer.
    • By focusing on attention scores rather than using computationally intensive top-k token selection methods, LTP remains efficient and scalable.
  2. Adaptive Thresholding:
    • The method utilizes a learned threshold-based mechanism, allowing the pruned sequence length to vary dynamically depending on the input sequence.
    • This adaptive mechanism ensures that the model effectively identifies and retains critical tokens necessary for maintaining performance.
  3. Performance Evaluation:
    • The authors extensively evaluate LTP on the GLUE benchmark suite, showing significant improvements over previous token pruning methods.
    • The method achieves up to 2.5% higher accuracy at similar computational costs (measured in FLOPs).
  4. Efficiency Gains:
    • With LTP, there is a notable reduction in FLOPs—up to 2.1 times—while maintaining less than a 1% drop in accuracy.
    • These efficiency gains translate into practical benefits, such as up to 1.9 times and 2.0 times improvement in throughput on Intel Haswell CPUs and NVIDIA V100 GPUs, respectively.
  5. Robustness to Input Length Variations:
    • An important attribute of LTP is its robustness to variations in input sentence lengths. This robustness ensures consistent performance gains across different types of input sequences.
  6. Implementation and Open Source Code:
    • The method is implemented in PyTorch, and the authors have open-sourced their code, making it accessible for further research and practical applications.

Implications:

The introduction of Learned Token Pruning represents a significant advancement in making transformer models more feasible for deployment in resource-constrained environments. By adaptively pruning tokens during inference, LTP not only reduces computational costs but also improves overall throughput, making transformers more adaptable and scalable in real-world applications. This method can be particularly beneficial in NLP tasks where input sequences can vary widely in length, offering robust performance while ensuring efficiency.

The work presented in this paper addresses critical bottlenecks in transformer deployment and sets a new benchmark for token pruning techniques.

User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (7)
  1. Sehoon Kim (30 papers)
  2. Sheng Shen (68 papers)
  3. David Thorsley (4 papers)
  4. Amir Gholami (60 papers)
  5. Woosuk Kwon (9 papers)
  6. Joseph Hassoun (7 papers)
  7. Kurt Keutzer (199 papers)
Citations (124)
Github Logo Streamline Icon: https://streamlinehq.com