Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
10 tokens/sec
GPT-4o
12 tokens/sec
Gemini 2.5 Pro Pro
40 tokens/sec
o3 Pro
5 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Reasoning Models Know When They're Right: Probing Hidden States for Self-Verification (2504.05419v1)

Published 7 Apr 2025 in cs.AI and cs.CL

Abstract: Reasoning models have achieved remarkable performance on tasks like math and logical reasoning thanks to their ability to search during reasoning. However, they still suffer from overthinking, often performing unnecessary reasoning steps even after reaching the correct answer. This raises the question: can models evaluate the correctness of their intermediate answers during reasoning? In this work, we study whether reasoning models encode information about answer correctness through probing the model's hidden states. The resulting probe can verify intermediate answers with high accuracy and produces highly calibrated scores. Additionally, we find models' hidden states encode correctness of future answers, enabling early prediction of the correctness before the intermediate answer is fully formulated. We then use the probe as a verifier to decide whether to exit reasoning at intermediate answers during inference, reducing the number of inference tokens by 24\% without compromising performance. These findings confirm that reasoning models do encode a notion of correctness yet fail to exploit it, revealing substantial untapped potential to enhance their efficiency.

Summary

  • The paper reveals that hidden states in reasoning models encode intermediate answer correctness, achieving high ROC-AUC scores in self-verification.
  • It employs a two-layer MLP probe with weighted binary cross-entropy to extract and assess correctness from each reasoning chunk.
  • The study demonstrates a confidence-based early exit strategy that reduces inference tokens by up to 24% without sacrificing accuracy.

This research investigates whether reasoning models internally encode information about the correctness of their intermediate answers during multi-step reasoning processes. The authors find that this information is indeed present in the model's hidden states and can be extracted using a simple probing mechanism. This "hidden verifier" can then be used to improve inference efficiency by enabling an early-exit strategy.

Core Problem and Motivation:

Reasoning models often employ a search-like process, generating multiple intermediate reasoning paths and answers before arriving at a final solution. While effective, this can lead to "overthinking," where models continue to perform unnecessary reasoning steps even after a correct intermediate answer has been found. The paper questions whether models can internally assess the correctness of these intermediate steps.

Methodology: Probing for Intermediate Answer Correctness

