The paper "Learned Token Pruning for Transformers" tackles the significant challenge of high inference costs in transformer models, which scale quadratically with the input sequence length. To mitigate this issue, the authors introduce a novel method called Learned Token Pruning (LTP). This approach aims to make transformers more efficient by adaptively removing tokens deemed unimportant as the input sequence progresses through the transformer layers.
Key Contributions and Methodology:
- Token Pruning based on Attention Scores:
- The central idea of LTP is to prune tokens whose attention scores fall below a threshold. This threshold is not static but learned during the training process for each layer of the transformer.
- By focusing on attention scores rather than using computationally intensive top-k token selection methods, LTP remains efficient and scalable.
- Adaptive Thresholding:
- The method utilizes a learned threshold-based mechanism, allowing the pruned sequence length to vary dynamically depending on the input sequence.
- This adaptive mechanism ensures that the model effectively identifies and retains critical tokens necessary for maintaining performance.
- Performance Evaluation:
- The authors extensively evaluate LTP on the GLUE benchmark suite, showing significant improvements over previous token pruning methods.
- The method achieves up to 2.5% higher accuracy at similar computational costs (measured in FLOPs).
- Efficiency Gains:
- With LTP, there is a notable reduction in FLOPs—up to 2.1 times—while maintaining less than a 1% drop in accuracy.
- These efficiency gains translate into practical benefits, such as up to 1.9 times and 2.0 times improvement in throughput on Intel Haswell CPUs and NVIDIA V100 GPUs, respectively.
- Robustness to Input Length Variations:
- An important attribute of LTP is its robustness to variations in input sentence lengths. This robustness ensures consistent performance gains across different types of input sequences.
- Implementation and Open Source Code:
- The method is implemented in PyTorch, and the authors have open-sourced their code, making it accessible for further research and practical applications.
Implications:
The introduction of Learned Token Pruning represents a significant advancement in making transformer models more feasible for deployment in resource-constrained environments. By adaptively pruning tokens during inference, LTP not only reduces computational costs but also improves overall throughput, making transformers more adaptable and scalable in real-world applications. This method can be particularly beneficial in NLP tasks where input sequences can vary widely in length, offering robust performance while ensuring efficiency.
The work presented in this paper addresses critical bottlenecks in transformer deployment and sets a new benchmark for token pruning techniques.