Reasoning Models Know When They're Right: Probing Hidden States for Self-Verification (2504.05419v1)
Abstract: Reasoning models have achieved remarkable performance on tasks like math and logical reasoning thanks to their ability to search during reasoning. However, they still suffer from overthinking, often performing unnecessary reasoning steps even after reaching the correct answer. This raises the question: can models evaluate the correctness of their intermediate answers during reasoning? In this work, we study whether reasoning models encode information about answer correctness through probing the model's hidden states. The resulting probe can verify intermediate answers with high accuracy and produces highly calibrated scores. Additionally, we find models' hidden states encode correctness of future answers, enabling early prediction of the correctness before the intermediate answer is fully formulated. We then use the probe as a verifier to decide whether to exit reasoning at intermediate answers during inference, reducing the number of inference tokens by 24\% without compromising performance. These findings confirm that reasoning models do encode a notion of correctness yet fail to exploit it, revealing substantial untapped potential to enhance their efficiency.
Summary
- The paper reveals that hidden states in reasoning models encode intermediate answer correctness, achieving high ROC-AUC scores in self-verification.
- It employs a two-layer MLP probe with weighted binary cross-entropy to extract and assess correctness from each reasoning chunk.
- The study demonstrates a confidence-based early exit strategy that reduces inference tokens by up to 24% without sacrificing accuracy.
This research investigates whether reasoning models internally encode information about the correctness of their intermediate answers during multi-step reasoning processes. The authors find that this information is indeed present in the model's hidden states and can be extracted using a simple probing mechanism. This "hidden verifier" can then be used to improve inference efficiency by enabling an early-exit strategy.
Core Problem and Motivation:
Reasoning models often employ a search-like process, generating multiple intermediate reasoning paths and answers before arriving at a final solution. While effective, this can lead to "overthinking," where models continue to perform unnecessary reasoning steps even after a correct intermediate answer has been found. The paper questions whether models can internally assess the correctness of these intermediate steps.
Methodology: Probing for Intermediate Answer Correctness
The core idea is to train a probe (a small classifier) to predict the correctness of intermediate answers based on the reasoning model's hidden states.
- Data Collection and Preparation:
- Generate Reasoning Traces: Obtain long Chain-of-Thought (CoT) outputs from a reasoning model for a given task (e.g., math problems).
- Segment Traces into Chunks: The CoT, often encapsulated in
>
tokens, is split into paragraphs (using\n\n
as a delimiter). Keywords like "wait," "double-check," or "alternatively" are used to identify the start of new reasoning paths, and paragraphs within the same path are merged into a "chunk." (A full list of keywords is in Appendix Table 1 of the paper). > * Extract Intermediate Answers and Labels: For each chunk, an external model (Gemini 2.0 Flash) is used to: > * Extract the intermediate answer, if one exists. > * Judge its correctness against the ground-truth answer, providing a binary label (correct/incorrect). > * Handle Chunks without Answers: Adjacent chunks without intermediate answers are merged with the closest chunk that does contain an answer. > * Obtain Hidden State Representations: For each chunk ci, the last-layer hidden state at the position of the last token of that chunk is taken as its representation ei. > * Create Probing Dataset: This process yields a dataset D={(ei,yi)}i=1N, where ei is the hidden state representation and yi is the correctness label. > > 2. Training the Probe: > * Architecture: A two-layer Multilayer Perceptron (MLP) is used as the probe. > > pi=σ(ReLU(eiW1+b1)W2+b2) > > where σ is the sigmoid function, ei is the input hidden state, W1,b1,W2,b2 are learnable parameters. The paper notes that often a linear probe (d=0, where d is the MLP hidden size) performs well. > * Loss Function: A weighted binary cross-entropy loss is used to handle class imbalance (often, more intermediate answers are correct for strong models): > > L(W,b)=−N1i=1∑N(wαyilogpi+(1−yi)log(1−pi)) > > where w is the ratio of negative to positive samples, and α is a hyperparameter to scale the imbalance weight. > > Implementation of Probing: > > > To implement this, you would: > > 1. Use a reasoning LLM to generate detailed CoT solutions for your task. > > 2. Parse these solutions: > * Identify distinct reasoning "chunks" as described. > * For each chunk, use a powerful model (like Gemini, or even GPT-4) or manual annotation to extract the intermediate answer and label its correctness. >
> > 3. Store these (hidden state, label) pairs. > > 4. Train an MLP (e.g., using PyTorch or TensorFlow) on this dataset. The input dimension of the MLP will be the hidden size of the reasoning LLM. The output will be a single logit for binary classification. > > Key Experimental Findings: > > > * Models Encode Correctness: Probes can predict intermediate answer correctness with high ROC-AUC (often >0.7) and low Expected Calibration Error (ECE < 0.1), indicating the information is reliably encoded. Simpler linear probes often suffice. > > * Generalization: Probes generalize well within the same domain (e.g., trained on MATH, tested on GSM8K) but less so across different domains (e.g., math to logical reasoning). > > * Importance of Long CoT Training: Probes trained on hidden states of standard instruction-tuned models (which produce shorter CoTs) perform significantly worse than those trained on reasoning models (fine-tuned for long CoT). This suggests the self-verification ability is enhanced during long CoT supervised training. > > * Look-ahead Capability: Hidden states before an intermediate answer is fully generated already contain signals about its future correctness. Accuracy improves as the generation gets closer to the answer. Calibration (ECE) can reach its minimum even before the answer is fully formed (around 60% through the chunk in one experiment). > > Application: Probe as a Verifier for Early Exit > > The well-calibrated probability scores from the probe can be used to decide when to stop the reasoning process early, saving computation. > > 1. Confidence-Based Early Exit: > * During inference, as the model generates its reasoning trace, it's processed chunk by chunk. > * For each chunk producing an intermediate answer, its hidden state is fed to the trained probe. > * If the probe's confidence score pi (probability of correctness) exceeds a predefined threshold Thr, the reasoning is truncated, and this intermediate answer is taken as the final answer. > >1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
# Conceptual pseudocode for data collection def collect_probing_data(task_prompts, reasoning_model, labeling_model): probing_dataset = [] for prompt in task_prompts: full_cot_output = reasoning_model.generate(prompt) # Contains <think>...</think> # 1. Extract reasoning trace within <think> tokens reasoning_trace = extract_think_trace(full_cot_output) # 2. Segment into chunks based on keywords and newlines chunks = segment_into_chunks(reasoning_trace) ground_truth_answer = get_ground_truth(prompt) for i, chunk_text in enumerate(chunks): # 3. Extract intermediate answer and correctness label # This might involve prompting another LLM (e.g., Gemini) intermediate_answer, is_correct = labeling_model.evaluate_chunk( chunk_text, ground_truth_answer ) if intermediate_answer is not None: # 4. Get hidden state for the chunk # This requires running the chunk_text (or the prefix of CoT up to this chunk) # through the reasoning_model and extracting the last token's hidden state # from the last layer. hidden_state = get_hidden_state_for_chunk(reasoning_model, chunk_text_up_to_this_point) probing_dataset.append({"hidden_state": hidden_state, "label": is_correct}) return probing_dataset
> > 2. Results: > * Using this strategy on the MATH dataset with R1-Distill-Llama-8B, a 24% reduction in inference tokens was achieved with no loss in accuracy (threshold 0.85). A 19% token reduction was seen with identical accuracy (threshold 0.9). > * This dynamic confidence-based early exit significantly outperforms static early-exit (e.g., stopping after a fixed number of chunks), achieving higher accuracy for similar token savings. > > Practical Implications and Implementation Considerations: > > > * Efficiency Gains: This method offers a practical way to reduce the computational cost of inference for reasoning tasks by avoiding overthinking, without needing to modify the model architecture or retraining the base LLM (only a small probe is trained). > > * Lightweight Verifier: The probe is a small MLP, making it very fast to run compared to using another LLM as a verifier. > > * On-Policy Control: The verifier uses the internal states of the same model that is doing the reasoning, making it an "on-policy" control mechanism. > > * Data Annotation for Probe Training: A crucial step is obtaining the (chunk, hidden_state, correctness_label) data. While the paper uses Gemini 2.0 Flash for labeling, this could be a bottleneck or require careful prompt engineering if applied to new domains/models. Manual labeling or using other sophisticated methods might be needed. > > * Choosing the Right Layer/Token: The paper uses the last-layer hidden state at the last token of a chunk. Experimentation might be needed for different models or tasks to find the most informative representation. > > * Calibration is Key: The effectiveness of the early-exit strategy relies on the probe's output being well-calibrated. The weighted loss helps, but calibration should be monitored. > > * Domain Specificity: Probes are most effective when trained and applied within the same domain. For a new application, you'd likely need to collect data and train a new probe specific to that domain and the reasoning model being used. > > * Threshold Tuning: The confidence threshold for early exit (Thr) is a hyperparameter that needs to be tuned on a validation set to balance accuracy and token savings. > > * Applicability: This technique is most relevant for models and tasks that involve long, multi-step reasoning chains where intermediate conclusions are drawn. > > In summary, the paper demonstrates that reasoning models possess a latent self-verification capability. By training a simple probe on their hidden states, one can extract signals about the correctness of intermediate reasoning steps. These signals can be practically leveraged to implement an early-exit mechanism, significantly improving inference efficiency without sacrificing performance. This opens avenues for more adaptive and resource-aware deployment of reasoning LLMs.1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
# Conceptual pseudocode for early exit def inference_with_early_exit(prompt, reasoning_model, probe, threshold): generated_tokens = [] current_reasoning_path = "" intermediate_answers_found = [] # Stream generation or generate chunk by chunk for token in reasoning_model.stream_generate(prompt): generated_tokens.append(token) current_reasoning_path += token_to_text(token) # Check if current_reasoning_path forms a complete chunk # and contains an intermediate answer if is_new_chunk_and_has_answer(current_reasoning_path): intermediate_answer = extract_intermediate_answer(current_reasoning_path) hidden_state = get_hidden_state_for_chunk(reasoning_model, generated_tokens) # From the model's current state probe_confidence = probe.predict_proba(hidden_state)[1] # Prob of being correct if probe_confidence >= threshold: print(f"Early exit at chunk with confidence: {probe_confidence}") return intermediate_answer, generated_tokens # Return answer and tokens used # Reset for next chunk or continue current path # ... # If no early exit, return the final answer from the full generation final_answer = extract_final_answer_from_full_cot("".join(token_to_text(t) for t in generated_tokens)) return final_answer, generated_tokens