Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
194 tokens/sec
GPT-4o
7 tokens/sec
Gemini 2.5 Pro Pro
46 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Training Large Language Models to Reason in a Continuous Latent Space (2412.06769v2)

Published 9 Dec 2024 in cs.CL

Abstract: LLMs are restricted to reason in the "language space", where they typically express the reasoning process with a chain-of-thought (CoT) to solve a complex reasoning problem. However, we argue that language space may not always be optimal for reasoning. For example, most word tokens are primarily for textual coherence and not essential for reasoning, while some critical tokens require complex planning and pose huge challenges to LLMs. To explore the potential of LLM reasoning in an unrestricted latent space instead of using natural language, we introduce a new paradigm Coconut (Chain of Continuous Thought). We utilize the last hidden state of the LLM as a representation of the reasoning state (termed "continuous thought"). Rather than decoding this into a word token, we feed it back to the LLM as the subsequent input embedding directly in the continuous space. Experiments show that Coconut can effectively augment the LLM on several reasoning tasks. This novel latent reasoning paradigm leads to emergent advanced reasoning patterns: the continuous thought can encode multiple alternative next reasoning steps, allowing the model to perform a breadth-first search (BFS) to solve the problem, rather than prematurely committing to a single deterministic path like CoT. Coconut outperforms CoT in certain logical reasoning tasks that require substantial backtracking during planning, with fewer thinking tokens during inference. These findings demonstrate the promise of latent reasoning and offer valuable insights for future research.

Citations (3)

Summary

  • The paper introduces Coconut, a novel method enabling large language models to reason in a continuous latent space instead of relying solely on discrete token generation.
  • It employs a dual-mode approach with a multi-stage curriculum, integrating latent and language modes using special tokens to transition between reasoning states.
  • Experimental results on tasks like GSM8k and ProsQA demonstrate that Coconut improves reasoning depth and planning efficiency while reducing computational cost.

This paper (2412.06769) introduces Coconut (Chain of Continuous Thought), a novel paradigm that enables LLMs to reason in a continuous latent space rather than being restricted to generating explicit language tokens like in Chain-of-Thought (CoT). The core idea is to directly feed the last hidden state of the LLM, representing a "continuous thought," back into the model as the subsequent input embedding, bypassing the standard LLM head and token embedding lookup.

The motivation behind Coconut stems from the observation that human reasoning is not solely reliant on language, and that language generation in CoT can be inefficient, dedicating similar computational effort to fluency tokens as to critical reasoning steps. The authors argue that an unrestricted latent space might be a more optimal environment for complex reasoning, allowing for richer representations of reasoning states and potentially more efficient computation.

Method and Implementation:

The Coconut method introduces a "latent mode" alongside the standard "language mode." Special tokens, <bot> and <eot>, mark the boundaries of the latent reasoning sequence.

  • Language Mode: Standard autoregressive decoding, where the model predicts the next token using the LLM head and the input is the embedding of the previous token.
  • Latent Mode: When the model is between <bot> and <eot>, the last hidden state from the previous step is directly used as the input embedding for the next step. This hidden state is not projected back to the vocabulary space via the LLM head. The input sequence becomes a mix of token embeddings and hidden states: Et=[e(x1),...,e(xi),hi,hi+1,...,ht1]E_t=[e(x_1), ..., e(x_i), h_i, h_{i+1}, ..., h_{t-1}] within the latent sequence xi=x_i= <bot>, xj=x_j= <eot>.

Training:

Coconut leverages existing language CoT data to train the latent reasoning capability through a multi-stage curriculum, inspired by iCoT (2405.14838).

  1. Initial Stage: Train on complete language CoT instances.
  2. Subsequent Stages: In stage kk, the first kk language reasoning steps in the CoT are replaced by k×ck \times c continuous thoughts, where cc is a hyperparameter (e.g., c=1c=1 or c=2c=2). The model is trained to predict the remaining language tokens after the latent sequence. The <bot> and <eot> tokens are inserted around the latent sequence.
  3. Loss: Standard negative log-likelihood loss is applied, but masked for the initial question tokens and the continuous latent thoughts themselves. The loss is only calculated on the remaining language tokens (the latter part of the original CoT and the final answer).
  4. Optimizer Reset: The optimizer state is reset when switching training stages, similar to (2405.14838).
  5. Differentiability: The continuous thoughts are fully differentiable, allowing end-to-end optimization via gradient descent.
  6. Computational Cost: Training involves multiple forward passes (n+1 for n latent thoughts) per sample, posing challenges for parallelism, although KV caching can help.

