Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
80 tokens/sec
GPT-4o
59 tokens/sec
Gemini 2.5 Pro Pro
43 tokens/sec
o3 Pro
7 tokens/sec
GPT-4.1 Pro
50 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Quiet-STaR: Language Models Can Teach Themselves to Think Before Speaking (2403.09629v2)

Published 14 Mar 2024 in cs.CL, cs.AI, and cs.LG

Abstract: When writing and talking, people sometimes pause to think. Although reasoning-focused works have often framed reasoning as a method of answering questions or completing agentic tasks, reasoning is implicit in almost all written text. For example, this applies to the steps not stated between the lines of a proof or to the theory of mind underlying a conversation. In the Self-Taught Reasoner (STaR, Zelikman et al. 2022), useful thinking is learned by inferring rationales from few-shot examples in question-answering and learning from those that lead to a correct answer. This is a highly constrained setting -- ideally, a LLM could instead learn to infer unstated rationales in arbitrary text. We present Quiet-STaR, a generalization of STaR in which LMs learn to generate rationales at each token to explain future text, improving their predictions. We address key challenges, including 1) the computational cost of generating continuations, 2) the fact that the LM does not initially know how to generate or use internal thoughts, and 3) the need to predict beyond individual next tokens. To resolve these, we propose a tokenwise parallel sampling algorithm, using learnable tokens indicating a thought's start and end, and an extended teacher-forcing technique. Encouragingly, generated rationales disproportionately help model difficult-to-predict tokens and improve the LM's ability to directly answer difficult questions. In particular, after continued pretraining of an LM on a corpus of internet text with Quiet-STaR, we find zero-shot improvements on GSM8K (5.9%$\rightarrow$10.9%) and CommonsenseQA (36.3%$\rightarrow$47.2%) and observe a perplexity improvement of difficult tokens in natural text. Crucially, these improvements require no fine-tuning on these tasks. Quiet-STaR marks a step towards LMs that can learn to reason in a more general and scalable way.

This paper introduces Quiet-STaR, a method that enables LLMs (LMs) to learn to generate internal "thoughts" or "rationales" at each token to improve their prediction of future text. This contrasts with prior work like Self-Taught Reasoner (STaR), which focused on learning reasoning for specific question-answering tasks. Quiet-STaR aims to teach LMs to reason in a more general and scalable way by leveraging the diverse reasoning implicit in large, unstructured text corpora.

The core idea is that an LM can improve its ability to predict upcoming tokens if it first generates an internal rationale explaining why those tokens might appear. The process is framed as the LM learning to "think before speaking."

Key challenges addressed by Quiet-STaR include:

  1. Computational Cost: Generating rationales at every token position can be prohibitively expensive.
  2. Initial Inability: Pre-trained LMs don't initially know how to generate or use internal thoughts effectively.
  3. Beyond Next-Token Prediction: Useful thoughts often explain longer-term dependencies, not just the immediately next token.

Quiet-STaR operates in three main steps:

  1. Think (Parallel Rationale Generation):
    • At each token position xix_i in an input sequence x0:nx_{0:n}, the model generates rr candidate rationales (thoughts) of length tt.
    • Learned special tokens, <|startofthought|> and <|endofthought|>, are inserted to mark the beginning and end of each rationale.
    • A crucial contribution is a parallel sampling algorithm. This allows for generating thoughts for all token positions simultaneously within a batch. It works by constructing a special attention mask (Figure 3) where each generated thought token attends to itself, preceding thought tokens within the same thought, and the preceding context text, but not to other "counterfactual" thought paths. Each inference call generates one additional thought token for all positions.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# Pseudocode for Parallel Rationale Generation (simplified)
