Papers
Topics
Authors
Recent
Search
2000 character limit reached

Implicit Chain of Thought Reasoning via Knowledge Distillation

Published 2 Nov 2023 in cs.CL, cs.AI, and cs.LG | (2311.01460v1)

Abstract: To augment LLMs with the ability to reason, researchers usually prompt or finetune them to produce chain of thought reasoning steps before producing the final answer. However, although people use natural language to reason effectively, it may be that LMs could reason more effectively with some intermediate computation that is not in natural language. In this work, we explore an alternative reasoning approach: instead of explicitly producing the chain of thought reasoning steps, we use the LLM's internal hidden states to perform implicit reasoning. The implicit reasoning steps are distilled from a teacher model trained on explicit chain-of-thought reasoning, and instead of doing reasoning "horizontally" by producing intermediate words one-by-one, we distill it such that the reasoning happens "vertically" among the hidden states in different layers. We conduct experiments on a multi-digit multiplication task and a grade school math problem dataset and find that this approach enables solving tasks previously not solvable without explicit chain-of-thought, at a speed comparable to no chain-of-thought.

Citations (33)

Summary

  • The paper introduces an implicit chain-of-thought method that eliminates explicit intermediate steps using teacher hidden states.
  • The technique employs a three-stage process—mind-reading, thought emulation, and joint optimization—to enhance LLM reasoning.
  • Experiments on multi-digit multiplication and grade school math show improved inference efficiency and problem accuracy compared to traditional CoT methods.

Implicit Chain of Thought Reasoning via Knowledge Distillation

Introduction

The study explores an innovative approach to improve reasoning capabilities in LLMs by employing implicit reasoning without generating explicit intermediate steps. Traditionally, chain-of-thought (CoT) methods prompt models to articulate reasoning steps leading to a final answer. This resembles human cognitive processes but may not utilize the full computational potential of LLMs. The authors propose using internal hidden states for reasoning, distilling this from a teacher model trained on explicit CoT reasoning, allowing for vertical reasoning within the model's layers.

Methodology

Implicit Chain-of-Thought Framework

The framework comprises three main steps:

  1. Mind-Reading the Teacher: A student model is trained to leverage the continuous hidden states generated by a teacher model during intermediate reasoning steps. This student model directly utilizes selected hidden states from the teacher, bypassing the explicit reasoning steps to produce the final answer.
  2. Thought Emulation: Knowledge distillation is applied to train an emulator, which predicts teacher hidden states vertically across layers, eliminating the need for horizontally explicit reasoning steps. The emulator captures and compresses the teacher's reasoning into a sequence of compact internal states used by the student model at inference.
  3. Couple and Optimize: The emulator and student model are combined and optimized end-to-end. This holistic system empowers the student model to refine its reasoning strategies potentially diverging from the teacher's approach, enabling efficient and direct answer generation.

Experimental Setup

Experiments were conducted on two tasks: multi-digit multiplication and grade school math problems, utilizing datasets from BIG-bench and GSM8K. The implicit CoT approach was compared against baselines using no CoT and explicit CoT reasoning modes, with models such as GPT-2 Small, Medium, and Large.

Results

The results demonstrated the implicit CoT method's efficacy, showcasing improvements in tasks requiring complex reasoning steps. For instance, the approach achieved high accuracy in five-digit multiplication using GPT-2 Medium, which was previously unsolvable using traditional methods without explicit reasoning. In handling grade school math problems, it significantly improved answer accuracy compared to no CoT methods.

Implicit CoT also demonstrated notable efficiency in inference time compared to explicit CoT methods, as it streamlined reasoning processes without generating verbose intermediate steps. However, accuracy compared to explicit methods showed room for improvement, hinting at potential shortfalls in adapting vertical reasoning effectively over larger-scale reasoning tasks.

Analysis and Discussion

The paper highlights several critical insights:

  • The efficacy of diagonal hidden state selection from teacher models proved effective, highlighting the importance of strategic information extraction from model layers.
  • Introducing a mixture model accounted for multiple reasoning pathways, essential for tasks like GSM8K with non-unique intermediary tokens.
  • The "Optimize" stage enabled the student model to articulate unique reasoning pathways, enhancing prediction accuracy and performance albeit at the cost of interpretability.

Conclusion

The research introduces a compelling paradigm shift towards implicit reasoning in LLMs. By leveraging internal hidden states vertically, models can circumvent traditional human-like reasoning steps, achieving faster and potentially more autonomous decision-making processes. While the approach offers substantial promise, further research could refine implicit reasoning, exploring fully end-to-end training strategies and integrating such methods into pre-training processes. The study lays foundational work to inspire future exploration into the autonomous reasoning capacities of large-scale LLMs.

Paper to Video (Beta)

Whiteboard

No one has generated a whiteboard explanation for this paper yet.

Open Problems

We haven't generated a list of open problems mentioned in this paper yet.

Collections

Sign up for free to add this paper to one or more collections.

Tweets

Sign up for free to view the 8 tweets with 170 likes about this paper.