- The paper introduces AbstRaL, a framework that strengthens LLM reasoning through abstract problem decomposition and symbolic derivations.
- The paper shows that reinforcement learning outperforms supervised fine-tuning in teaching LLMs to generate faithful abstractions.
- The paper demonstrates practical improvements in handling input perturbations and reducing interference from distracting conditions.
The paper "AbstRaL: Augmenting LLMs' Reasoning by Reinforcing Abstract Thinking" (2506.07751) introduces a novel framework, AbstRaL, designed to enhance the reasoning robustness of LLMs, particularly smaller ones, when faced with distribution shifts like numerical/nominal variable changes or the insertion of distracting clauses. Instead of generating more synthetic instances of reasoning problems, AbstRaL focuses on teaching LLMs to "abstract" these problems. This approach aims to make reasoning invariant to contextual changes and facilitate connections with symbolic tools for solution derivation. The authors find that Reinforcement Learning (RL) is more effective than Supervised Fine-Tuning (SFT) alone for acquiring faithful abstractions.
AbstRaL Framework
The AbstRaL framework operates in four main steps:
- Condition Recognition: The input question X is parsed to identify relevant conditions C and formulate them with abstract symbols (e.g.,
in0
, in1
). This creates an abstract input question X^ where specific values are replaced by these symbols. This step can be performed using symbolic tools (e.g., regex for numerical values) or a prompted LLM. For mathematical reasoning, the paper uses a Llama-3.3-70B-Instruct model, prompted with few-shot examples, to label numerical values and convert implicit numbers (e.g., "one hundred" to "100") before symbolization.
- Abstract Reasoning: The core of AbstRaL, where an LLM is trained to generate an abstract answer Y^ from the abstract question X^. This involves quoting input symbols and using new abstract symbols for derived outputs (e.g.,
out0
). This learning process uses SFT followed by RL on specially constructed "Granularly-decomposed Abstract Reasoning" (GranulAR) data.
- Abstraction Retrieval: Based on the abstract answer Y^, a de-contextualized abstraction A (e.g., a set of symbolic equations) is retrieved. This can be done with regex matching or a prompted LLM. The paper uses a regex script to extract math derivations enclosed in double angle brackets (
<< >>
).
- Symbolic Derivation: The abstraction A, along with the input conditions C, is used to derive the final answer. This can be done by a rule-based symbolic parser (e.g., an equation solver) or a neural symbolic reasoner. The paper uses the SymPy equation solver for this.
GranulAR Data
A key component is the GranulAR training data, designed to integrate abstract reasoning with existing LLM capabilities like Chain-of-Thought (CoT) and Socratic problem decomposition.
- Format: The GranulAR answer Y^ first decomposes the question X^ into sub-questions. Each sub-question is then answered using CoT and abstract symbols, quoting relevant input conditions or previous sub-answers, and deriving the current sub-answer with symbols. Finally, a conclusion clarifies the final abstract output symbol.
- Construction:
1. Perform Condition Recognition on an original question X to get X^ and conditions C.
2. Prompt an oracle LLM (Llama-3.3-70B-Instruct) with X^ and the gold Socratic CoT answer Y to rewrite Y into the abstract GranulAR format Y^.
3. Perform Abstraction Retrieval on Y^ to get the abstraction A.
4. Verify if A and C can derive the correct final answer stated in Y. Only verified instances are kept.
The paper uses the Socratic version of the GSM8K training set as seed data, resulting in 6,386 problems after rewriting and filtering.
Learning Abstract Reasoning
AbstRaL employs a two-stage learning process:
- Supervised Fine-Tuning (SFT): LLMs are fine-tuned to auto-regressively generate the GranulAR answer Y^ given the abstract question X^, using a standard causal LLMing loss.
- SFT Hyperparameters: Batch size 8, learning rate 5e-6 (AdamW), 2 epochs. This took less than 1 hour on 4 A100-80GB GPUs.
- Reinforcement Learning (RL): To improve the faithfulness of generated abstractions beyond SFT, RL is used with novel model-free rewards. The GRPO algorithm is adopted.
Abstraction Rewards:
- Answer Correctness Reward (ranswer): A positive reward rcorrect (hyperparameter, set to 2.5) is given if the abstraction A^ (retrieved from the model's generated Y^) and gold conditions C derive the correct final answer; 0 otherwise.
Symbolic Distance Reward (rsymbolic): Measures the similarity between the model's tokenized abstraction A^ and the gold abstraction A (from GranulAR data).
rsymbolic(A^,A)=rmax⋅(1−EditDistance(A^,A)/a∈{A^,A}maxLen(a))
where rmax is a hyperparameter (set to 1.5), EditDistance is list-wise edit distance between tokenized abstractions, and Len is list length.
- GRPO Implementation: The reference policy πref is the SFT model. The group relative advantage Ri is calculated using the sum of ranswer and rsymbolic.
- RL Hyperparameters: rcorrect=2.5, rmax=1.5. GRPO: β=0.04, ϵ=0.2, group size G=16. Sampling: temperature 0.9, top_p 1.0, top_k 50. Learning rate 5e-7 (AdamW).
- RL Training Time: 8 epochs on 8 A100-80GB GPUs took about 3-5 days per LLM.
Experimental Setup and Results
- Task: Mathematical reasoning on GSM8K-derived benchmarks.
- Datasets:
- GSM-Symbolic: 100 GSM8K test samples with templates to vary numbers, names, or both.
- GSM-Plus: Full GSM8K test set with variations like digit expansion, int-dec-fra conversion, numerical substitution, rephrasing, and distractor insertion.
- Models Tested: Llama-3.2 (1B, 3B), Llama-3.1 (8B), Qwen2.5 (0.5B, 1.5B, 3B, 7B, Math-7B), Mathstral-7B.
- Baselines: CoT-8S (8-shot prompting), CoT-RL (SFT+RL on Socratic CoT data with only ranswer), CoA (SFT-only abstract reasoning without input abstraction or GranulAR).
Key Findings:
- AbstRaL consistently improves accuracy on perturbed samples (Vary Both on GSM-Symbolic, various perturbations on GSM-Plus) and reduces the performance drop compared to original samples.
- It significantly mitigates interference from distracting conditions (Distract set in GSM-Plus), attributed to the GranulAR data format which encourages planning.
- For larger models, AbstRaL can outperform 8-shot prompting on perturbed data even if slightly lower on original data, suggesting it mitigates overfitting to input conditions possibly due to pre-training data contamination.
- Ablation studies confirmed the importance of:
- Learning abstract reasoning (vs. just using tools).
- Constructing abstraction within contexts (X→X^→Y^→A) rather than direct abstraction generation (X→X^→A).
- The RL stage, especially the rsymbolic reward.
- The GranulAR data format for handling distractors.
Implementation Considerations
- Condition Recognition & GranulAR Data Generation: Leverages a powerful oracle LLM (Llama-3.3-70B-Instruct) with few-shot prompting. This is a one-time offline process but requires significant compute (e.g., 36 hours on 4 A100s for condition recognition on all data, similar times for GranulAR steps). Tables with prompt examples are provided in the appendix.
- Symbolic Tools: Relies on regex for abstraction retrieval and SymPy for symbolic derivation. These are generally efficient.
- Model Training: Full SFT and RL fine-tuning are required. RL training is computationally intensive (days on multiple high-end GPUs).
- Inference: Involves the four-step pipeline. The LLM performs condition recognition (if not using an external tool/oracle for this step at inference, though the paper used an oracle for experiments) and abstract reasoning. Regex and SymPy handle the latter two steps. Greedy decoding was used.
- Scalability: While AbstRaL aims to improve smaller models, the data generation uses a very large model. The training itself is standard fine-tuning.
Practical Application Example (Mathematical Word Problem)
Consider the problem: "Jaime places eggs on some trays. Each tray can hold 24 eggs. If he has 64 eggs and 2 trays, how many eggs won't he be able to place on the tray?"
- Condition Recognition:
-
in0 = 24
(eggs per tray)
-
in1 = 64
(total eggs)
-
in2 = 2
(number of trays)
- Abstract Question X^: "...Each tray can hold [in0] eggs. If he has [in1] eggs and [in2] trays..."
- Abstract Reasoning (LLM generates GranulAR output Y^):
- Q1: How many eggs can Jaime place on the trays?
- Each tray holds [in0], Jaime has [in2] trays. Total capacity:
<< in0 * in2 = out0 >>
- Q2: How many eggs won't he be able to place?
- Jaime has [in1] eggs, capacity is [out0]. Unplaced:
<< in1 - out0 = out1 >>
- Final answer is
out1
.
- Abstraction Retrieval: From Y^, extract A:
-
in0 * in2 = out0
-
in1 - out0 = out1
- Symbolic Derivation (using SymPy with C and A):
-
24 * 2 = out0
=> out0 = 48
-
64 - out0 = out1
=> 64 - 48 = out1
=> out1 = 16
- Final Answer: 16
If the input numbers change (e.g., 39 crackers/tray, 302 total, 7 trays), the abstract reasoning steps and abstraction A remain the same. Only the values in C change, leading SymPy to calculate a new correct answer, demonstrating robustness.
Limitations
- Tested primarily on grade school math in English.
- Assumes perturbations don't change the underlying abstract reasoning steps.
- Requires full model tuning (LoRA or other PEFT methods could be explored).
- Used greedy decoding; advanced decoding strategies might yield further improvements.
In summary, AbstRaL provides a structured method to teach LLMs abstract reasoning by decomposing problems, using symbolic representations, and leveraging RL with specific rewards tied to abstraction quality and answer correctness. This leads to improved robustness against various input perturbations, especially for mathematical reasoning tasks.