function generate_thoughts_parallel(model, text_tokens, num_thoughts_per_pos, thought_length):
    batch_size = text_tokens.shape[0]
    seq_len = text_tokens.shape[1]
    
    # Initialize hidden states from text_tokens
    current_hidden_states = model.get_hidden_states(text_tokens) 
    
    # Prepend <startofthought> to each position's context
    # (This is a conceptual step; in practice, it's handled by masking and input construction)
    
    generated_thoughts = empty_tensor_for_thoughts(batch_size, seq_len, num_thoughts_per_pos, thought_length)

    for t_step in range(thought_length):
        # Construct attention mask for parallel generation
        # Each thought token attends to its prefix and original text up to its start
        attention_mask = create_parallel_attention_mask(text_tokens, generated_thoughts[:,:,:,:t_step])
        
        # Get next token logits for all parallel thoughts
        next_token_logits = model.lm_head(current_hidden_states, attention_mask=attention_mask)
        
        # Sample next tokens for all thoughts in parallel
        # (Shape: batch_size, seq_len, num_thoughts_per_pos, 1)
        next_thought_tokens = sample_from_logits(next_token_logits) 
        
        generated_thoughts[:,:,:,t_step] = next_thought_tokens
        
        # Update hidden states (caching previous states)
        current_hidden_states = model.get_hidden_states(
            original_text_plus_generated_thoughts, 
            attention_mask=attention_mask
        )
        
    # Append <endofthought>
    return generated_thoughts

  1. Talk (Mixing Predictions):
    • After a thought is generated, the LM predicts the next text tokens based on the context including the thought.
    • Simultaneously, the LM also has a prediction for the next text tokens without the thought (base prediction).
    • A "mixing head," a shallow Multi-Layer Perceptron (MLP), is trained. It takes the hidden state from the <|endofthought|> token and the hidden state of the original text token (before the thought) as input.
    • The mixing head outputs a weight wjw_j. This weight is used to interpolate between the logits of the base prediction and the post-rationale prediction:

      logpjtalk=wjlogpjinit+(1wj)logpjthought\log p_j^{\mathrm{talk}} = w_j \cdot \log p_{j}^{\mathrm{init}} + (1 - w_j) \cdot \log p_{j}^{\mathrm{thought}}

* This mixing helps stabilize training, especially early on when generated thoughts might be out-of-distribution and harm performance.

  1. Learn (Optimizing Rationale Generation):
    • The model is trained to generate better rationales using a REINFORCE-based algorithm.
    • Non-myopic Scoring and Teacher Forcing: The "goodness" of a thought is not just based on how well it helps predict the immediate next token (Xj+1X_{j+1}), but a sequence of ntruen_{true} future ground-truth tokens (Xj+1:j+ntrue+1X_{j+1:j+n_{true}+1}). Teacher forcing is used here: when calculating the probability of Xj+kX_{j+k}, it's assumed that the true tokens Xj+1,,Xj+k1X_{j+1}, \dots, X_{j+k-1} were generated. This is visualized in Figure 4.
    • Reward Definition: For each token position jj, multiple rationales TjT_j are generated. The reward rjr_j for a specific rationale is its ability to improve the log-likelihood of the ntruen_{true} future tokens compared to the average over all rationales generated for that position:

      rj=logpj:j+ntruetalk(Xj+1:j+ntrue+1)logpj:j+ntruetalk(Xj+1:j+ntrue+1)r_j = \log p_{j:j+n_{true}}^{\mathrm{talk}}(X_{j+1:j+n_{true} + 1}) - \log \overline{p}_{j:j+n_{true}}^{\mathrm{talk}}(X_{j+1:j+n_{true} + 1})

* REINFORCE Update: The gradient for the rationale generation is:

θLjREINFORCE=rj1[rj>0]θlogpθ(Tj[X:j;<startofthought>])\nabla_\theta \mathcal{L}_j^{\mathrm{REINFORCE}} = -r_j \cdot \mathbb{1}[r_j > 0] \cdot \nabla_\theta \log p_{\theta}(T_j | [X_{:j}; <|startofthought|>])

The paper notes that only positive rewards were used for stability. * Meta-Token Optimization: The embeddings for <|startofthought|> and <|endofthought|> are also learned. They are initialized with the embedding of "---" (em dash) to leverage existing knowledge of pauses. Their gradients are weighted more heavily to speed up learning. * Overall Loss: The total loss combines the NLL loss from the mixed prediction and the REINFORCE loss: θLj=θLjNLL+θLjREINFORCE\nabla_\theta\mathcal{L}_j = \nabla_\theta \mathcal{L}_j^{\mathrm{NLL}} + \nabla_\theta\mathcal{L}_j^{\mathrm{REINFORCE}}. The NLL loss ensures the mixing head learns and provides a signal to the base LM.

The algorithm is detailed in Algorithm 1 of the paper.