Inference:

Inference follows the same mode switching.

  • A <bot> token is typically inserted after the question.
  • Determining when to generate <eot> can be done by: a) Training a binary classifier on latent states to predict the <eot> token. b) Padding the latent sequence to a fixed length, inserting <eot> after a predetermined number of continuous thoughts. The authors use the second approach for simplicity in their experiments.
  • Decoding within the latent mode involves feeding the last hidden state as the next input embedding for a fixed number of steps. After generating <eot>, decoding reverts to language mode.

Experiments and Results:

The authors evaluate Coconut on GSM8k (math reasoning), ProntoQA (logical reasoning), and a newly proposed dataset, ProsQA (planning-intensive logical reasoning) [(2412.06769) Appendix A.2]. They compare Coconut to baselines like CoT, No-CoT, iCoT (2405.14838), and a Pause Token method (2310.02226), as well as variants of Coconut (trained without curriculum, without thoughts, or using pause tokens as thoughts).

Key findings:

  • Enhanced Reasoning: Coconut consistently improves performance over the No-CoT baseline across tasks, demonstrating that latent reasoning can effectively augment LLMs.
  • Efficiency: Coconut generates significantly fewer tokens during inference compared to CoT while often achieving comparable or better accuracy (Table 1, Appendix B). Clock-time is generally lower than CoT (Table 3).
  • "Chaining" Effect: On GSM8k, increasing the number of continuous thoughts (cc) correlates with improved performance (Figure 4), suggesting that chaining latent thoughts provides a similar benefit to chaining language tokens in CoT, increasing effective reasoning depth.
  • Advantage in Planning: Coconut and its variants, including iCoT, show substantial improvements over CoT on planning-intensive tasks like ProsQA (Table 1). This suggests latent reasoning is better suited for problems requiring search or backtracking.
  • Importance of Curriculum: Training Coconut without the multi-stage curriculum significantly degrades performance (Table 1), highlighting the need for structured guidance to learn effective latent reasoning.
  • Continuous Thought as Representation: Analysis suggests continuous thoughts can encode intermediate reasoning variables (Figure 3) and, unlike discrete tokens, maintain a distribution over multiple possible next steps (Figure 6).

Understanding Latent Reasoning:

Through experiments interpolating between latent and language reasoning on ProsQA (by forcing <eot> after kk steps), the authors show that increasing latent steps improves accuracy and reduces errors like hallucination (Figure 5).

Interpreting the latent reasoning as a search tree, they analyze the probability distribution over subsequent language tokens if the model were forced to decode after kk latent steps. This "implicit value function" reveals that:

  • Latent reasoning explores multiple paths in parallel, akin to Breadth-First Search (BFS), especially in earlier steps (Figure 7).
  • The model prunes less promising paths over subsequent latent steps.
  • Evaluating potential next steps is easier when they have lower "height" (shorter distance to a leaf node) in the problem structure (Figure 8), which might explain why delaying hard decisions in latent space improves planning on complex graphs.

Conclusion and Future Work:

The paper concludes that Coconut demonstrates the promise of continuous latent space reasoning, showing benefits in efficiency and performance, particularly for planning-intensive tasks. The emergent BFS-like behavior in latent space is a key finding. Future work includes developing more robust training strategies (potentially without explicit language supervision), refining the curriculum (e.g., incremental addition of latent thoughts), and potentially exploring pretraining with continuous thoughts for better generalization. The authors also suggest combining language and latent reasoning, generating a high-level plan in language and executing steps in latent space.

Youtube Logo Streamline Icon: https://streamlinehq.com

HackerNews