Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
41 tokens/sec
GPT-4o
59 tokens/sec
Gemini 2.5 Pro Pro
41 tokens/sec
o3 Pro
7 tokens/sec
GPT-4.1 Pro
50 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads (2401.10774v3)

Published 19 Jan 2024 in cs.LG and cs.CL
Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads

Abstract: LLMs employ auto-regressive decoding that requires sequential computation, with each step reliant on the previous one's output. This creates a bottleneck as each step necessitates moving the full model parameters from High-Bandwidth Memory (HBM) to the accelerator's cache. While methods such as speculative decoding have been suggested to address this issue, their implementation is impeded by the challenges associated with acquiring and maintaining a separate draft model. In this paper, we present Medusa, an efficient method that augments LLM inference by adding extra decoding heads to predict multiple subsequent tokens in parallel. Using a tree-based attention mechanism, Medusa constructs multiple candidate continuations and verifies them simultaneously in each decoding step. By leveraging parallel processing, Medusa substantially reduces the number of decoding steps required. We present two levels of fine-tuning procedures for Medusa to meet the needs of different use cases: Medusa-1: Medusa is directly fine-tuned on top of a frozen backbone LLM, enabling lossless inference acceleration. Medusa-2: Medusa is fine-tuned together with the backbone LLM, enabling better prediction accuracy of Medusa heads and higher speedup but needing a special training recipe that preserves the backbone model's capabilities. Moreover, we propose several extensions that improve or expand the utility of Medusa, including a self-distillation to handle situations where no training data is available and a typical acceptance scheme to boost the acceptance rate while maintaining generation quality. We evaluate Medusa on models of various sizes and training procedures. Our experiments demonstrate that Medusa-1 can achieve over 2.2x speedup without compromising generation quality, while Medusa-2 further improves the speedup to 2.3-3.6x.

Introduction

Leverage in the computational power and memory of contemporary accelerators has hit a plateau when it comes to LLMs. The sequential nature of the auto-regressive decoding process in LLMs causes this bottleneck, which underutilizes the available computing capabilities of these technological workhorses. Speculative decoding has been introduced to address these inefficiencies. However, a significant roadblock has been the difficulties in deploying draft models that predict a sequence of tokens, which the larger LLMs then refine. This scenario is exactly where the Medusa framework comes into play, offering a straightforward solution to the intricate challenge of accelerating LLM inference.

Medusa Framework

The primary innovation introduced with Medusa is the addition of multiple decoding heads to the backbone LLM, which enables the prediction of multiple subsequent tokens in a parallel fashion. These heads are designed to be fine-tuned, ensuring they are closely aligned with the parent LLM in their predictions. Two distinct procedures have been outlined for integrating these predictive heads: Medusa-1 and Medusa-2. Medusa-1 pertains to a setting where the backbone LLM remains frozen during training, thus ensuring no alteration to its core capabilities while accelerating inference speed. Medusa-2 involves a more resource-intensive fine-tuning where the additional heads are trained together with the backbone LLM, potentially achieving even higher efficiency gains.

Addressing Challenges with Extensions

Several obstacles could impede the Medusa framework's widescale adoption, such as situations lacking sufficient training data. To tackle this, the researchers have designed a self-distillation protocol, which cleverly uses the LLM to generate training data for the Medusa heads. They have also introduced a 'typical acceptance scheme' as an alternative to rejection sampling, used in speculative decoding, to select the most plausible predictions from the Medusa heads. This approach maintains the quality of generation while potentially increasing the rate at which tokens can be accepted during the decoding process.

Experimental Results

In their comprehensive experiments, the researchers assessed Medusa on various model sizes and configurations. The findings are significant – Medusa-1 achieves more than a 2.2 times speedup in LLM inference with no loss in quality, whereas Medusa-2 pushes this further, attaining speed improvements ranging from 2.3 to 3.6 times. Moreover, another key takeaway is that their method can scale across different models and is particularly adept in scenarios with a batch size of one, which happens to represent the use case of hosting LLMs locally for personal applications.

Conclusion

