Chain-of-Thought Reasoning is a Policy Improvement Operator (2309.08589v2)
Abstract: LLMs have astounded the world with fascinating new capabilities. However, they currently lack the ability to teach themselves new skills, relying instead on large amounts of human-generated training data. We introduce SECToR (Self-Education via Chain-of-Thought Reasoning), a proof-of-concept demonstration that LLMs can teach themselves new skills using chain-of-thought reasoning. During the self-learning loop, SECToR asks models to solve addition problems using chain-of-thought reasoning before training the next version of the model to solve those same problems directly without using such reasoning. This process often results in an improved model which can, when again augmented with chain-of-thought reasoning, solve even harder problems than the original model, allowing the self-learning loop to continue. LLMs trained via SECToR autonomously learn to add up to the longest-length-digit numbers without access to any ground truth examples beyond an initial supervised fine-tuning phase consisting only of numbers with 6 or fewer digits. Our central hypothesis is that chain-of-thought reasoning can act as a policy improvement operator, similarly to how Monte-Carlo Tree Search is used in AlphaZero (Silver et al., 2017). We hope that this research can lead to new directions in which LLMs can learn to teach themselves without the need for human demonstrations.
Summary
- The paper introduces SECToR, showing that using chain-of-thought reasoning as a policy improvement operator allows language models to self-improve on multi-digit addition.
- It outlines a two-phase method with supervised fine-tuning followed by self-training using model-generated data and robust consistency checks.
- Experimental results demonstrate that the approach achieves significant generalization in arithmetic tasks, reducing dependence on human-curated data.
The paper "Chain-of-Thought Reasoning is a Policy Improvement Operator" (2309.08589) introduces SECToR (Self-Education via Chain-of-Thought Reasoning), a method demonstrating that LLMs can teach themselves new skills, specifically multi-digit addition, without continuous human-provided data. The core idea is that chain-of-thought (CoT) reasoning acts as a "policy improvement operator," analogous to how Monte-Carlo Tree Search (MCTS) improves policies in systems like AlphaZero.
Core Concept: CoT as a Policy Improvement Operator
The central hypothesis is that prompting a LLM to use step-by-step CoT reasoning allows it to solve problems it couldn't solve directly. SECToR leverages this by:
- Having the current model (Modelt) use CoT to generate solutions for problems slightly beyond its direct capabilities.
- Training the next iteration of the model (Modelt+1) to produce these CoT-generated solutions directly, without explicit reasoning steps.
- The improved Modelt+1, when augmented with CoT, can then tackle even more complex problems, allowing the self-learning loop to continue.
This process is visualized in Figure 1 of the paper and further detailed in an appendix figure (reproduced below for clarity):
1 2 3 4 5 6 7 8 9 10 11 |
Initial Problems (e.g., 167+708) --> Model_t +CoT | V Model_t solves with CoT (e.g., 167+708 = [CoT steps...] A: 875) --> Train Model_t+1 to solve directly | V Model_t+1 | V Repeat on harder problems (e.g., 2632+8647) <--------------------------- |
Methodology: SECToR for Addition
The paper demonstrates SECToR using multi-digit addition as a benchmark task. The process involves two main phases:
- Supervised Fine-Tuning:
- Initial Training: A pre-trained LLM (ByT5, chosen for its byte-level tokenization to avoid arithmetic tokenization issues) is first fine-tuned on addition problems with a small number of digits (e.g., 1 to 6 digits).
- Two Task Types:
- Fast Addition (without CoT): The model is trained to output the sum directly (e.g., "Q: 141 + 123 = ? A: 264.").
- Slow Addition (with CoT): The model is trained to perform one step of simplification, akin to how children learn addition (e.g., "Q: 141 + 123 = ? A: The first number's last digit is 1... The next subproblem is 14 + 12.").
- Curriculum Learning: The model must achieve satisfactory performance on N-digit problems before (N+1)-digit problems are introduced. This is crucial for building foundational skills. Satisfactory performance was defined as ≥75% accuracy on "fast" N-digit addition and 100% on "slow" N-digit addition.
- Transition to Self-Training: This phase ends when the model demonstrates strong generalization to (N+1)-digit "slow" addition using CoT, even though it was only trained up to N-digit problems. For the 582M parameter ByT5 model, this occurred after training on 1-6 digit addition, showing generalization to 7-digit "slow" addition.
- Self-Training:
- Model-Generated Data: All new training data is generated by the model itself, without access to ground truth answers.
- Generating "Slow" Examples: Since CoT-augmented models generalize well, "slow" addition examples for (N+1)-digits are generated by directly sampling from the current model using greedy decoding.
- Generating "Fast" Examples (Simplify-then-Guess): "Fast" addition doesn't generalize as well. To generate these examples, SECToR uses a method called "simplify-then-guess":
- The model is asked to simplify an (N+1)-digit problem K times using its "slow" CoT ability (e.g., an 8-digit problem becomes a 7-digit problem, then a 6-digit, etc.).
- After each simplification step, the model directly guesses the final solution to the original (N+1)-digit problem using its current "fast" addition capability on the simplified sub-problem.
- The final answer for the (N+1)-digit "fast" addition training example is determined by a majority vote over these K intermediate guesses. The paper used K=5. This process is illustrated in Figure 3 of the paper.
* Mitigating Error Avalanching: A key challenge in self-training is "error avalanching," where small errors in model-generated data compound over iterations. SECToR employs consistency checks: * Simplify-then-Guess inherently: The majority vote provides some robustness. * Commutativity Checks: For any problem a+b, the model also solves b+a. * For "fast" addition, the answers for a+b and b+a (generated via simplify-then-guess) must be an exact string match. * For "slow" addition, the check is that (performing one CoT simplification step + fast adding the resulting subproblem) yields identical final sums for a+b and b+a. If the checks fail, the problem-solution pair is discarded. This significantly reduces the introduction of incorrect data (as shown in Figure 5).
Implementation Details and Considerations
- Models: ByT5 models (582M and 300M parameters) were used. The byte-level nature of ByT5 helps avoid tokenization artifacts common with arithmetic tasks in other LLMs.
- Training:
- Adam optimizer, DeepSpeed library, constant learning rate of 10−4, bfloat16 training.
- Batch sizes: 2048 (300M model), 1024 (582M model).
- Data generation during supervised phase (per N-digit step): 10,000 unique N-digit CoT examples, 1,000 for each smaller digit length (anti-forgetting); 30,000 unique N-digit fast examples, 3,000 for smaller.
- Data generation during self-training phase: Numbers reduced by a factor of 10 due to higher generation cost.
- Computational Cost: The self-training phase, especially data generation with simplify-then-guess and consistency checks, is computationally intensive.
- Curriculum Learning is Key: The paper notes (Appendix C.7) that curriculum learning is vital. An ablation training a 582M model on 1-6 digits in a single step (not curriculum) generalized to 9-digit slow addition, but the curriculum is essential for the step-by-step self-improvement process where N-digit capabilities are used to generate (N+1)-digit data.
Pseudocode for SECToR Self-Training Loop (for N+1 digits):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
def generate_slow_addition_example(model_N, num1_str, num2_str): # Generates one step of CoT simplification for (num1_str + num2_str) prompt = f"Q: {num1_str} + {num2_str} = ? A:" # Use greedy decoding from model_N slow_reasoning_step = model_N.generate(prompt, type="slow_addition_token") return (prompt, slow_reasoning_step) def generate_fast_addition_example_with_simplify_then_guess(model_N, num1_str, num2_str, K=5): guesses = [] current_problem_num1 = num1_str current_problem_num2 = num2_str for i in range(K): # 1. Perform one CoT simplification step on current_problem # (This generates a new, smaller (current_problem_num1_simplified, current_problem_num2_simplified) # and a partial_sum_digit) # For simplicity, assume model_N can do this step internally. # Let's say it simplifies to (n1_simplified, n2_simplified) # 2. Model_N directly guesses the *original* sum based on the simplified problem # This requires careful prompting or a specific head if model_N is trained for it. # The paper implies this is "fast adding the remaining addition problem". # Assume model_N.fast_add(n1_simplified, n2_simplified) gives a candidate for original sum # This part is a bit abstract in the paper's description of simplify-then-guess. # More accurately: after k simplifications, the remaining subproblem is (say) M-digits. # The model "fast adds" this M-digit problem. This M-digit sum is one "guess". # This seems to be a slight misinterpretation in my initial thought. # Re-reading Figure 3: After each simplification, the model *directly guesses the final solution* # to the original N+1 digit problem without further reasoning. # This implies it uses its "fast add" capability on the *remaining sub-problem* derived from the original one. # The illustration shows Guess #1 is from the 7-digit subproblem of an 8-digit original. # Let's refine based on Figure 3: # Original problem: P_orig = (num1_str, num2_str) # After 1 CoT step: P_orig -> P_sub1 (e.g., 8-digit -> 7-digit) # Guess 1: model_N.fast_add(P_sub1) -> combined with first step's result to form full guess for P_orig. # After 2 CoT steps: P_orig -> P_sub2 (e.g., 8-digit -> 6-digit) # Guess 2: model_N.fast_add(P_sub2) -> combined with first two steps' results. # ... # This continues for K simplifications. # The paper states: "After each simplification of the problem, simplify-then-guess asks the model to # directly guess the final solution without using any further reasoning steps." # This means the "guess" is for the *original* problem, using the current simplified state as a helper. # Let's assume a function: # guess = model_N.guess_original_solution_from_simplified(original_problem, current_simplified_subproblem) # For an 8-digit problem simplified to a 7-digit subproblem: # prompt_for_guess = f"Original: 12345678+87654321. Current subproblem: 2345678+7654321 (after handling 1+8=9). Final answer?" # generated_guess = model_N.generate(prompt_for_guess, type="fast_addition_token_for_original") # guesses.append(generated_guess) # For the sake of pseudocode, let's assume a more direct interpretation from Figure 3: # After i CoT steps, we have a subproblem P_sub_i. # The "guess" is the result of fast_adding P_sub_i and then reconstructing the full sum. # This is essentially what least-to-most would do if it solved the subproblem directly. # The "self-consistency" comes from *multiple such paths* if K > number of digits. # The paper: "simplify-then-guess generates K separate guesses for an addition problem by applying # between 1 and K simplification steps before fast adding the remaining addition problem." # This is clearer. if i < min(len(num1_str), K): # Max K simplifications or until 1-digit # Perform i+1 CoT steps on (num1_str, num2_str) to get P_sub_(i+1) # and the sequence of carries/partial sums so far. # result_of_fast_add_subproblem = model_N.fast_add(P_sub_(i+1)) # Combine sequence of carries/partial sums + result_of_fast_add_subproblem # This forms one complete guess for the original problem. # For simplicity, let's just say: simplified_problem, intermediate_steps_output = model_N.simplify_problem_n_steps(num1_str, num2_str, steps=i+1) remaining_sum = model_N.fast_add(simplified_problem.num1, simplified_problem.num2) # this is the "fast add" part full_guess = model_N.reconstruct_final_answer(intermediate_steps_output, remaining_sum) guesses.append(full_guess) else: break # No more simplifications possible/needed final_answer = majority_vote(guesses) return (f"Q: {num1_str} + {num2_str} = ? A:", final_answer) new_training_data_slow = [] new_training_data_fast = [] for _ in range(num_examples_to_generate_per_type): # Generate random (N+1)-digit numbers a, b a_str, b_str = generate_random_numbers(digits=N+1) # Generate "slow" example prompt_slow, solution_slow = generate_slow_addition_example(Model_N, a_str, b_str) # Commutativity check for slow: # Simplified(a+b)_step1 + FastAdd(subproblem_ab) == Simplified(b+a)_step1 + FastAdd(subproblem_ba) # If check passes: new_training_data_slow.append((prompt_slow, solution_slow)) # Generate "fast" example prompt_fast, solution_fast_ab = generate_fast_addition_example_with_simplify_then_guess(Model_N, a_str, b_str) _, solution_fast_ba = generate_fast_addition_example_with_simplify_then_guess(Model_N, b_str, a_str) # Commutative twin if solution_fast_ab == solution_fast_ba: # Commutativity check for fast new_training_data_fast.append((prompt_fast, solution_fast_ab)) Model_N_plus_1 = fine_tune(Model_N, new_training_data_slow + new_training_data_fast + old_data) |
Results
- 582M ByT5 Model: After supervised fine-tuning on 1-6 digit addition, it self-trained to accurately (98%+) perform up to 29-digit addition. This involved 22 steps of self-improvement. The final model could add 30-digit numbers with 88% accuracy without CoT.
- 300M ByT5 Model: Supervised training up to 8 digits, then self-trained up to 24-digit addition.
- Generalization: Models showed poor length generalization for "fast" addition. However, with CoT ("slow" addition), generalization to N+1 digits occurred much earlier (e.g., after training up to N=4 or N=6 digits for the 582M model, as seen in Figure 2). This strong CoT generalization is what enables the self-training.
- Error Avalanching: While mitigated, it eventually caused training to terminate. The 582M model failed to continue after 29-digit addition.
Practical Applications and Implications
- Reducing Human Data Dependency: Demonstrates a path towards models that can improve and acquire new skills with less reliance on vast, human-curated datasets. This is significant given concerns about exhausting high-quality training data.
- Compute-Driven Scaling: If self-learning can be broadly applied, it could lead to scaling laws driven more by computational power than data availability.
- Improving Reasoning: The core mechanism could potentially be applied to more complex reasoning tasks beyond arithmetic, such as mathematics, programming, or logical deduction, provided effective CoT-like processes and consistency checks can be formulated for those domains.
- System Design for Self-Improving AI:
- Curriculum: A structured curriculum seems essential.
- Dual Capabilities: Training models for both direct ("fast") and step-by-step ("slow") problem-solving is beneficial.
- Self-Correction/Consistency: Robust mechanisms to filter out self-generated errors are critical. Commutativity is a domain-specific example; more general consistency principles would be needed for other tasks (e.g., logical consistency, consistency with known facts).
- Model Architecture: Byte-level models like ByT5 may be advantageous for tasks sensitive to tokenization.
Limitations
- Task Specificity: Success is shown on addition. Generalization to more complex, less formally verifiable tasks is an open question.
- Computational Inefficiency: Generating data via simplify-then-guess and consistency checks is resource-intensive.
- Error Accumulation: Self-training doesn't continue indefinitely.
- Safety and Control: Self-learning models could amplify biases or develop unpredictable capabilities, requiring research into safety and control mechanisms.
In summary, SECToR provides a proof-of-concept that LLMs can autonomously extend their capabilities on a structured task like addition. It operationalizes the idea of CoT reasoning as a policy improvement mechanism within a self-training loop, using curriculum learning and novel data generation/filtering techniques (simplify-then-guess, commutativity checks) to manage the process and mitigate error accumulation. The approach offers insights into building more data-efficient and continuously improving AI systems.
Related Papers
- Chain of Thought Prompting Elicits Reasoning in Large Language Models (2022)
- Contrastive Chain-of-Thought Prompting (2023)
- Why think step by step? Reasoning emerges from the locality of experience (2023)
- Implicit Chain of Thought Reasoning via Knowledge Distillation (2023)
- Large Language Models are In-context Teachers for Knowledge Reasoning (2023)