This paper introduces Quiet-STaR, a method that enables LLMs (LMs) to learn to generate internal "thoughts" or "rationales" at each token to improve their prediction of future text. This contrasts with prior work like Self-Taught Reasoner (STaR), which focused on learning reasoning for specific question-answering tasks. Quiet-STaR aims to teach LMs to reason in a more general and scalable way by leveraging the diverse reasoning implicit in large, unstructured text corpora.
The core idea is that an LM can improve its ability to predict upcoming tokens if it first generates an internal rationale explaining why those tokens might appear. The process is framed as the LM learning to "think before speaking."
Key challenges addressed by Quiet-STaR include:
- Computational Cost: Generating rationales at every token position can be prohibitively expensive.
- Initial Inability: Pre-trained LMs don't initially know how to generate or use internal thoughts effectively.
- Beyond Next-Token Prediction: Useful thoughts often explain longer-term dependencies, not just the immediately next token.
Quiet-STaR operates in three main steps:
- Think (Parallel Rationale Generation):
- At each token position in an input sequence , the model generates candidate rationales (thoughts) of length .
- Learned special tokens,
<|startofthought|>
and<|endofthought|>
, are inserted to mark the beginning and end of each rationale. - A crucial contribution is a parallel sampling algorithm. This allows for generating thoughts for all token positions simultaneously within a batch. It works by constructing a special attention mask (Figure 3) where each generated thought token attends to itself, preceding thought tokens within the same thought, and the preceding context text, but not to other "counterfactual" thought paths. Each inference call generates one additional thought token for all positions.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
# Pseudocode for Parallel Rationale Generation (simplified) function generate_thoughts_parallel(model, text_tokens, num_thoughts_per_pos, thought_length): batch_size = text_tokens.shape[0] seq_len = text_tokens.shape[1] # Initialize hidden states from text_tokens current_hidden_states = model.get_hidden_states(text_tokens) # Prepend <startofthought> to each position's context # (This is a conceptual step; in practice, it's handled by masking and input construction) generated_thoughts = empty_tensor_for_thoughts(batch_size, seq_len, num_thoughts_per_pos, thought_length) for t_step in range(thought_length): # Construct attention mask for parallel generation # Each thought token attends to its prefix and original text up to its start attention_mask = create_parallel_attention_mask(text_tokens, generated_thoughts[:,:,:,:t_step]) # Get next token logits for all parallel thoughts next_token_logits = model.lm_head(current_hidden_states, attention_mask=attention_mask) # Sample next tokens for all thoughts in parallel # (Shape: batch_size, seq_len, num_thoughts_per_pos, 1) next_thought_tokens = sample_from_logits(next_token_logits) generated_thoughts[:,:,:,t_step] = next_thought_tokens # Update hidden states (caching previous states) current_hidden_states = model.get_hidden_states( original_text_plus_generated_thoughts, attention_mask=attention_mask ) # Append <endofthought> return generated_thoughts |
- Talk (Mixing Predictions):
- After a thought is generated, the LM predicts the next text tokens based on the context including the thought.
- Simultaneously, the LM also has a prediction for the next text tokens without the thought (base prediction).
- A "mixing head," a shallow Multi-Layer Perceptron (MLP), is trained. It takes the hidden state from the
<|endofthought|>
token and the hidden state of the original text token (before the thought) as input. The mixing head outputs a weight . This weight is used to interpolate between the logits of the base prediction and the post-rationale prediction:
* This mixing helps stabilize training, especially early on when generated thoughts might be out-of-distribution and harm performance.
- Learn (Optimizing Rationale Generation):
- The model is trained to generate better rationales using a REINFORCE-based algorithm.
- Non-myopic Scoring and Teacher Forcing: The "goodness" of a thought is not just based on how well it helps predict the immediate next token (), but a sequence of future ground-truth tokens (). Teacher forcing is used here: when calculating the probability of , it's assumed that the true tokens were generated. This is visualized in Figure 4.
Reward Definition: For each token position , multiple rationales are generated. The reward for a specific rationale is its ability to improve the log-likelihood of the future tokens compared to the average over all rationales generated for that position:
* REINFORCE Update: The gradient for the rationale generation is:
The paper notes that only positive rewards were used for stability.
* Meta-Token Optimization: The embeddings for <|startofthought|>
and <|endofthought|>
are also learned. They are initialized with the embedding of "---" (em dash) to leverage existing knowledge of pauses. Their gradients are weighted more heavily to speed up learning.
* Overall Loss: The total loss combines the NLL loss from the mixed prediction and the REINFORCE loss: . The NLL loss ensures the mixing head learns and provides a signal to the base LM.
The algorithm is detailed in Algorithm 1 of the paper.
Experiments and Results:
- The method was applied to a Mistral 7B model.
- Training was primarily on OpenWebMath and also on C4.
- Downstream Task Performance: Quiet-STaR showed zero-shot improvements on:
- GSM8K (math word problems): Accuracy increased from 5.9% (baseline) to 10.9%.
- CommonsenseQA: Accuracy increased from 36.3% (baseline) to 47.2%.
- These improvements generally increased with the number of thought tokens used during Quiet-STaR training (Figure 2).
- Training on C4 also showed improvements but to a lesser extent: GSM8K (5.9% 8.1%) and CommonsenseQA (36.3% 42.6%).
- Improvement Distribution: Generated thoughts disproportionately helped predict difficult-to-predict tokens, while most tokens saw little change (Figure 5). Figure 6 visualizes where thoughts helped in an example text, suggesting benefits in recalling relevant information or structuring next steps.
- Comparison to Pause Tokens: Quiet-STaR's multi-token rationales were found to be more effective than single "pause" tokens, which showed minor gains or even performance degradation on the same tasks.
Discussion and Analysis:
- Training Instability: A key challenge is the co-adaptation of the thought generator and the mixing head. If the mixing head ignores thoughts, the generator gets no learning signal. Solutions explored included Gumbel-Softmax (vanishing gradients) and more complex RL methods (unstable reward functions). The chosen approach of a simple mixing head and REINFORCE with positive rewards proved more stable.
- Interpretable Thoughts: While not explicitly optimized for human readability, generated thoughts were often partially understandable. Examples show thoughts recalling necessary preceding information or near-continuations of the target text.
- Quiet-STaR vs. Chain-of-Thought (CoT): Quiet-STaR is orthogonal to CoT. CoT is explicit, prompted reasoning "out loud." Quiet-STaR is implicit, internal thinking at each token. They could be complementary (e.g., using Quiet-STaR during CoT generation).
Limitations:
- The paper used a pre-trained model; performance when training from scratch is unknown.
- Only applied to a 7B parameter model; larger models might show greater benefits.
- Significant computational overhead due to generating many thought tokens.
- The current implementation doesn't dynamically decide when to think or for how long.
Conclusion:
Quiet-STaR demonstrates a promising approach for LMs to learn general reasoning skills from unstructured text. It improves downstream reasoning without task-specific fine-tuning. Future work could explore ensembling thoughts, dynamic computation allocation for thinking, and applying it to larger models.
Practical Implementation Considerations:
- Parallel Rationale Generation: This is key for making the approach scalable. Implementing the custom attention masks efficiently is important. Appendix B.1 suggests optimizations like elementwise dot-products for diagonal attention.
- Meta-Tokens: The
<|startofthought|>
and<|endofthought|>
tokens are crucial. Initializing them thoughtfully (e.g., with "---") and applying a higher learning rate to their embeddings can accelerate training. - Mixing Head: A simple MLP for the mixing head helps with stability. Its role is to smoothly integrate thoughts without disrupting the base LM's capabilities too early.
- Non-Myopic Loss: Using future tokens for the reward signal helps generate more meaningful, less noisy rationales.
- REINFORCE: Using only positive rewards () and averaging rewards over multiple rationale samples per position can stabilize the REINFORCE training.
- Computational Resources: Training requires significant resources (e.g., 8x 80GB H100s mentioned for experiments). The overhead comes from generating thought tokens for (potentially) many positions in a sequence.
- Hyperparameters: Careful tuning of learning rates, batch size, thought length (), number of future tokens for supervision (), and the number of thoughts sampled per position is necessary. Appendix A provides some hyperparameters used.
The paper provides a strong foundation for building LMs that can "think" more deeply about the text they process and generate, moving beyond simple pattern matching towards more robust reasoning.