Medusa has set a new precedent for inference acceleration in LLMs without compromising generation quality. Its versatile training approaches cater to diverse computational resource scenarios, and the proposed extensions effectively confront common problems when employing accelerated inference methods. The code for Medusa has been made available to the public, inviting collaborative efforts to further refine and incorporate the framework into different serving systems.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (51)
  1. Gqa: Training generalized multi-query transformer models from multi-head checkpoints. arXiv preprint arXiv:2305.13245, 2023.
  2. Axolotl. Axolotl. https://github.com/OpenAccess-AI-Collective/axolotl, 2023.
  3. {MIROSTAT}: A {neural} {text} {decoding} {algorithm} {that} {directly} {controls} {perplexity}. In International Conference on Learning Representations, 2021. URL https://openreview.net/forum?id=W1G1JZEIy5_.
  4. Language models are few-shot learners. Advances in neural information processing systems, 33:1877–1901, 2020.
  5. Accelerating large language model decoding with speculative sampling. February 2023. doi: 10.48550/ARXIV.2302.01318.
  6. Vicuna: An open-source chatbot impressing gpt-4 with 90%* chatgpt quality, March 2023. URL https://lmsys.org/blog/2023-03-30-vicuna/.
  7. Palm: Scaling language modeling with pathways. arXiv preprint arXiv:2204.02311, 2022.
  8. 8-bit optimizers via block-wise quantization. International Conference on Learning Representations, 2021.
  9. Llm. int8 (): 8-bit matrix multiplication for transformers at scale. arXiv preprint arXiv:2208.07339, 2022.
  10. Qlora: Efficient finetuning of quantized llms. arXiv preprint arXiv:2305.14314, 2023.
  11. Enhancing chat language models by scaling high-quality instructional conversations, 2023.
  12. Alpacafarm: A simulation framework for methods that learn from human feedback, 2023.
  13. Sigmoid-weighted linear units for neural network function approximation in reinforcement learning. Neural Networks, 2017. doi: 10.1016/j.neunet.2017.12.012.
  14. Hierarchical neural story generation. In Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). Association for Computational Linguistics, 2018. doi: 10.18653/v1/p18-1082.
  15. Gptq: Accurate post-training quantization for generative pre-trained transformers. arXiv preprint arXiv:2210.17323, 2022.
  16. Breaking the sequential dependency of llm inference using lookahead decoding, November 2023. URL https://lmsys.org/blog/2023-11-21-lookahead-decoding/.
  17. Google. Palm 2 technical report, 2023. URL https://ai.google/static/documents/palm2techreport.pdf.
  18. Rest: Retrieval-based speculative decoding. arXiv preprint arXiv: 2311.08252, 2023.
  19. Truncation sampling as language model desmoothing. October 2022. doi: 10.48550/ARXIV.2210.15191.
  20. Training compute-optimal large language models. arXiv preprint arXiv:2203.15556, 2022.
  21. The curious case of neural text degeneration. In International Conference on Learning Representations, 2020. URL https://openreview.net/forum?id=rygGQyrFvH.
  22. Lora: Low-rank adaptation of large language models. ICLR, 2021.
  23. Joao Gante. Assisted generation: a new direction toward low-latency text generation, 2023. URL https://huggingface.co/blog/assisted-generation.
  24. Squeezellm: Dense-and-sparse quantization. arXiv preprint arXiv:2306.07629, 2023.
  25. Sequence-level knowledge distillation. EMNLP, 2016.
  26. Fine-tuning can distort pretrained features and underperform out-of-distribution. International Conference on Learning Representations, 2022.
  27. Efficient memory management for large language model serving with pagedattention. In Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles, 2023.
  28. Fast inference from transformers via speculative decoding. November 2022. doi: 10.48550/ARXIV.2211.17192.
  29. Awq: Activation-aware weight quantization for llm compression and acceleration. arXiv preprint arXiv:2306.00978, 2023.
  30. Online speculative decoding. arXiv preprint arXiv: 2310.07177, 2023.
  31. On the probability-quality paradox in language generation. March 2022. doi: 10.48550/ARXIV.2203.17217.
  32. Locally typical sampling. Transactions of the Association for Computational Linguistics, 11:102–121, 2023.
  33. Specinfer: Accelerating generative llm serving with speculative inference and token tree verification. arXiv preprint arXiv:2305.09781, 2023.
  34. OpenAI. Gpt-4 technical report, 2023.
  35. Training language models to follow instructions with human feedback. arXiv preprint arXiv:2203.02155, 2022.
  36. MAUVE: Measuring the gap between neural text and human text using divergence frontiers. In A. Beygelzimer, Y. Dauphin, P. Liang, and J. Wortman Vaughan, editors, Advances in Neural Information Processing Systems, 2021. URL https://openreview.net/forum?id=Tqx7nJp7PR.
  37. Efficiently scaling transformer inference. November 2022. doi: 10.48550/ARXIV.2211.05102.
  38. ShareGPT. ShareGPT. https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered, 2023.
  39. Noam Shazeer. Fast transformer decoding: One write-head is all you need. arXiv preprint arXiv:1911.02150, 2019.
  40. Accelerating llm inference with staged speculative decoding. arXiv preprint arXiv:2308.04623, 2023.
  41. Blockwise parallel decoding for deep autoregressive models. Neural Information Processing Systems, 2018.
  42. Llama 2: Open foundation and fine-tuned chat models. arXiv preprint arXiv:2307.09288, 2023.
  43. Zephyr: Direct distillation of lm alignment, 2023.
  44. Speculative decoding: Lossless speedup of autoregressive translation, 2023. URL https://openreview.net/forum?id=H-VlwsYvVi.
  45. Smoothquant: Accurate and efficient post-training quantization for large language models. In International Conference on Machine Learning, pages 38087–38099. PMLR, 2023a.
  46. A survey on non-autoregressive generation for neural machine translation and beyond. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2023b.
  47. Do transformers really perform badly for graph representation? Advances in Neural Information Processing Systems, 34:28877–28888, 2021.
  48. Opt: Open pre-trained transformer language models. arXiv preprint arXiv:2205.01068, 2022.
  49. H _⁢2_2\_2_ 2 o: Heavy-hitter oracle for efficient generative inference of large language models. arXiv preprint arXiv:2306.14048, 2023.
  50. Judging llm-as-a-judge with mt-bench and chatbot arena, 2023.
  51. Distillspec: Improving speculative decoding via knowledge distillation. arXiv preprint arXiv: 2310.08461, 2023.
User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (7)
  1. Tianle Cai (34 papers)
  2. Yuhong Li (33 papers)
  3. Zhengyang Geng (17 papers)
  4. Hongwu Peng (27 papers)
  5. Jason D. Lee (151 papers)
  6. Deming Chen (62 papers)
  7. Tri Dao (47 papers)
Citations (159)
Youtube Logo Streamline Icon: https://streamlinehq.com