Experiments and Results:

  • The method was applied to a Mistral 7B model.
  • Training was primarily on OpenWebMath and also on C4.
  • Downstream Task Performance: Quiet-STaR showed zero-shot improvements on:
    • GSM8K (math word problems): Accuracy increased from 5.9% (baseline) to 10.9%.
    • CommonsenseQA: Accuracy increased from 36.3% (baseline) to 47.2%.
    • These improvements generally increased with the number of thought tokens used during Quiet-STaR training (Figure 2).
    • Training on C4 also showed improvements but to a lesser extent: GSM8K (5.9% \rightarrow 8.1%) and CommonsenseQA (36.3% \rightarrow 42.6%).
  • Improvement Distribution: Generated thoughts disproportionately helped predict difficult-to-predict tokens, while most tokens saw little change (Figure 5). Figure 6 visualizes where thoughts helped in an example text, suggesting benefits in recalling relevant information or structuring next steps.
  • Comparison to Pause Tokens: Quiet-STaR's multi-token rationales were found to be more effective than single "pause" tokens, which showed minor gains or even performance degradation on the same tasks.

Discussion and Analysis:

  • Training Instability: A key challenge is the co-adaptation of the thought generator and the mixing head. If the mixing head ignores thoughts, the generator gets no learning signal. Solutions explored included Gumbel-Softmax (vanishing gradients) and more complex RL methods (unstable reward functions). The chosen approach of a simple mixing head and REINFORCE with positive rewards proved more stable.
  • Interpretable Thoughts: While not explicitly optimized for human readability, generated thoughts were often partially understandable. Examples show thoughts recalling necessary preceding information or near-continuations of the target text.
  • Quiet-STaR vs. Chain-of-Thought (CoT): Quiet-STaR is orthogonal to CoT. CoT is explicit, prompted reasoning "out loud." Quiet-STaR is implicit, internal thinking at each token. They could be complementary (e.g., using Quiet-STaR during CoT generation).

Limitations:

  • The paper used a pre-trained model; performance when training from scratch is unknown.
  • Only applied to a 7B parameter model; larger models might show greater benefits.
  • Significant computational overhead due to generating many thought tokens.
  • The current implementation doesn't dynamically decide when to think or for how long.

Conclusion:

Quiet-STaR demonstrates a promising approach for LMs to learn general reasoning skills from unstructured text. It improves downstream reasoning without task-specific fine-tuning. Future work could explore ensembling thoughts, dynamic computation allocation for thinking, and applying it to larger models.

Practical Implementation Considerations:

  • Parallel Rationale Generation: This is key for making the approach scalable. Implementing the custom attention masks efficiently is important. Appendix B.1 suggests optimizations like elementwise dot-products for diagonal attention.
  • Meta-Tokens: The <|startofthought|> and <|endofthought|> tokens are crucial. Initializing them thoughtfully (e.g., with "---") and applying a higher learning rate to their embeddings can accelerate training.
  • Mixing Head: A simple MLP for the mixing head helps with stability. Its role is to smoothly integrate thoughts without disrupting the base LM's capabilities too early.
  • Non-Myopic Loss: Using ntrue>1n_{true} > 1 future tokens for the reward signal helps generate more meaningful, less noisy rationales.
  • REINFORCE: Using only positive rewards (rj1[rj>0]r_j \cdot \mathbb{1}[r_j > 0]) and averaging rewards over multiple rationale samples per position can stabilize the REINFORCE training.
  • Computational Resources: Training requires significant resources (e.g., 8x 80GB H100s mentioned for experiments). The overhead comes from generating tt thought tokens for (potentially) many positions in a sequence.
  • Hyperparameters: Careful tuning of learning rates, batch size, thought length (tt), number of future tokens for supervision (ntruen_{true}), and the number of thoughts sampled per position is necessary. Appendix A provides some hyperparameters used.

