PromptPG: RL-Based Dynamic Prompt Construction
- PromptPG is a reinforcement learning-based method that optimizes the selection of in-context examples to improve few-shot mathematical reasoning in LLMs.
- It formalizes prompt selection as a policy optimization problem using a BERT-based encoder and policy gradient to maximize GPT-3 accuracy.
- Experiments on TabMWP show a 5.31% accuracy gain and reduced variance compared to heuristic and random selection strategies.
PromptPG is a reinforcement learning-based method for dynamic prompt construction in LLMs, specifically targeting the robust selection of in-context examples to improve few-shot performance on mathematical reasoning tasks involving semi-structured data such as tables. The approach formalizes the prompt selection process as a policy optimization problem and employs the policy gradient algorithm for learning. PromptPG demonstrates substantial improvements over heuristic or random selection, particularly in terms of both accuracy and prediction stability on benchmarks such as TabMWP (Lu et al., 2022).
1. Problem Formulation and Motivation
PromptPG addresses the instability and poor performance observed in few-shot GPT-3 prompting for semi-structured mathematical reasoning tasks. This instability is exacerbated for tasks requiring multi-modal comprehension (e.g., table understanding) and compositional reasoning, where naive random or fixed heuristic selection of prompt examples often causes accuracy to degrade to near chance. PromptPG casts the in-context example selection as a reinforcement learning (RL) problem, parameterizing the selection policy and optimizing it to maximize downstream task accuracy by directing the LLM (here, GPT-3) to produce correct answers on the TabMWP dataset, a benchmark of 38,431 problems requiring composition over both tabular and natural language data (Lu et al., 2022).
2. Reinforcement Learning Framework
Within the RL formalization, each test math problem is a state, and the action is the selection of a -element subset of in-context examples from a candidate set . The policy , parameterized by , is a neural net that assigns probabilities over subsets of candidates given the test problem. Rewards are obtained by constructing a prompt from the selected examples and the test problem, inputting it into a frozen GPT-3 model (specifically, text-davinci-002 at zero temperature), and assigning a reward equal to if the output matches the gold answer , and otherwise (with exactness normalized for the task type: two decimal places for free-text, nearest-string for multi-choice) (Lu et al., 2022).
The objective is the maximization of
$J(\theta) = \mathbb{E}_{p_i} \mathbb{E}_{e_i \sim \pi_\theta(\cdot|p_i)} \left[ R(\mathrm{GPT\mbox{-}3}(e_i, p_i)) \right].$
The policy is updated using REINFORCE:
0
In practice, this is implemented via mini-batch sampling from the training set, with each batch forming stochastic gradient steps on the negative expected reward-weighted log likelihood. The policy network uses a fixed BERT encoder; scores are computed via dot products in the embedding space, followed by a softmax over candidate prompts (Lu et al., 2022).
3. Prompt Construction and Data Representation
Selected in-context examples 1 are concatenated to build the prompt. Each table 2 is serialized as a plain text structure (rows joined by newlines, columns by the pipe symbol), followed by the corresponding question 3, (optional) choices 4, solution 5 (chain-of-thought step-by-step description), and answer 6. The final prompt for GPT-3 is of the form:
4
This explicit inclusion of gold chain-of-thoughts in the demonstration phase guides the LLM to provide stepwise reasoning (Lu et al., 2022).
4. Experimental Protocol and Results
PromptPG is trained and evaluated on the TabMWP dataset, partitioned into 23,059 training, 7,686 development, and 7,686 test examples. Each training batch uses 7 randomly selected problems, with 8 in-context slots chosen from 9 pre-sampled candidates (excluding the current task). The policy is a trainable linear transformation atop frozen BERT representations.
The method is benchmarked against several baselines:
| Method | Accuracy (%) | Std. Dev. (%) |
|---|---|---|
| Random (2-shot) | 62.92 | ā2.30 |
| Heuristic (type) | 66ā68 | Higher |
| Retrieval (NN) | 68.2 | ā |
| PromptPG | 68.23 | 1.27 |
PromptPG achieves an absolute gain of 5.31% in accuracy compared to the strongest baseline, with variance (measured as standard deviation across reruns) reduced from ā4.0% (random) to ā1.3% (Lu et al., 2022).
Ablation studies indicate optimal candidate pool sizes (20ā40) and a peak in training performance at 160 RL training examples, beyond which stochastic instability increases (Lu et al., 2022).
5. Policy Architecture and Optimization
The prompt selection policy operates in embedding space, with:
- Input representation using BERT[CLS] embeddings for both the test problem 0 and each candidate example.
- Scoring by computing 1 with trainable W and b.
- Softmax normalization to obtain selection probabilities.
- Two independent draws for 2-shot prompting, with replacement.
- Optimization with Adam (learning rate 3), early stopping, and optional reward baselines to reduce gradient variance.
Pseudo-code for the core training loop:
5 (Lu et al., 2022)
6. Analysis, Ablations, and Baselines
PromptPG outperforms random, heuristic, and simple nearest-neighbor selection methods by shifting from "semantic similarity" to empirical optimization for downstream accuracy. Heuristics that match question/answer type or grade level attain 66ā68% but exhibit higher variance, while nearest-neighbor semantic retrieval reaches 68.2% with similar variance. PromptPG, via RL policy, discovers in-context demonstration sets that empirically maximize reward for GPT-3, outperforming all fixed or hand-crafted strategies in both mean accuracy and stability (Lu et al., 2022).
7. Implications and Scope
PromptPG demonstrates that the critical leverage in few-shot LLM-based reasoning over semi-structured data lies in the automated, data-driven selection of demonstration exemplars. The method is simple (two-parameter linear layer atop BERT), lightweight, model-agnostic (works with any frozen LLM with in-context learning capability), and requires minimal training data (peak performance at ā160 RL problems). The robustness to variance and strong empirical gains suggest that RL-driven prompt construction can supplant ad hoc heuristic prompt assembly, especially in multi-modal or multi-step reasoning scenarios (Lu et al., 2022).
A plausible implication is that as LLMs are increasingly deployed in complex data environments, scalable RL-based prompting policies such as PromptPG will become essential for extracting consistent, high-quality model behavior. Future extensions may address tasks with more complex or higher K-shot settings, or explore alternative or more granular reward structures.