DeeBERT: Dynamic Early Exiting for Accelerating BERT Inference
(2004.12993v1)
Published 27 Apr 2020 in cs.CL and cs.LG
Abstract: Large-scale pre-trained LLMs such as BERT have brought significant improvements to NLP applications. However, they are also notorious for being slow in inference, which makes them difficult to deploy in real-time applications. We propose a simple but effective method, DeeBERT, to accelerate BERT inference. Our approach allows samples to exit earlier without passing through the entire model. Experiments show that DeeBERT is able to save up to ~40% inference time with minimal degradation in model quality. Further analyses show different behaviors in the BERT transformer layers and also reveal their redundancy. Our work provides new ideas to efficiently apply deep transformer-based models to downstream tasks. Code is available at https://github.com/castorini/DeeBERT.
The paper "DeeBERT: Dynamic Early Exiting for Accelerating BERT Inference" (Xin et al., 2020) introduces a method to speed up inference for large transformer-based LLMs like BERT and RoBERTa by allowing "easier" input samples to exit the network early. This is particularly useful for deploying these models in latency-sensitive applications.
The core idea of DeeBERT is to add additional classification layers, termed "off-ramps," after each transformer layer in the pre-trained model. The original classification layer at the end of the network becomes the final off-ramp.
Here's how DeeBERT works in practice:
Model Architecture: A standard BERT or RoBERTa model is taken, consisting of an embedding layer followed by n transformer layers. n classification off-ramps are added, one after each transformer layer. Each off-ramp takes the hidden states from its preceding transformer layer and outputs a prediction probability distribution over the classes of the downstream task.
Fine-Tuning: The fine-tuning process on a specific downstream dataset is modified into two stages:
Stage 1: The entire model (embeddings, all transformer layers, and the last off-ramp) is fine-tuned end-to-end using the loss from only the final off-ramp. This is identical to standard BERT fine-tuning and ensures the final layer retains the optimal performance of the original model.
Stage 2: The parameters of the embedding and transformer layers (trained in Stage 1) are frozen. Only the intermediate off-ramps (layers 1 to n−1) are trained. The loss for this stage is the sum of the losses from all intermediate off-ramps. This stage trains the intermediate off-ramps to make predictions based on the representations available at earlier layers without disrupting the learned weights of the core transformer layers.
Inference: During inference for an input sample x, the following procedure (Algorithm 1 in the paper) is followed:
The sample passes through the embedding layer and the first transformer layer.
The output hidden state from the first transformer layer is fed into the first off-ramp, which produces a probability distribution z1.
The entropy of z1 is calculated.
If entropy(z1) is below a pre-defined threshold S, the sample's inference stops, and z1 is returned as the prediction.
If the entropy is not below S, the sample proceeds to the second transformer layer.
This process repeats for each subsequent layer i (from 1 to n). The sample goes through transformer layer i, the output is passed to off-ramp i, and the entropy of the output distribution zi is checked against S.
If entropy(zi)<S, the prediction zi is returned, and the sample exits.
If the sample reaches the final off-ramp (i=n) without exiting early, its output zn is returned unconditionally.
The entropy threshold S is a crucial hyperparameter that controls the trade-off between inference speed and model accuracy. A higher S means the model is less "strict" about confidence, allowing more samples to exit earlier, resulting in faster inference but potentially lower accuracy. A lower S requires higher confidence for early exit, leading to fewer early exits, slower inference, but typically higher accuracy, closer to the baseline model's performance. The optimal S is typically chosen by evaluating performance and speed trade-offs on a development set.
Practical Implementation Considerations:
Adding Off-Ramps: This involves adding linear classification layers with output dimensions equal to the number of classes in the downstream task. These layers take the hidden states from the output of each transformer block as input.
Training Loop Modification: The standard fine-tuning loop needs to be adapted to perform the two distinct stages of training. The first stage is standard, while the second requires freezing certain parameters and computing losses from multiple output heads simultaneously.
Inference Loop Modification: The standard sequential pass through layers needs to be wrapped in a loop that includes the entropy calculation and the conditional early exit logic based on the threshold S.
Threshold Selection: Choosing the right S is key. This involves experimenting with different S values on a development set and plotting the speed-accuracy curve (as shown in Figure 3 of the paper) to identify a point that meets the specific application's requirements.
Computational Overhead: The added off-ramps are simple linear layers, adding negligible computation compared to the transformer layers. The main computational saving comes from skipping subsequent transformer layers. The entropy calculation is also very fast.
Performance and Findings:
DeeBERT achieves significant inference time savings (up to ~40-47%) on GLUE tasks with minimal (or sometimes no) performance degradation compared to the original BERT/RoBERTa baseline. Larger savings are possible with acceptable drops in accuracy (2-4%).
Compared to methods like DistilBERT or LayerDrop, DeeBERT offers a flexible speed-accuracy trade-off at inference time from a single fine-tuned model, rather than producing a fixed-size smaller model. It also avoids the need for further pre-training, relying only on fine-tuning.
Analysis shows that samples considered "easier" by the model (those that can be predicted with high confidence) tend to exit earlier.
Layer-wise analysis reveals differences between BERT and RoBERTa and hints at redundancy in later layers, especially in larger models (BERT-large, RoBERTa-large), where the accuracy gain plateaus or even slightly drops in the final layers. This redundancy is what enables early exiting without significant performance loss for many samples.
In essence, DeeBERT provides a practical method to dynamically adapt the computational cost of BERT/RoBERTa inference based on the complexity of the input sample and the desired speed-accuracy trade-off, making these powerful models more viable for real-time applications.