Reasoning by Superposition: A Theoretical Perspective on Chain of Continuous Thought (2505.12514v2)
Abstract: LLMs have demonstrated remarkable performance in many applications, including challenging reasoning problems via chain-of-thoughts (CoTs) techniques that generate ``thinking tokens'' before answering the questions. While existing theoretical works demonstrate that CoTs with discrete tokens boost the capability of LLMs, recent work on continuous CoTs lacks a theoretical understanding of why it outperforms discrete counterparts in various reasoning tasks such as directed graph reachability, a fundamental graph reasoning problem that includes many practical domain applications as special cases. In this paper, we prove that a two-layer transformer with $D$ steps of continuous CoTs can solve the directed graph reachability problem, where $D$ is the diameter of the graph, while the best known result of constant-depth transformers with discrete CoTs requires $O(n2)$ decoding steps where $n$ is the number of vertices ($D<n$). In our construction, each continuous thought vector is a superposition state that encodes multiple search frontiers simultaneously (i.e., parallel breadth-first search (BFS)), while discrete CoTs must choose a single path sampled from the superposition state, which leads to sequential search that requires many more steps and may be trapped into local solutions. We also performed extensive experiments to verify that our theoretical construction aligns well with the empirical solution obtained via training dynamics. Notably, encoding of multiple search frontiers as a superposition state automatically emerges in training continuous CoTs, without explicit supervision to guide the model to explore multiple paths simultaneously.
Summary
- The paper demonstrates that a simple two-layer transformer employing continuous thought outperforms discrete CoT in solving directed graph reachability.
- It introduces continuous thought vectors as superposition states to implicitly perform parallel breadth-first search in reasoning tasks.
- Experimental validation on graph tasks confirms near-perfect accuracy, emphasizing efficient reasoning and innovative transformer design.
This paper, "Reasoning by Superposition: A Theoretical Perspective on Chain of Continuous Thought," investigates the mechanisms behind the superior performance of continuous chain-of-thought (CoT) reasoning in LLMs compared to discrete CoT, particularly for tasks like graph reachability. The authors provide a theoretical proof that a simple two-layer transformer using continuous thoughts can solve the directed graph reachability problem more efficiently than known results for discrete CoT.
The core idea is that continuous thought vectors act as "superposition states," simultaneously encoding multiple potential search frontiers. This allows the model to perform an implicit parallel breadth-first search (BFS) on a graph. In contrast, discrete CoT requires serializing thoughts, forcing the model to pick a single path at each step, which can be less efficient and prone to local optima.
Problem Formulation and Input Structure
The paper focuses on the directed graph reachability problem: given a graph, a starting node r, and two candidate destination nodes c1 and c2 (where exactly one is reachable from r), determine the reachable node.
The input to the transformer is structured as follows:
<s>
(Beginning of sentence token)- A sequence of edges, each represented as
(source_node, target_node, <e>)
, where<e>
is a special edge marker. <Q>
(Special question token)c_1
(First candidate destination)c_2
(Second candidate destination)<R>
(Special reasoning token)r
(Root/starting node)
This initial prompt has length t0=3m+6 (for m edges). The model then generates C continuous thought vectors, [t]1,[t]2,…,[t]C, where [t]c=TFθ(h1,…,ht0+c−1). Finally, a special answer token <A>
is appended, and the model predicts the answer based on TFθ(h1,…,ht0+C,u<A>).
Theoretical Construction and Key Results
The authors prove that a two-layer transformer can solve graph reachability in D steps of continuous thought, where D is the graph's diameter. This is significantly more efficient than the O(n2) steps required by the best-known constant-depth transformers using discrete CoT for n vertices. The embedding dimension d is 3dTE+dPE, where dTE is for token content and dPE for positional encoding. The embedding is divided into content
, buffer_1
, and buffer_2
(each dTE dims), plus an effective positional encoding
(dPE dims).
- Attention Chooser (Lemma 4.1): A crucial building block is the "attention chooser" head. Using sinusoidal positional encodings, this head can be constructed to selectively attend to a specific relative position (i−ℓ) if the current token hi is a particular token
<x>
, or attend to a default position (e.g., the first token, acting as an attention sink) otherwise. This allows for dynamic attention patterns based on the current processing context. The construction involves carefully crafting Query (Q) and Key (K) matrices that leverage the properties of sinusoidal encodings, such as pi+ℓ=R(ℓ)pi. - Continuous Thought as Superposition (Lemma 4.2): The central theoretical claim is that the c-th continuous thought vector [t]c represents a normalized superposition of all vertices Vc reachable from the root r within c steps:
[t]c=∣Vc∣1v∈Vc∑uv
This is achieved through a two-layer transformer:
- Layer 1 (Attention & MLP):
- Attention: Uses five attention chooser heads. For an edge token
<e>
at position Idx(<e>,i), these heads copy the source node si embedding into buffer1(hIdx(<e>,i)) and the target node ti embedding into buffer2(hIdx(<e>,i)). Similar operations store candidate nodes c1,c2 with the<R>
token and the last thought [t]C with the<A>
token. - MLP: Acts as a filter to clean up noise from attention and normalize the copied information. It uses a structure like W2σ(W1x) where W1 projects to a basis where each token is a coordinate, σ is an indicator function $\mathbbm{1}\{x \geq \varepsilon\}$, and W2 projects back.
- Attention: Uses five attention chooser heads. For an edge token
- Layer 2 (Attention & MLP):
- Attention (Thought Generation, Fig. 4): The current thought [t]c (a superposition of Vc) forms a query. It attends to all edge tokens
<e>
whose source node si (stored in their buffer1) is in Vc. The value vectors retrieve the target nodes ti (from buffer2 of these attended edge tokens). The sum of these target nodes is added to [t]c. This effectively performs one step of a breadth-first search expansion. - MLP: Filters the resulting sum to form [t]c+1 as a clean, normalized superposition of all nodes in Vc+1.
- Attention (Prediction, Fig. 7 in Appendix): When the
<A>
token is processed, its query (containing u<A>) attends to the<R>
token. The<R>
token's buffer2 (prepared by Layer 1) contains uc1+uc2. The value from<R>
is combined with the final thought [t]C (stored in<A>
's buffer1 by Layer 1). The final output embedding hT(L) after the<A>
token will have larger components corresponding to the reachable candidate node ci∗ because ci∗∈VC and c3−i∗∈/VC. The decoding matrix WO=U⊤ (where U is the matrix of token embeddings) can then pick out ci∗.
- Attention (Thought Generation, Fig. 4): The current thought [t]c (a superposition of Vc) forms a query. It attends to all edge tokens
- Layer 1 (Attention & MLP):
- Main Theorem (Theorem 4.1): A two-layer transformer with parameters independent of specific graphs (but dependent on vocabulary size, d=O(∣Voc∣)) can solve graph reachability in D continuous thought steps if D is the graph diameter.
Implementation Considerations
- Embedding Space: The construction uses separate "buffer" spaces within the embedding to shuttle information. In practice, these could be projected into a more compact space.
- Positional Encoding: The construction works with standard sinusoidal positional encoding and is adaptable to Rotary Position Embeddings (RoPE), as discussed in Appendix A.3.
- MLP as Filter: The MLP layers play a critical role in selecting relevant signals and normalizing states, effectively implementing $f(x) = \sum_{v} \mathbbm{1}\{\lambda_v \geq \varepsilon\} u_v$ after normalization.
- Computational Cost: Generating D continuous thoughts involves D full transformer forward passes. However, D is typically much smaller than O(n2).
Experimental Validation
Experiments were conducted on a subset of the ProsQA dataset (graph reasoning requiring 3-4 hops).
- Model: A 2-layer GPT-2 style decoder (dmodel=768,nheads=8) trained from scratch.
- Training: A multi-stage curriculum where stage i trains the model to use i continuous thoughts.
- Results:
- The 2-layer Coconut model achieved near-perfect accuracy (Fig. 5), significantly outperforming 2-layer discrete CoT (~75%) and even a 12-layer discrete CoT (83%).
- Layer 1 Attention (Fig. 6): Visualizations confirmed that edge tokens
<e>
attend to their source and target nodes, implementing the designed copying mechanism. - Layer 2 Attention (Table 1 in paper, renumbered as Table 2 in document): When generating the i-th thought, attention scores were highest for "Reachable" edges (source node in current search set), particularly "Frontier" edges (source node i hops away) and "Optimal" edges (on the solution path). This supports the BFS-like expansion.
- Continuous Thought Representation (Fig. 7): The inner product between the i-th thought [t]i and node embeddings uv was high for nodes reachable within i steps, especially high for "Frontier" nodes, and highest for "Optimal" nodes. This directly visualizes the superposition state emphasizing the search frontier.
- Exploration Priority: The model's tendency to focus on optimal paths was observed even when trained with a "Coconut-BFS" strategy (supervision from random frontier nodes, not just the optimal path). This suggests the superposition mechanism and training dynamics inherently learn efficient search strategies.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
def generate_next_thought(h_prompt, prev_thoughts, V_prev, transformer_params): # Current input sequence to the transformer current_sequence = h_prompt + prev_thoughts current_thought_position = len(current_sequence) # Position for [t]_c # --- Layer 1 --- h_layer1_out = [None] * len(current_sequence) # For each position in current_sequence: # Apply Layer 1 Attention (Attention Choosers) # - For edge tokens <e>: copy source to buffer1, target to buffer2 # e.g., h_layer1_out[Idx(<e>,i)].buffer1 = embedding(source_node_of_edge_i) # h_layer1_out[Idx(<e>,i)].buffer2 = embedding(target_node_of_edge_i) # - For previous thought [t]_{c-1}: its content is V_prev # Apply Layer 1 MLP (Filter & Normalize) # h_layer1_out[j] = MLP1(Attn1(current_sequence[j])) (simplified) # --- Layer 2 --- # Focus on the current_thought_position which will become [t]_c # Query for [t]_c is derived from content([t]_{c-1}), which is sum over V_prev q_current_thought = project_query(h_layer1_out[current_thought_position - 1].content) # Query from [t]_{c-1} attended_target_nodes_sum = zero_vector() # For each edge token <e> in h_layer1_out: # Key_edge = project_key(h_layer1_out[Idx(<e>)].buffer1) # Key from source node of edge # attention_score = softmax_sim(q_current_thought, Key_edge) # if source_node_of_edge is in V_prev (high attention_score): # Value_edge = project_value(h_layer1_out[Idx(<e>)].buffer2) # Value from target node # attended_target_nodes_sum += attention_score * Value_edge # Combine with previous thought's content (residual connection) h_current_thought_after_attn = h_layer1_out[current_thought_position - 1].content + attended_target_nodes_sum # Apply Layer 2 MLP (Filter & Normalize to get V_curr) # This MLP will effectively select nodes that are in V_prev OR are new targets from attended_target_nodes_sum # and normalize them to form the superposition for V_curr # [t]_c.content = MLP2(h_current_thought_after_attn) # [t]_c.content is now the superposition for V_curr (nodes reachable in c steps) # LayerNorm is applied after MLP in the paper's Algorithm 1 # next_thought_embedding = LayerNorm(MLP2(h_current_thought_after_attn)) # For the actual transformer, the output at the *last* token's position # after processing the *entire* sequence up to that point is the next thought. # The pseudocode above simplifies the attention mechanism to be more illustrative # of the information flow for generating one thought. # Actual generation: # tf_output_at_last_pos = Transformer(current_sequence, transformer_params) # next_thought_embedding = tf_output_at_last_pos # return next_thought_embedding pass |
Practical Implications
- Efficient Reasoning: Continuous CoT can solve complex reasoning problems with shallower networks and fewer steps compared to discrete CoT.
- Parallel Search: The superposition mechanism allows LLMs to implicitly explore multiple reasoning paths in parallel, making them more robust for problems with large search spaces or branching factors.
- Model Design: The findings suggest that even simple 2-layer transformers, if structured correctly (e.g., with mechanisms for information routing like buffers and filtering like MLPs), can perform sophisticated reasoning when continuous thoughts are employed.
- Training Strategies: The multi-stage training curriculum, even with supervision only on optimal paths, can lead to the emergence of this parallel, superpositional search behavior.
Conclusion
The paper provides a strong theoretical and empirical case for the power of continuous thoughts in LLMs. By demonstrating that continuous thought vectors can maintain a superposition of reachable states, it explains how these models can perform efficient parallel searches. This work offers valuable insights for designing more capable and efficient reasoning systems. Future directions include deriving lower bounds for discrete CoT on such problems and further understanding the training dynamics that lead to the emergence of these exploration behaviors.
Related Papers
Tweets
YouTube
HackerNews
- Reasoning by Superposition: A Perspective on Chain of Continuous Thought (60 points, 1 comment)
- [R] Reasoning by Superposition: A Theoretical Perspective on Chain of Continuous Thought (47 points, 9 comments)