Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
41 tokens/sec
GPT-4o
59 tokens/sec
Gemini 2.5 Pro Pro
41 tokens/sec
o3 Pro
7 tokens/sec
GPT-4.1 Pro
50 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

CODI: Compressing Chain-of-Thought into Continuous Space via Self-Distillation (2502.21074v2)

Published 28 Feb 2025 in cs.CL
CODI: Compressing Chain-of-Thought into Continuous Space via Self-Distillation

Abstract: Chain-of-Thought (CoT) reasoning enhances LLMs by encouraging step-by-step reasoning in natural language. However, leveraging a latent continuous space for reasoning may offer benefits in terms of both efficiency and robustness. Prior implicit CoT methods attempt to bypass language completely by reasoning in continuous space but have consistently underperformed compared to the standard explicit CoT approach. We introduce CODI (Continuous Chain-of-Thought via Self-Distillation), a novel training framework that effectively compresses natural language CoT into continuous space. CODI jointly trains a teacher task (Explicit CoT) and a student task (Implicit CoT), distilling the reasoning ability from language into continuous space by aligning the hidden states of a designated token. Our experiments show that CODI is the first implicit CoT approach to match the performance of explicit CoT on GSM8k at the GPT-2 scale, achieving a 3.1x compression rate and outperforming the previous state-of-the-art by 28.2% in accuracy. CODI also demonstrates robustness, generalizable to complex datasets, and interpretability. These results validate that LLMs can reason effectively not only in natural language, but also in a latent continuous space.

The paper introduces CODI (Continuous Chain-of-Thought via Self-Distillation), a framework designed to compress the chain-of-thought reasoning process into a continuous space, with the goal of improving efficiency while maintaining performance. CODI employs a self-distillation approach, where a shared model acts as both teacher and student, learning explicit and implicit CoT jointly, and aligning their hidden activations during the generation of the final answer.

The central idea is to move away from discrete, natural language representations of CoT, which may not be optimal for reasoning, towards dense, continuous representations. The paper highlights that while implicit CoT methods exist, they generally underperform explicit CoT. CODI addresses this gap by distilling knowledge from explicit CoT (teacher) to implicit CoT (student) within the same model.

The CODI framework involves two main tasks:

  • Teacher Task: The teacher task learns to generate explicit CoTs using a standard LLMing objective. This provides the model with structured reasoning patterns. The loss function is defined as:

    Lteacher=1Ni=1NlogP(yiy1:i1,Q)\mathcal{L}_{\text{teacher}} = -\frac{1}{N} \sum_{i=1}^{N} \log P(y_i \mid y_{1:i-1}, Q)

    where:

    • PP is the probability distribution of the LLM
    • yy refers to both the CoT and the answer labels
    • QQ refers to the question tokens.
  • Student Task: The student task learns to generate continuous thoughts by autoregressively propagating hidden states and predicting the final answer. The loss function is defined as:

    Lstudent=1Ni=1NlogP(yiy1:i1,Q,Z)\mathcal{L}_{\text{student}} = - \frac{1}{N} \sum_{i=1}^{N} \log P(y_i \mid y_{1:i-1}, Q, Z)

    where:

    • PP is the probability distribution of the LLM
    • yy refers to the answer label
    • QQ refers to the question tokens
    • ZZ refers to the continuous thoughts.

    The student task uses special tokens, bot and eot, to mark the start and end of continuous reasoning. A two-layer Multi-Layer Perceptron (MLP) with layer normalization transforms the hidden representations of continuous thought tokens.

Knowledge distillation is achieved by aligning the hidden activations of a key token between the teacher and student tasks. Specifically, the hidden activation of the token immediately preceding the answer (e.g., the colon in "The answer is:") is used. This token is chosen because it is believed to encode crucial reasoning information. The alignment is enforced using an L1 loss:

LKD=1Ml=1Msg[hteacherl]hstudentl\mathcal{L}_{\text{KD}} = \frac{1}{M} \sum_{l=1}^M |\text{sg}[h_{\text{teacher}}^l]-h_{\text{student}}^l|

where:

  • MM is the number of layers in the LLM
  • sg denotes stop gradient
  • hlh^l is the hidden activation of the LLM's ll-th layer.

The overall training objective is a weighted sum of the teacher loss, student loss, and knowledge distillation loss:

L=αLteacher+βLstudent+γLKD\mathcal{L} = \alpha \mathcal{L}_{\text{teacher}} + \beta \mathcal{L}_{\text{student}} + \gamma \mathcal{L}_{\text{KD}}

where α\alpha, β\beta, and γ\gamma are hyperparameters.

The paper provides a theoretical justification for aligning the hidden activations of the selected token. Drawing upon observations from in-context learning, the authors posit that CoT tokens induce a shift in the hidden activation values of the target token. This "CoT shift" is formalized as:

hCoTlhno-CoTl+f(WVR(WKR)Tq)\mathbf{h}^l_{\text{CoT}} \approx \mathbf{h}^l_{\text{no-CoT}} + f\Big(W_V R(W_K R)^Tq\Big)

where:

  • qq is the query of this target token
  • hCoTl\mathbf{h}^l_{\text{CoT}} is the hidden activation at layer ll with CoT (equivalent to hteacherl\mathbf{h}^l_{\text{teacher}})
  • hno-CoTl\mathbf{h}^l_{\text{no-CoT}} is the corresponding activation without CoT
  • RR is the CoT rationale
  • WVW_V is the model's value parameters
  • WKW_K is the model's key parameters
  • ff is a non-linear function

This suggests that the target token's hidden activation encodes the influence of preceding reasoning steps, and the student can learn this shift by minimizing the L1 distance with the teacher's hidden activation.

The paper details experiments conducted on mathematical reasoning tasks, using the GSM8k dataset and its variants. The results show that CODI achieves performance comparable to explicit CoT methods, while also demonstrating efficiency gains through compression. CODI achieves a 3.1x compression ratio, and is the first implicit CoT method to match explicit CoT's performance on GSM8k, surpassing the previous state-of-the-art by 28.2\% in accuracy. The method is also shown to be robust, scalable, and generalizable to more complex CoT datasets.

Furthermore, the paper explores the interpretability of CODI by decoding its continuous thoughts and analyzing the attended tokens. The authors find that CODI can produce observable intermediate results within its continuous thoughts.

Ablation studies validate the design choices in CODI, including the use of a shared model for the teacher and student tasks, the importance of the distillation loss, and the impact of excluding the final step of the CoT chain during training.

User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (6)
  1. Zhenyi Shen (2 papers)
  2. Hanqi Yan (18 papers)
  3. Linhai Zhang (12 papers)
  4. Zhanghao Hu (3 papers)
  5. Yali Du (63 papers)
  6. Yulan He (113 papers)