- The paper introduces PromptPG, a reinforcement learning approach that selects optimal in-context examples to enhance GPT-3's performance on math word problems.
- It presents the TabMWP dataset with over 38K grade-level problems integrating text and tabular data for multi-step reasoning.
- Experiments show PromptPG improves accuracy by 5.31% over baselines, stabilizing predictions on semi-structured problems.
Dynamic Prompt Learning via Policy Gradient for Semi-structured Mathematical Reasoning
The paper presents an investigation into the ability of LLMs, specifically GPT-3, to solve math word problems (MWPs) incorporating semi-structured data. Building on the recent success of LLMs in natural language processing tasks, this research explores an innovative approach called PromptPG, which harnesses policy gradient methods to dynamically select prompts for problem-solving.
Overview and Dataset
The authors introduce a new dataset, Tabular Math Word Problems (TabMWP), which consists of 38,431 open-domain, grade-level problems requiring mathematical reasoning over textual and tabular data. Each problem comprises a question and a corresponding tabular context presented in multiple formats: image, semi-structured text, and structured table. This dataset adds complexity to the standard MWP tasks as it demands integration of heterogeneous information.
TabMWP consists of free-text and multiple-choice questions, annotated with gold solutions to elucidate the multi-step reasoning required to solve each problem. The ability to reason over diverse data types in this context represents a significant evolution beyond traditional MWP datasets that primarily involve unstructured text.
Existing Challenges
While GPT-3 and its few-shot capabilities are significant advancements, their performance on complex problems such as those in TabMWP can be unstable, often deteriorating to near chance levels depending on the selection of in-context examples. This instability arises mainly when models need to handle multi-faceted data derived from varied question types and table structures.
Proposed Solution: PromptPG
To address this, the paper proposes PromptPG, a novel method using reinforcement learning. By applying policy gradient techniques, PromptPG learns to select the most suitable in-context examples to include in the prompts fed to the model. This method contrasts with random selection strategies traditionally employed, which do not reliably improve model stability.
The approach involves training an agent interacting with GPT-3, optimizing its performance by dynamically choosing examples that demonstrate the highest prediction accuracy. PromptPG builds on BERT-generated embeddings and applies a policy gradient strategy to guide the agent's decisions, ultimately reducing prediction variance and improving overall accuracy.
Experimental Validation
The authors conduct extensive experiments to benchmark several existing QA methods, GPT-3 in various settings, and PromptPG against the TabMWP dataset. The results indicate that PromptPG surpasses state-of-the-art baselines by 5.31% in accuracy, a noteworthy performance increase. Further, it stabilizes prediction outcomes significantly better than random selection methods. Additionally, the research explores how factors such as training set size and candidate selection affect the learning algorithm, identifying optimal configurations for best results.
Implications and Future Directions
This work demonstrates significant implications for the QA and AI fields. First, it establishes a new benchmark with TabMWP for evaluating models on problems requiring reasoning across structured and unstructured data modalities. Furthermore, PromptPG's framework underscores the potential of integrating reinforcement learning with LLMs to enhance their problem-solving capabilities in semi-structured environments.
Prospective research could explore scaling PromptPG to handle even more complex datasets or adapting it to other domains involving structured data. Additionally, further refining extraction methodologies or integrating holistic reasoning strategies could drive improvements in downstream applications utilizing AI for semi-structured data reasoning tasks.
In summary, this paper contributes an innovative dataset and a reinforcement learning-based approach that together push the boundaries of machine reasoning in handling complex mathematical problems interwoven with multi-modal data.