Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
41 tokens/sec
GPT-4o
59 tokens/sec
Gemini 2.5 Pro Pro
41 tokens/sec
o3 Pro
7 tokens/sec
GPT-4.1 Pro
50 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Direct Alignment of Draft Model for Speculative Decoding with Chat-Fine-Tuned LLMs (2403.00858v4)

Published 29 Feb 2024 in cs.LG, cs.AI, and cs.CL

Abstract: Text generation with LLMs is known to be memory bound due to the combination of their auto-regressive nature, huge parameter counts, and limited memory bandwidths, often resulting in low token rates. Speculative decoding has been proposed as a solution for LLM inference acceleration. However, since draft models are often unavailable in the modern open-source LLM families, e.g., for Llama 2 7B, training a high-quality draft model is required to enable inference acceleration via speculative decoding. In this paper, we propose a simple draft model training framework for direct alignment to chat-capable target models. With the proposed framework, we train Llama 2 Chat Drafter 115M, a draft model for Llama 2 Chat 7B or larger, with only 1.64\% of the original size. Our training framework only consists of pretraining, distillation dataset generation, and finetuning with knowledge distillation, with no additional alignment procedure. For the finetuning step, we use instruction-response pairs generated by target model for distillation in plausible data distribution, and propose a new Total Variation Distance++ (TVD++) loss that incorporates variance reduction techniques inspired from the policy gradient method in reinforcement learning. Our empirical results show that Llama 2 Chat Drafter 115M with speculative decoding achieves up to 2.3 block efficiency and 2.4$\times$ speed-up relative to autoregressive decoding on various tasks with no further task-specific fine-tuning.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (25)
  1. Gpt-4 technical report. arXiv preprint arXiv:2303.08774, 2023.
  2. Gkd: Generalized knowledge distillation for auto-regressive sequence models. arXiv preprint arXiv:2306.13649, 2023.
  3. Palm 2 technical report. arXiv preprint arXiv:2305.10403, 2023.
  4. Baichuan. Baichuan 2: Open large-scale language models. arXiv preprint arXiv:2309.10305, 2023. URL https://arxiv.org/abs/2309.10305.
  5. Pythia: A suite for analyzing large language models across training and scaling. In International Conference on Machine Learning, pp.  2397–2430. PMLR, 2023.
  6. Findings of the 2018 conference on machine translation (wmt18). In Proceedings of the Third Conference on Machine Translation, Volume 2: Shared Task Papers, pp.  272–307, Belgium, Brussels, October 2018. Association for Computational Linguistics. URL http://www.aclweb.org/anthology/W18-6401.
  7. Accelerating large language model decoding with speculative sampling. arXiv preprint arXiv:2302.01318, 2023.
  8. Free dolly: Introducing the world’s first truly open instruction-tuned llm, 2023. URL https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm.
  9. Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531, 2015.
  10. Sequence-level knowledge distillation. arXiv preprint arXiv:1606.07947, 2016.
  11. Openassistant conversations–democratizing large language model alignment. arXiv preprint arXiv:2304.07327, 2023.
  12. On reinforcement learning and distribution matching for fine-tuning language models with no catastrophic forgetting. Advances in Neural Information Processing Systems, 35:16203–16220, 2022.
  13. Fast inference from transformers via speculative decoding. In International Conference on Machine Learning, pp.  19274–19286. PMLR, 2023.
  14. Autoregressive knowledge distillation through imitation learning. arXiv preprint arXiv:2009.07253, 2020.
  15. Abstractive text summarization using sequence-to-sequence rnns and beyond. arXiv preprint arXiv:1602.06023, 2016.
  16. Don’t give me the details, just the summary! topic-aware convolutional neural networks for extreme summarization. arXiv preprint arXiv:1808.08745, 2018.
  17. OIG-small-chip2. https://huggingface.co/datasets/0-hero/OIG-small-chip2.
  18. The RefinedWeb dataset for Falcon LLM: outperforming curated corpora with web data, and web data only. arXiv preprint arXiv:2306.01116, 2023. URL https://arxiv.org/abs/2306.01116.
  19. Code llama: Open foundation models for code. arXiv preprint arXiv:2308.12950, 2023.
  20. Weight subcloning: direct initialization of transformers using larger pretrained ones. arXiv preprint arXiv:2312.09299, 2023.
  21. High-dimensional continuous control using generalized advantage estimation. arXiv preprint arXiv:1506.02438, 2015.
  22. Llama 2: Open foundation and fine-tuned chat models. arXiv preprint arXiv:2307.09288, 2023.
  23. f-divergence minimization for sequence-level knowledge distillation. arXiv preprint arXiv:2307.15190, 2023.
  24. Ronald J Williams. Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine learning, 8:229–256, 1992.
  25. Distillspec: Improving speculative decoding via knowledge distillation. arXiv preprint arXiv:2310.08461, 2023.
User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (6)
  1. Raghavv Goel (7 papers)
  2. Mukul Gagrani (11 papers)
  3. Wonseok Jeon (14 papers)
  4. Junyoung Park (37 papers)
  5. Mingu Lee (16 papers)
  6. Christopher Lott (6 papers)
Citations (4)
X Twitter Logo Streamline Icon: https://streamlinehq.com

Tweets