This paper introduces Multi-Granularity Direct Preference Optimization (MDPO), a method designed to enhance the mathematical reasoning capabilities of LLMs. The core problem addressed is that LLMs often struggle with long-chain mathematical reasoning, where errors in any single step can lead to incorrect final answers. While Supervised Fine-Tuning (SFT) can improve these abilities, it often leads to hallucinations and doesn't effectively suppress incorrect outputs. Direct Preference Optimization (DPO) has been effective for general alignment but shows limited benefits in complex mathematical reasoning, as it struggles to pinpoint specific errors in long solution chains and its training objective can be inconsistent with generation metrics.
MDPO proposes to optimize LLMs at three distinct granularities, providing more targeted supervision signals:
- Solution2Solution (Sol2Sol): This level operates on the entire reasoning chain (solution). It provides coarse-grained supervision by comparing complete correct solutions () with complete incorrect solutions () for a given problem (). This is similar to standard DPO.
- Inference2Inference (Infer2Infer): This level focuses on the logical transitions between individual reasoning steps. An "inference" is defined as the generation from to . If a particular inference leads to a higher error rate or an incorrect path, it's labeled . A corrected or alternative successful inference is . This provides fine-grained supervision for the reasoning process.
- Step2Step: This level targets computational errors within a single reasoning step. If a contains a calculation mistake, a corrected is provided. This aims to directly improve the model's computational accuracy.
A key aspect of MDPO is its unified training objective, inspired by Simple Preference Optimization (SimPO) (Lai et al., 26 Jun 2024 ). The objective aims to align the fine-tuning process with the downstream generation metrics. The mathematical reasoning task is framed as a text completion task: given a problem and the preceding correct steps , the model must generate the remaining steps to reach the correct answer. This applies to all three granularities:
- Sol2Sol: , the model generates the entire solution from "Let’s think step by step."
- Infer2Infer & Step2Step: Given and , the model generates and subsequent steps.
The MDPO loss function is:
where is the problem, are correct preceding steps, is the preferred continuation, is the rejected continuation, is a scaling factor for the reward difference, and is a target reward margin. The reward for a sequence given context is defined as its average log-likelihood, normalized by length: .
The paper also outlines a pipeline for automatically constructing the multi-granularity preference data without manual annotation:
- Sol2Sol Data:
1. Use an LLM to generate multiple reasoning paths for each problem, prepending "[Step i]" to each step. 2. Verify paths based on dataset labels. 3. Select paths with correct final answers as and incorrect ones as . 4. Prioritize problems where the model generates both correct and incorrect solutions.
- Infer2Infer Data:
1. Use erroneous reasoning paths from Sol2Sol. 2. Segment paths into steps and create windows , where . 3. For each window , LLMs generate and sample reasoning paths. 4. Calculate error rate for each : . 5. An unreliable step is identified if . The transition from to is . 6. Generate from again, sampling a reliable path with a correct final answer as . 7. Construct preference pairs: .
- Step2Step Data:
1. Use selected problems, including new ones with complex calculations (numbers in original problems replaced with more complex ones). 2. LLM generates reasoning paths, which are sampled and segmented. 3. Use GPT-4 with prompts to find the first step with a calculation error. 4. GPT-4 corrects it to and generates the rest of the solution. 5. Verify the LLM's modifications via answer checking. 6. Construct preference data: .
Experiments were conducted on Qwen2-7B-Instruct and Llama3-8B-Instruct models, evaluated on GSM8K and MATH datasets. Training used 30,000 preference data pairs for 8 epochs with a global batch size of 128, a learning rate of 5e-7, and .
Key Results:
- Main Performance:
- Qwen2-7B-Instruct + MDPO: +1.7% on GSM8K, +2.3% on MATH.
- Llama3-8B-Instruct + MDPO: +0.9% on GSM8K, +1.2% on MATH.
- Comparison with other methods (on Qwen2-7B-Instruct):
- MDPO (GSM8K: 83.4%, MATH: 56.5%) outperformed DPO (GSM8K: 81.9%, MATH: 54.6%), SimPO (GSM8K: 82.1%, MATH: 54.9%), and Step-DPO (GSM8K: 82.1%, MATH: 55.1%).
- The improvement over Step-DPO was notable on MATH (1.4% absolute), attributed to MDPO's additional focus on computational capabilities via Step2Step.
- Ablation Study (on Qwen2-7B-Instruct, GSM8K):
- Base: 81.7%
- + Sol2Sol: 82.5%
- + Sol2Sol + Infer2Infer: 83.2% (Infer2Infer contributed most to reasoning improvement)
- + Sol2Sol + Infer2Infer + Step2Step (Full MDPO): 83.4%
- Computational Ability (Step2Step only on Qwen2-7B-Instruct):
- On GSM-HARD: +3.4% (45.5% vs 42.1% base)
- On MATH: +1.7% (55.9% vs 54.2% base)
- This demonstrated Step2Step's effectiveness in enhancing computational skills, outperforming DPO and Step-DPO on these complex datasets.
- Training Objective Alignment: MDPO significantly increased the proportion of instances where the model assigns a higher probability to the preferred answer () compared to the rejected answer (), unlike DPO and Step-DPO. This is attributed to aligning the reward function with generation metrics and unifying fine-tuning with downstream tasks. (See Figure 2 in the paper for Win Rate comparison).
Practical Implementation Considerations:
- Data Construction: The automated data pipeline is a significant practical advantage, reducing reliance on manual annotation. However, it involves multiple LLM generation and verification steps (including calls to GPT-4 for Step2Step error correction), which can be computationally intensive and may depend on the quality of the LLMs used in the pipeline.
- Computational Resources: While experiments were on 7B/8B models, the training (8 epochs, batch size 128) still requires considerable GPU resources. The authors suggest greater improvements may be seen on larger models.
- Hyperparameter Tuning: The (reward scaling) and (reward margin) parameters in the loss function might require tuning for optimal performance on different models or datasets. The paper used .
- Granularity Trade-offs: While all three granularities contribute, Infer2Infer seems most crucial for general reasoning. For tasks heavy on computation, Step2Step becomes more important. The mix of data from these granularities in the 30,000 pairs was not specified but could be a factor to optimize.
- Base Model Choice: The method uses instruct-tuned models as a starting point, which is standard in RLHF pipelines. The quality of this initial SFT model can impact MDPO's effectiveness.
In conclusion, MDPO offers a promising approach to improve mathematical reasoning in LLMs by providing more detailed, multi-level supervision signals and aligning training objectives with generation metrics. Its automated data construction pipeline makes it more feasible to implement. The results indicate significant gains over existing DPO variants, particularly in complex reasoning and computation.