The paper introduces CODI (Continuous Chain-of-Thought via Self-Distillation), a framework designed to compress the chain-of-thought reasoning process into a continuous space, with the goal of improving efficiency while maintaining performance. CODI employs a self-distillation approach, where a shared model acts as both teacher and student, learning explicit and implicit CoT jointly, and aligning their hidden activations during the generation of the final answer.
The central idea is to move away from discrete, natural language representations of CoT, which may not be optimal for reasoning, towards dense, continuous representations. The paper highlights that while implicit CoT methods exist, they generally underperform explicit CoT. CODI addresses this gap by distilling knowledge from explicit CoT (teacher) to implicit CoT (student) within the same model.
The CODI framework involves two main tasks:
- Teacher Task: The teacher task learns to generate explicit CoTs using a standard LLMing objective. This provides the model with structured reasoning patterns. The loss function is defined as:
where:
- is the probability distribution of the LLM
- refers to both the CoT and the answer labels
- refers to the question tokens.
- Student Task: The student task learns to generate continuous thoughts by autoregressively propagating hidden states and predicting the final answer. The loss function is defined as:
where:
- is the probability distribution of the LLM
- refers to the answer label
- refers to the question tokens
- refers to the continuous thoughts.
The student task uses special tokens,
bot
andeot
, to mark the start and end of continuous reasoning. A two-layer Multi-Layer Perceptron (MLP) with layer normalization transforms the hidden representations of continuous thought tokens.
Knowledge distillation is achieved by aligning the hidden activations of a key token between the teacher and student tasks. Specifically, the hidden activation of the token immediately preceding the answer (e.g., the colon in "The answer is:") is used. This token is chosen because it is believed to encode crucial reasoning information. The alignment is enforced using an L1 loss:
where:
- is the number of layers in the LLM
- sg denotes stop gradient
- is the hidden activation of the LLM's -th layer.
The overall training objective is a weighted sum of the teacher loss, student loss, and knowledge distillation loss:
where , , and are hyperparameters.
The paper provides a theoretical justification for aligning the hidden activations of the selected token. Drawing upon observations from in-context learning, the authors posit that CoT tokens induce a shift in the hidden activation values of the target token. This "CoT shift" is formalized as:
where:
- is the query of this target token
- is the hidden activation at layer with CoT (equivalent to )
- is the corresponding activation without CoT
- is the CoT rationale
- is the model's value parameters
- is the model's key parameters
- is a non-linear function
This suggests that the target token's hidden activation encodes the influence of preceding reasoning steps, and the student can learn this shift by minimizing the L1 distance with the teacher's hidden activation.
The paper details experiments conducted on mathematical reasoning tasks, using the GSM8k dataset and its variants. The results show that CODI achieves performance comparable to explicit CoT methods, while also demonstrating efficiency gains through compression. CODI achieves a 3.1x compression ratio, and is the first implicit CoT method to match explicit CoT's performance on GSM8k, surpassing the previous state-of-the-art by 28.2\% in accuracy. The method is also shown to be robust, scalable, and generalizable to more complex CoT datasets.
Furthermore, the paper explores the interpretability of CODI by decoding its continuous thoughts and analyzing the attended tokens. The authors find that CODI can produce observable intermediate results within its continuous thoughts.
Ablation studies validate the design choices in CODI, including the use of a shared model for the teacher and student tasks, the importance of the distillation loss, and the impact of excluding the final step of the CoT chain during training.