The core idea is to train a probe (a small classifier) to predict the correctness of intermediate answers based on the reasoning model's hidden states.

  1. Data Collection and Preparation:
    • Generate Reasoning Traces: Obtain long Chain-of-Thought (CoT) outputs from a reasoning model for a given task (e.g., math problems).
    • Segment Traces into Chunks: The CoT, often encapsulated in > tokens, is split into paragraphs (using \n\n as a delimiter). Keywords like "wait," "double-check," or "alternatively" are used to identify the start of new reasoning paths, and paragraphs within the same path are merged into a "chunk." (A full list of keywords is in Appendix Table 1 of the paper). > * Extract Intermediate Answers and Labels: For each chunk, an external model (Gemini 2.0 Flash) is used to: > * Extract the intermediate answer, if one exists. > * Judge its correctness against the ground-truth answer, providing a binary label (correct/incorrect). > * Handle Chunks without Answers: Adjacent chunks without intermediate answers are merged with the closest chunk that does contain an answer. > * Obtain Hidden State Representations: For each chunk cic_i, the last-layer hidden state at the position of the last token of that chunk is taken as its representation eie_i. > * Create Probing Dataset: This process yields a dataset D={(ei,yi)}i=1N\mathcal{D} = \{(e_i, y_i)\}_{i=1}^N, where eie_i is the hidden state representation and yiy_i is the correctness label. > > 2. Training the Probe: > * Architecture: A two-layer Multilayer Perceptron (MLP) is used as the probe. > > pi=σ(ReLU(eiW1+b1)W2+b2)p_i = \sigma(\text{ReLU}(e_i\mathbf{W}_1 + \mathbf{b}_1)\mathbf{W}_2 + b_2) > > where σ\sigma is the sigmoid function, eie_i is the input hidden state, W1,b1,W2,b2\mathbf{W}_1, \mathbf{b}_1, \mathbf{W}_2, b_2 are learnable parameters. The paper notes that often a linear probe (d=0d=0, where dd is the MLP hidden size) performs well. > * Loss Function: A weighted binary cross-entropy loss is used to handle class imbalance (often, more intermediate answers are correct for strong models): > > L(W,b)=1Ni=1N(wαyilogpi+(1yi)log(1pi))\mathcal{L}(\mathbf{W}, \mathbf{b}) = -\frac{1}{N} \sum_{i=1}^{N} \left( w \alpha y_i \log p_i + (1 - y_i) \log (1 - p_i) \right) > > where ww is the ratio of negative to positive samples, and α\alpha is a hyperparameter to scale the imbalance weight. > > Implementation of Probing: > > > To implement this, you would: > > 1. Use a reasoning LLM to generate detailed CoT solutions for your task. > > 2. Parse these solutions: > * Identify distinct reasoning "chunks" as described. > * For each chunk, use a powerful model (like Gemini, or even GPT-4) or manual annotation to extract the intermediate answer and label its correctness. >
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      
      # Conceptual pseudocode for data collection
      def collect_probing_data(task_prompts, reasoning_model, labeling_model):
          probing_dataset = []
          for prompt in task_prompts:
              full_cot_output = reasoning_model.generate(prompt) # Contains <think>...</think>
              # 1. Extract reasoning trace within <think> tokens
              reasoning_trace = extract_think_trace(full_cot_output)
              # 2. Segment into chunks based on keywords and newlines
              chunks = segment_into_chunks(reasoning_trace)
              ground_truth_answer = get_ground_truth(prompt)
      
              for i, chunk_text in enumerate(chunks):
                  # 3. Extract intermediate answer and correctness label
                  # This might involve prompting another LLM (e.g., Gemini)
                  intermediate_answer, is_correct = labeling_model.evaluate_chunk(
                      chunk_text, ground_truth_answer
                  )
      
                  if intermediate_answer is not None:
                      # 4. Get hidden state for the chunk
                      # This requires running the chunk_text (or the prefix of CoT up to this chunk)
                      # through the reasoning_model and extracting the last token's hidden state
                      # from the last layer.
                      hidden_state = get_hidden_state_for_chunk(reasoning_model, chunk_text_up_to_this_point)
                      probing_dataset.append({"hidden_state": hidden_state, "label": is_correct})
          return probing_dataset
      > > 3. Store these (hidden state, label) pairs. > > 4. Train an MLP (e.g., using PyTorch or TensorFlow) on this dataset. The input dimension of the MLP will be the hidden size of the reasoning LLM. The output will be a single logit for binary classification. > > Key Experimental Findings: > > > * Models Encode Correctness: Probes can predict intermediate answer correctness with high ROC-AUC (often >0.7) and low Expected Calibration Error (ECE < 0.1), indicating the information is reliably encoded. Simpler linear probes often suffice. > > * Generalization: Probes generalize well within the same domain (e.g., trained on MATH, tested on GSM8K) but less so across different domains (e.g., math to logical reasoning). > > * Importance of Long CoT Training: Probes trained on hidden states of standard instruction-tuned models (which produce shorter CoTs) perform significantly worse than those trained on reasoning models (fine-tuned for long CoT). This suggests the self-verification ability is enhanced during long CoT supervised training. > > * Look-ahead Capability: Hidden states before an intermediate answer is fully generated already contain signals about its future correctness. Accuracy improves as the generation gets closer to the answer. Calibration (ECE) can reach its minimum even before the answer is fully formed (around 60% through the chunk in one experiment). > > Application: Probe as a Verifier for Early Exit > > The well-calibrated probability scores from the probe can be used to decide when to stop the reasoning process early, saving computation. > > 1. Confidence-Based Early Exit: > * During inference, as the model generates its reasoning trace, it's processed chunk by chunk. > * For each chunk producing an intermediate answer, its hidden state is fed to the trained probe. > * If the probe's confidence score pip_i (probability of correctness) exceeds a predefined threshold ThrThr, the reasoning is truncated, and this intermediate answer is taken as the final answer. > >
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      
      # Conceptual pseudocode for early exit
      def inference_with_early_exit(prompt, reasoning_model, probe, threshold):
          generated_tokens = []
          current_reasoning_path = ""
          intermediate_answers_found = []
      
          # Stream generation or generate chunk by chunk
          for token in reasoning_model.stream_generate(prompt):
              generated_tokens.append(token)
              current_reasoning_path += token_to_text(token)
      
              # Check if current_reasoning_path forms a complete chunk
              # and contains an intermediate answer
              if is_new_chunk_and_has_answer(current_reasoning_path):
                  intermediate_answer = extract_intermediate_answer(current_reasoning_path)
                  hidden_state = get_hidden_state_for_chunk(reasoning_model, generated_tokens) # From the model's current state
      
                  probe_confidence = probe.predict_proba(hidden_state)[1] # Prob of being correct
      
                  if probe_confidence >= threshold:
                      print(f"Early exit at chunk with confidence: {probe_confidence}")
                      return intermediate_answer, generated_tokens # Return answer and tokens used
      
                  # Reset for next chunk or continue current path
                  # ...
          
          # If no early exit, return the final answer from the full generation
          final_answer = extract_final_answer_from_full_cot("".join(token_to_text(t) for t in generated_tokens))
          return final_answer, generated_tokens
      > > 2. Results: > * Using this strategy on the MATH dataset with R1-Distill-Llama-8B, a 24% reduction in inference tokens was achieved with no loss in accuracy (threshold 0.85). A 19% token reduction was seen with identical accuracy (threshold 0.9). > * This dynamic confidence-based early exit significantly outperforms static early-exit (e.g., stopping after a fixed number of chunks), achieving higher accuracy for similar token savings. > > Practical Implications and Implementation Considerations: > > > * Efficiency Gains: This method offers a practical way to reduce the computational cost of inference for reasoning tasks by avoiding overthinking, without needing to modify the model architecture or retraining the base LLM (only a small probe is trained). > > * Lightweight Verifier: The probe is a small MLP, making it very fast to run compared to using another LLM as a verifier. > > * On-Policy Control: The verifier uses the internal states of the same model that is doing the reasoning, making it an "on-policy" control mechanism. > > * Data Annotation for Probe Training: A crucial step is obtaining the (chunk, hidden_state, correctness_label) data. While the paper uses Gemini 2.0 Flash for labeling, this could be a bottleneck or require careful prompt engineering if applied to new domains/models. Manual labeling or using other sophisticated methods might be needed. > > * Choosing the Right Layer/Token: The paper uses the last-layer hidden state at the last token of a chunk. Experimentation might be needed for different models or tasks to find the most informative representation. > > * Calibration is Key: The effectiveness of the early-exit strategy relies on the probe's output being well-calibrated. The weighted loss helps, but calibration should be monitored. > > * Domain Specificity: Probes are most effective when trained and applied within the same domain. For a new application, you'd likely need to collect data and train a new probe specific to that domain and the reasoning model being used. > > * Threshold Tuning: The confidence threshold for early exit (ThrThr) is a hyperparameter that needs to be tuned on a validation set to balance accuracy and token savings. > > * Applicability: This technique is most relevant for models and tasks that involve long, multi-step reasoning chains where intermediate conclusions are drawn. > > In summary, the paper demonstrates that reasoning models possess a latent self-verification capability. By training a simple probe on their hidden states, one can extract signals about the correctness of intermediate reasoning steps. These signals can be practically leveraged to implement an early-exit mechanism, significantly improving inference efficiency without sacrificing performance. This opens avenues for more adaptive and resource-aware deployment of reasoning LLMs.
X Twitter Logo Streamline Icon: https://streamlinehq.com