- The paper introduces an implicit chain-of-thought method that eliminates explicit intermediate steps using teacher hidden states.
- The technique employs a three-stage process—mind-reading, thought emulation, and joint optimization—to enhance LLM reasoning.
- Experiments on multi-digit multiplication and grade school math show improved inference efficiency and problem accuracy compared to traditional CoT methods.
Implicit Chain of Thought Reasoning via Knowledge Distillation
Introduction
The study explores an innovative approach to improve reasoning capabilities in LLMs by employing implicit reasoning without generating explicit intermediate steps. Traditionally, chain-of-thought (CoT) methods prompt models to articulate reasoning steps leading to a final answer. This resembles human cognitive processes but may not utilize the full computational potential of LLMs. The authors propose using internal hidden states for reasoning, distilling this from a teacher model trained on explicit CoT reasoning, allowing for vertical reasoning within the model's layers.
Methodology
Implicit Chain-of-Thought Framework
The framework comprises three main steps:
- Mind-Reading the Teacher: A student model is trained to leverage the continuous hidden states generated by a teacher model during intermediate reasoning steps. This student model directly utilizes selected hidden states from the teacher, bypassing the explicit reasoning steps to produce the final answer.
- Thought Emulation: Knowledge distillation is applied to train an emulator, which predicts teacher hidden states vertically across layers, eliminating the need for horizontally explicit reasoning steps. The emulator captures and compresses the teacher's reasoning into a sequence of compact internal states used by the student model at inference.
- Couple and Optimize: The emulator and student model are combined and optimized end-to-end. This holistic system empowers the student model to refine its reasoning strategies potentially diverging from the teacher's approach, enabling efficient and direct answer generation.
Experimental Setup
Experiments were conducted on two tasks: multi-digit multiplication and grade school math problems, utilizing datasets from BIG-bench and GSM8K. The implicit CoT approach was compared against baselines using no CoT and explicit CoT reasoning modes, with models such as GPT-2 Small, Medium, and Large.
Results
The results demonstrated the implicit CoT method's efficacy, showcasing improvements in tasks requiring complex reasoning steps. For instance, the approach achieved high accuracy in five-digit multiplication using GPT-2 Medium, which was previously unsolvable using traditional methods without explicit reasoning. In handling grade school math problems, it significantly improved answer accuracy compared to no CoT methods.
Implicit CoT also demonstrated notable efficiency in inference time compared to explicit CoT methods, as it streamlined reasoning processes without generating verbose intermediate steps. However, accuracy compared to explicit methods showed room for improvement, hinting at potential shortfalls in adapting vertical reasoning effectively over larger-scale reasoning tasks.
Analysis and Discussion
The paper highlights several critical insights:
- The efficacy of diagonal hidden state selection from teacher models proved effective, highlighting the importance of strategic information extraction from model layers.
- Introducing a mixture model accounted for multiple reasoning pathways, essential for tasks like GSM8K with non-unique intermediary tokens.
- The "Optimize" stage enabled the student model to articulate unique reasoning pathways, enhancing prediction accuracy and performance albeit at the cost of interpretability.
Conclusion
The research introduces a compelling paradigm shift towards implicit reasoning in LLMs. By leveraging internal hidden states vertically, models can circumvent traditional human-like reasoning steps, achieving faster and potentially more autonomous decision-making processes. While the approach offers substantial promise, further research could refine implicit reasoning, exploring fully end-to-end training strategies and integrating such methods into pre-training processes. The study lays foundational work to inspire future exploration into the autonomous reasoning capacities of large-scale LLMs.