The paper provides a strong foundation for building LMs that can "think" more deeply about the text they process and generate, moving beyond simple pattern matching towards more robust reasoning.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (75)
  1. Thinking fast and slow with deep learning and tree search. Advances in neural information processing systems, 30, 2017.
  2. Fireact: Toward language agent fine-tuning. arXiv preprint arXiv:2310.05915, 2023.
  3. Scaling instruction-finetuned language models. arXiv preprint arXiv:2210.11416, 2022.
  4. Training Verifiers to Solve Math Word Problems. arXiv, 2021. _eprint: 2110.14168.
  5. Strategic reasoning with language models. arXiv preprint arXiv:2305.19165, 2023.
  6. Are we modeling the task or the annotator? an investigation of annotator bias in natural language understanding datasets. arXiv preprint arXiv:1908.07898, 2019.
  7. Think before you speak: Training language models with pause tokens. arXiv preprint arXiv:2310.02226, 2023.
  8. Reinforced self-training (rest) for language modeling. arXiv preprint arXiv:2308.08998, 2023.
  9. Textbooks are all you need. arXiv preprint arXiv:2306.11644, 2023.
  10. Language models can teach themselves to program better. In The Eleventh International Conference on Learning Representations, 2023. URL https://openreview.net/forum?id=SaRj2ka1XZ3.
  11. Backpack language models. arXiv preprint arXiv:2305.16765, 2023.
  12. Large language models are reasoning teachers. arXiv preprint arXiv:2212.10071, 2022.
  13. Training chain-of-thought via latent-variable inference. Advances in Neural Information Processing Systems, 36, 2024.
  14. V-star: Training verifiers for self-taught reasoners. arXiv preprint arXiv:2402.06457, 2024.
  15. Distilling step-by-step! outperforming larger language models with less training data and smaller model sizes. arXiv preprint arXiv:2305.02301, 2023.
  16. Large language models can self-improve. arXiv preprint arXiv:2210.11610, 2022.
  17. Categorical reparameterization with gumbel-softmax. arXiv preprint arXiv:1611.01144, 2016.
  18. Mistral 7b. arXiv preprint arXiv:2310.06825, 2023.
  19. Discrete prompt compression with reinforcement learning. arXiv preprint arXiv:2308.08758, 2023.
  20. Demonstrate-search-predict: Composing retrieval and language models for knowledge-intensive nlp. arXiv preprint arXiv:2212.14024, 2022.
  21. Dspy: Compiling declarative language model calls into self-improving pipelines. arXiv preprint arXiv:2310.03714, 2023.
  22. Large Language Models are Zero-Shot Reasoners, 2022. URL https://arxiv.org/abs/2205.11916.
  23. Can language models learn from explanations in context? arXiv preprint arXiv:2204.02329, 2022.
  24. Learning to reason and memorize with self-notes. Advances in Neural Information Processing Systems, 36, 2024.
  25. The power of scale for parameter-efficient prompt tuning. arXiv preprint arXiv:2104.08691, 2021.
  26. Solving quantitative reasoning problems with language models. Advances in Neural Information Processing Systems, 35:3843–3857, 2022.
  27. Automated statistical model discovery with language models. arXiv preprint arXiv:2402.17879, 2024.
  28. Explanations from large language models make small reasoners better. arXiv preprint arXiv:2210.06726, 2022.
  29. Prefix-tuning: Optimizing continuous prompts for generation. arXiv preprint arXiv:2101.00190, 2021.
  30. Compressing context to enhance inference efficiency of large language models. arXiv preprint arXiv:2310.06201, 2023.
  31. Crystal: Introspective reasoners reinforced with self-feedback. arXiv preprint arXiv:2310.04921, 2023.
  32. Wizardmath: Empowering mathematical reasoning for large language models via reinforced evol-instruct. arXiv preprint arXiv:2308.09583, 2023.
  33. Self-refine: Iterative refinement with self. Feedback, 2023.
  34. Playing atari with deep reinforcement learning. arXiv preprint arXiv:1312.5602, 2013.
  35. Asynchronous methods for deep reinforcement learning. In International conference on machine learning, pp.  1928–1937. PMLR, 2016.
  36. Learning to compress prompts with gist tokens. Advances in Neural Information Processing Systems, 36, 2024.
  37. Show your work: Scratchpads for intermediate computation with language models. arXiv preprint arXiv:2112.00114, 2021.
  38. Feedback loops with language models drive in-context reward hacking. arXiv preprint arXiv:2402.06627, 2024.
  39. Openwebmath: An open dataset of high-quality mathematical web text. arXiv preprint arXiv:2310.06786, 2023.
  40. Training chain-of-thought via latent-variable inference. In Thirty-seventh Conference on Neural Information Processing Systems, 2023.
  41. Certified reasoning with language models. arXiv preprint arXiv:2306.04031, 2023.
  42. Generative Language Modeling for Automated Theorem Proving. CoRR, abs/2009.03393, 2020. URL https://arxiv.org/abs/2009.03393. _eprint: 2009.03393.
  43. Why think step by step? reasoning emerges from the locality of experience. Advances in Neural Information Processing Systems, 36, 2024.
  44. Autoact: Automatic agent learning from scratch via self-planning. arXiv preprint arXiv:2401.05268, 2024.
  45. Phenomenal yet puzzling: Testing inductive reasoning capabilities of language models with hypothesis refinement. arXiv preprint arXiv:2310.08559, 2023.
  46. Language models are unsupervised multitask learners. OpenAI blog, 1(8):9, 2019.
  47. Exploring the limits of transfer learning with a unified text-to-text transformer. Journal of machine learning research, 21(140):1–67, 2020.
  48. Explain yourself! leveraging language models for commonsense reasoning. In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, pp.  4932–4942, 2019.
  49. Toolformer: Language models can teach themselves to use tools. Advances in Neural Information Processing Systems, 36, 2024.
  50. Proximal policy optimization algorithms. arXiv preprint arXiv:1707.06347, 2017.
  51. Programming Puzzles. In Thirty-fifth Conference on Neural Information Processing Systems, 2021. URL https://openreview.net/forum?id=fe_hCc4RBrg.
  52. Reflexion: Language agents with verbal reinforcement learning. arXiv preprint arXiv:2303.11366, 2023.
  53. Unsupervised commonsense question answering with self-talk. In Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP), pp.  4615–4629, 2020.
  54. Mastering chess and shogi by self-play with a general reinforcement learning algorithm. arXiv preprint arXiv:1712.01815, 2017.
  55. Commonsenseqa: A question answering challenge targeting commonsense knowledge. arXiv preprint arXiv:1811.00937, 2018.
  56. Function vectors in large language models. arXiv preprint arXiv:2310.15213, 2023.
  57. Solving math word problems with process-and outcome-based feedback. Neural Information Processing Systems (NeurIPS 2022) Workshop on MATH-AI, 2022.
  58. Hypothesis search: Inductive reasoning with language models. arXiv preprint arXiv:2309.05660, 2023.
  59. Chain-of-thought reasoning without prompting. arXiv preprint arXiv:2402.10200, 2024.
  60. Language modelling as a multi-task problem. arXiv preprint arXiv:2101.11287, 2021.
  61. Finetuned language models are zero-shot learners. In International Conference on Learning Representations, 2021a.
  62. Finetuned language models are zero-shot learners. arXiv preprint arXiv:2109.01652, 2021b.
  63. Emergent Abilities of Large Language Models, October 2022a. URL http://arxiv.org/abs/2206.07682. arXiv:2206.07682 [cs].
  64. Chain of Thought Prompting Elicits Reasoning in Large Language Models, 2022b. URL https://arxiv.org/abs/2201.11903.
  65. Ronald J Williams. Simple statistical gradient-following algorithms for connectionist reinforcement learning. Machine learning, 8:229–256, 1992.
  66. React: Synergizing reasoning and acting in language models. International Conference on Learning Representations (ICLR 2023), 2022.
  67. Star: Bootstrapping reasoning with reasoning. Advances in Neural Information Processing Systems, 35:15476–15488, 2022.
  68. Parsel: Algorithmic reasoning with language models by composing decompositions, 2023a.
  69. Self-taught optimizer (stop): Recursively self-improving code generation. arXiv preprint arXiv:2310.02304, 2023b.
  70. Chain-of-thought reasoning is a policy improvement operator. arXiv preprint arXiv:2309.08589, 2023.
  71. In-context principle learning from mistakes. arXiv preprint arXiv:2402.05403, 2024.
  72. Automatic chain of thought prompting in large language models. arXiv preprint arXiv:2210.03493, 2022.
  73. Hop, union, generate: Explainable multi-hop reasoning without rationale supervision. arXiv preprint arXiv:2305.14237, 2023.
  74. Teaching algorithmic reasoning via in-context learning. arXiv preprint arXiv:2211.09066, 2022.
  75. Large language models can learn rules. arXiv preprint arXiv:2310.07064, 2023.
User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (6)
  1. Eric Zelikman (20 papers)
  2. Georges Harik (3 papers)
  3. Yijia Shao (18 papers)
  4. Varuna Jayasiri (1 paper)
  5. Nick Haber (48 papers)
  6. Noah D. Goodman (83 papers)
Citations (57)
Youtube Logo Streamline Icon: https://streamlinehq.com