Introduction
Leverage in the computational power and memory of contemporary accelerators has hit a plateau when it comes to LLMs. The sequential nature of the auto-regressive decoding process in LLMs causes this bottleneck, which underutilizes the available computing capabilities of these technological workhorses. Speculative decoding has been introduced to address these inefficiencies. However, a significant roadblock has been the difficulties in deploying draft models that predict a sequence of tokens, which the larger LLMs then refine. This scenario is exactly where the Medusa framework comes into play, offering a straightforward solution to the intricate challenge of accelerating LLM inference.
Medusa Framework
The primary innovation introduced with Medusa is the addition of multiple decoding heads to the backbone LLM, which enables the prediction of multiple subsequent tokens in a parallel fashion. These heads are designed to be fine-tuned, ensuring they are closely aligned with the parent LLM in their predictions. Two distinct procedures have been outlined for integrating these predictive heads: Medusa-1 and Medusa-2. Medusa-1 pertains to a setting where the backbone LLM remains frozen during training, thus ensuring no alteration to its core capabilities while accelerating inference speed. Medusa-2 involves a more resource-intensive fine-tuning where the additional heads are trained together with the backbone LLM, potentially achieving even higher efficiency gains.
Addressing Challenges with Extensions
Several obstacles could impede the Medusa framework's widescale adoption, such as situations lacking sufficient training data. To tackle this, the researchers have designed a self-distillation protocol, which cleverly uses the LLM to generate training data for the Medusa heads. They have also introduced a 'typical acceptance scheme' as an alternative to rejection sampling, used in speculative decoding, to select the most plausible predictions from the Medusa heads. This approach maintains the quality of generation while potentially increasing the rate at which tokens can be accepted during the decoding process.
Experimental Results
In their comprehensive experiments, the researchers assessed Medusa on various model sizes and configurations. The findings are significant – Medusa-1 achieves more than a 2.2 times speedup in LLM inference with no loss in quality, whereas Medusa-2 pushes this further, attaining speed improvements ranging from 2.3 to 3.6 times. Moreover, another key takeaway is that their method can scale across different models and is particularly adept in scenarios with a batch size of one, which happens to represent the use case of hosting LLMs locally for personal applications.
Conclusion
Medusa has set a new precedent for inference acceleration in LLMs without compromising generation quality. Its versatile training approaches cater to diverse computational resource scenarios, and the proposed extensions effectively confront common problems when employing accelerated inference methods. The code for Medusa has been made available to the public, inviting collaborative efforts to further refine and incorporate the framework into different serving systems.