Direct Alignment of Draft Model for Speculative Decoding with Chat-Fine-Tuned LLMs (2403.00858v4)
Abstract: Text generation with LLMs is known to be memory bound due to the combination of their auto-regressive nature, huge parameter counts, and limited memory bandwidths, often resulting in low token rates. Speculative decoding has been proposed as a solution for LLM inference acceleration. However, since draft models are often unavailable in the modern open-source LLM families, e.g., for Llama 2 7B, training a high-quality draft model is required to enable inference acceleration via speculative decoding. In this paper, we propose a simple draft model training framework for direct alignment to chat-capable target models. With the proposed framework, we train Llama 2 Chat Drafter 115M, a draft model for Llama 2 Chat 7B or larger, with only 1.64\% of the original size. Our training framework only consists of pretraining, distillation dataset generation, and finetuning with knowledge distillation, with no additional alignment procedure. For the finetuning step, we use instruction-response pairs generated by target model for distillation in plausible data distribution, and propose a new Total Variation Distance++ (TVD++) loss that incorporates variance reduction techniques inspired from the policy gradient method in reinforcement learning. Our empirical results show that Llama 2 Chat Drafter 115M with speculative decoding achieves up to 2.3 block efficiency and 2.4$\times$ speed-up relative to autoregressive decoding on various tasks with no further task-specific fine-tuning.
- Gpt-4 technical report. arXiv preprint arXiv:2303.08774, 2023.
- Gkd: Generalized knowledge distillation for auto-regressive sequence models. arXiv preprint arXiv:2306.13649, 2023.
- Palm 2 technical report. arXiv preprint arXiv:2305.10403, 2023.
- Baichuan. Baichuan 2: Open large-scale language models. arXiv preprint arXiv:2309.10305, 2023. URL https://arxiv.org/abs/2309.10305.
- Pythia: A suite for analyzing large language models across training and scaling. In International Conference on Machine Learning, pp. 2397–2430. PMLR, 2023.
- Findings of the 2018 conference on machine translation (wmt18). In Proceedings of the Third Conference on Machine Translation, Volume 2: Shared Task Papers, pp. 272–307, Belgium, Brussels, October 2018. Association for Computational Linguistics. URL http://www.aclweb.org/anthology/W18-6401.
- Accelerating large language model decoding with speculative sampling. arXiv preprint arXiv:2302.01318, 2023.
- Free dolly: Introducing the world’s first truly open instruction-tuned llm, 2023. URL https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm.
- Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531, 2015.
- Sequence-level knowledge distillation. arXiv preprint arXiv:1606.07947, 2016.
- Openassistant conversations–democratizing large language model alignment. arXiv preprint arXiv:2304.07327, 2023.
- On reinforcement learning and distribution matching for fine-tuning language models with no catastrophic forgetting. Advances in Neural Information Processing Systems, 35:16203–16220, 2022.
- Fast inference from transformers via speculative decoding. In International Conference on Machine Learning, pp. 19274–19286. PMLR, 2023.
- Autoregressive knowledge distillation through imitation learning. arXiv preprint arXiv:2009.07253, 2020.
- Abstractive text summarization using sequence-to-sequence rnns and beyond. arXiv preprint arXiv:1602.06023, 2016.
- Don’t give me the details, just the summary! topic-aware convolutional neural networks for extreme summarization. arXiv preprint arXiv:1808.08745, 2018.
- OIG-small-chip2. https://huggingface.co/datasets/0-hero/OIG-small-chip2.
- The RefinedWeb dataset for Falcon LLM: outperforming curated corpora with web data, and web data only. arXiv preprint arXiv:2306.01116, 2023. URL https://arxiv.org/abs/2306.01116.
- Code llama: Open foundation models for code. arXiv preprint arXiv:2308.12950, 2023.
- Weight subcloning: direct initialization of transformers using larger pretrained ones. arXiv preprint arXiv:2312.09299, 2023.
- High-dimensional continuous control using generalized advantage estimation. arXiv preprint arXiv:1506.02438, 2015.
- Llama 2: Open foundation and fine-tuned chat models. arXiv preprint arXiv:2307.09288, 2023.
- f-divergence minimization for sequence-level knowledge distillation. arXiv preprint arXiv:2307.15190, 2023.
- Ronald J Williams. Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine learning, 8:229–256, 1992.
- Distillspec: Improving speculative decoding via knowledge distillation. arXiv preprint arXiv:2310.08461, 2023.
- Raghavv Goel (7 papers)
- Mukul Gagrani (11 papers)
- Wonseok Jeon (14 papers)
- Junyoung Park (37 papers)
- Mingu Lee (16 papers)
- Christopher Lott (6 papers)