- The paper introduces Coconut, a novel method enabling large language models to reason in a continuous latent space instead of relying solely on discrete token generation.
- It employs a dual-mode approach with a multi-stage curriculum, integrating latent and language modes using special tokens to transition between reasoning states.
- Experimental results on tasks like GSM8k and ProsQA demonstrate that Coconut improves reasoning depth and planning efficiency while reducing computational cost.
This paper (2412.06769) introduces Coconut (Chain of Continuous Thought), a novel paradigm that enables LLMs to reason in a continuous latent space rather than being restricted to generating explicit language tokens like in Chain-of-Thought (CoT). The core idea is to directly feed the last hidden state of the LLM, representing a "continuous thought," back into the model as the subsequent input embedding, bypassing the standard LLM head and token embedding lookup.
The motivation behind Coconut stems from the observation that human reasoning is not solely reliant on language, and that language generation in CoT can be inefficient, dedicating similar computational effort to fluency tokens as to critical reasoning steps. The authors argue that an unrestricted latent space might be a more optimal environment for complex reasoning, allowing for richer representations of reasoning states and potentially more efficient computation.
Method and Implementation:
The Coconut method introduces a "latent mode" alongside the standard "language mode." Special tokens, <bot> and <eot>, mark the boundaries of the latent reasoning sequence.
- Language Mode: Standard autoregressive decoding, where the model predicts the next token using the LLM head and the input is the embedding of the previous token.
- Latent Mode: When the model is between <bot> and <eot>, the last hidden state from the previous step is directly used as the input embedding for the next step. This hidden state is not projected back to the vocabulary space via the LLM head. The input sequence becomes a mix of token embeddings and hidden states: Et=[e(x1),...,e(xi),hi,hi+1,...,ht−1] within the latent sequence xi= <bot>, xj= <eot>.
Training:
Coconut leverages existing language CoT data to train the latent reasoning capability through a multi-stage curriculum, inspired by iCoT (2405.14838).
- Initial Stage: Train on complete language CoT instances.
- Subsequent Stages: In stage k, the first k language reasoning steps in the CoT are replaced by k×c continuous thoughts, where c is a hyperparameter (e.g., c=1 or c=2). The model is trained to predict the remaining language tokens after the latent sequence. The <bot> and <eot> tokens are inserted around the latent sequence.
- Loss: Standard negative log-likelihood loss is applied, but masked for the initial question tokens and the continuous latent thoughts themselves. The loss is only calculated on the remaining language tokens (the latter part of the original CoT and the final answer).
- Optimizer Reset: The optimizer state is reset when switching training stages, similar to (2405.14838).
- Differentiability: The continuous thoughts are fully differentiable, allowing end-to-end optimization via gradient descent.
- Computational Cost: Training involves multiple forward passes (n+1 for n latent thoughts) per sample, posing challenges for parallelism, although KV caching can help.
Inference:
Inference follows the same mode switching.
- A <bot> token is typically inserted after the question.
- Determining when to generate <eot> can be done by:
a) Training a binary classifier on latent states to predict the <eot> token.
b) Padding the latent sequence to a fixed length, inserting <eot> after a predetermined number of continuous thoughts. The authors use the second approach for simplicity in their experiments.
- Decoding within the latent mode involves feeding the last hidden state as the next input embedding for a fixed number of steps. After generating <eot>, decoding reverts to language mode.
Experiments and Results:
The authors evaluate Coconut on GSM8k (math reasoning), ProntoQA (logical reasoning), and a newly proposed dataset, ProsQA (planning-intensive logical reasoning) [(2412.06769) Appendix A.2]. They compare Coconut to baselines like CoT, No-CoT, iCoT (2405.14838), and a Pause Token method (2310.02226), as well as variants of Coconut (trained without curriculum, without thoughts, or using pause tokens as thoughts).
Key findings:
- Enhanced Reasoning: Coconut consistently improves performance over the No-CoT baseline across tasks, demonstrating that latent reasoning can effectively augment LLMs.
- Efficiency: Coconut generates significantly fewer tokens during inference compared to CoT while often achieving comparable or better accuracy (Table 1, Appendix B). Clock-time is generally lower than CoT (Table 3).
- "Chaining" Effect: On GSM8k, increasing the number of continuous thoughts (c) correlates with improved performance (Figure 4), suggesting that chaining latent thoughts provides a similar benefit to chaining language tokens in CoT, increasing effective reasoning depth.
- Advantage in Planning: Coconut and its variants, including iCoT, show substantial improvements over CoT on planning-intensive tasks like ProsQA (Table 1). This suggests latent reasoning is better suited for problems requiring search or backtracking.
- Importance of Curriculum: Training Coconut without the multi-stage curriculum significantly degrades performance (Table 1), highlighting the need for structured guidance to learn effective latent reasoning.
- Continuous Thought as Representation: Analysis suggests continuous thoughts can encode intermediate reasoning variables (Figure 3) and, unlike discrete tokens, maintain a distribution over multiple possible next steps (Figure 6).
Understanding Latent Reasoning:
Through experiments interpolating between latent and language reasoning on ProsQA (by forcing <eot> after k steps), the authors show that increasing latent steps improves accuracy and reduces errors like hallucination (Figure 5).
Interpreting the latent reasoning as a search tree, they analyze the probability distribution over subsequent language tokens if the model were forced to decode after k latent steps. This "implicit value function" reveals that:
- Latent reasoning explores multiple paths in parallel, akin to Breadth-First Search (BFS), especially in earlier steps (Figure 7).
- The model prunes less promising paths over subsequent latent steps.
- Evaluating potential next steps is easier when they have lower "height" (shorter distance to a leaf node) in the problem structure (Figure 8), which might explain why delaying hard decisions in latent space improves planning on complex graphs.
Conclusion and Future Work:
The paper concludes that Coconut demonstrates the promise of continuous latent space reasoning, showing benefits in efficiency and performance, particularly for planning-intensive tasks. The emergent BFS-like behavior in latent space is a key finding. Future work includes developing more robust training strategies (potentially without explicit language supervision), refining the curriculum (e.g., incremental addition of latent thoughts), and potentially exploring pretraining with continuous thoughts for better generalization. The authors also suggest combining language and latent reasoning, generating a high-level plan in language and executing steps in latent space.