Self-Training with Direct Preference Optimization Improves Chain-of-Thought Reasoning
The paper "Self-Training with Direct Preference Optimization Improves Chain-of-Thought Reasoning" explores methods to enhance the mathematical reasoning capabilities of small-scale LLMs (LMs). The core idea revolves around improving traditional self-training frameworks using a technique called Direct Preference Optimization (DPO). This approach leverages the preference data to refine the model training process, guiding LMs during pseudo-label generation, making their outputs both more accurate and diverse.
The paper places an emphasis on two main aspects: reinforcing the reasoning capabilities of smaller LMs and doing so efficiently compared to large proprietary models. This investigation is motivated by the high computational and economic costs associated with using large models for reasoning tasks, such as Codex, PaLM, and GPT-4. Smaller models offer a more cost-effective alternative but require methodologies to boost their inherent capabilities without significant resources.
Methodology
The authors introduce DPO-augmented self-training as an enhancement over traditional self-training approaches. The method is iterated through two primary steps:
- DPO Step: It involves refining a model to produce higher-quality outputs by using an objective based on DPO. This step uses a preference dataset created from multiple outputs generated by the model itself, labeling correct reasoning paths as preferred.
- SFT (Supervised Fine-Tuning) Step: Utilizing the improved model from the DPO step, new pseudo-labeled data are generated. These correct and unique rationales are then added to the training set for further fine-tuning.
Additionally, to boost performance on arithmetic tasks, the researchers integrate an external calculator into the reasoning process. They proposed a scaling method for calculator usage in batch inference, thereby overcoming limitations of existing single-batch methods.
Experiments and Results
The authors conducted experiments using Flan-T5 models and Llama models. Three major datasets—GSM8K, MultiArith, and ASDiv—were used for training and evaluation. Notably, the results demonstrate marked improvements for models employing the DPO-augmented self-training over traditional self-training and supervised fine-tuning. For instance, the Flan-T5-Large model saw an accuracy increase from 35.6% (self-training) to 37.4% using the DPO-augmented approach on the GSM8K dataset.
A key observation is the additional performance boost when using the external calculator integration in smaller models: the Flan-T5-Large model reached an accuracy of 40% on GSM8K, surpassing other reported results for comparably sized models. An iterative training regime showed consistent performance gains across different iterations, underscoring the robustness of the proposed method.
Discussion
The integration of DPO into self-training frameworks illustrates an efficient paradigm for enhancing small-scale LMs without the substantial costs tied to larger models. The empirical results suggest that models fine-tuned with DPO can generate higher-quality pseudo-labeled data, leading to continuous improvement with each iteration. This iterative refinement is particularly useful in scenarios with limited access to large annotated datasets.
The research also underscores the significant impact of computational tools during inference. Incorporating an external calculator improved performance by reducing arithmetic errors, a common shortfall in smaller models. This adaptability could have broader implications in improving the precision of LMs in tasks needing intricate, multi-step reasoning beyond arithmetic, such as code generation and complex problem-solving.
Implications and Future Directions
From a practical standpoint, the demonstrated effectiveness of DPO-augmented self-training offers a scalable and economical pathway for enhancing LMs' reasoning abilities. The methods alleviate the need for large-scale annotations and proprietary large models, balancing performance with resource efficiency.
Theoretically, the success of DPO in fine-tuning models using self-generated data offers insights into preference-guided learning. Future research could explore the application of this framework across different domains and tasks. Additionally, integrating knowledge distillation into the iterative DPO-self-training process may further refine model performance, creating a more synergistic approach that leverages both self-improvement and external expert models.
In conclusion, the paper provides valuable contributions by proposing a novel and effective method for improving the chain-of-thought reasoning in small-scale LMs. This work is meaningful both for its immediate practical benefits and for setting a foundational approach that can be built upon in future AI developments.