Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
121 tokens/sec
GPT-4o
9 tokens/sec
Gemini 2.5 Pro Pro
47 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Training a Generally Curious Agent (2502.17543v3)

Published 24 Feb 2025 in cs.LG, cs.AI, and cs.CL

Abstract: Efficient exploration is essential for intelligent systems interacting with their environment, but existing LLMs often fall short in scenarios that require strategic information gathering. In this paper, we present Paprika, a fine-tuning approach that enables LLMs to develop general decision-making capabilities that are not confined to particular environments. By training on synthetic interaction data from different tasks that require diverse strategies, Paprika teaches models to explore and adapt their behavior on a new task based on environment feedback in-context without more gradient updates. Experimental results show that models fine-tuned with Paprika can effectively transfer their learned decision-making capabilities to entirely unseen tasks without additional training. Unlike traditional training, our approach's primary bottleneck lies in sampling useful interaction data instead of model updates. To improve sample efficiency, we propose a curriculum learning strategy that prioritizes sampling trajectories from tasks with high learning potential. These results suggest a promising path towards AI systems that can autonomously solve novel sequential decision-making problems that require interactions with the external world.

Summary

  • The paper introduces PAPRIKA, a fine-tuning framework that enables LLMs to perform adaptive in-context reinforcement learning for interactive tasks.
  • It employs synthetic trajectory generation combined with curriculum learning and a multi-armed bandit strategy to enhance sample efficiency.
  • The approach achieved a 47% improvement in success rates across 10 task groups, demonstrating robust generalization over the base model.

The paper "Training a Generally Curious Agent" (2502.17543) introduces PAPRIKA, a fine-tuning methodology designed to imbue LLMs with generalized sequential decision-making capabilities, enabling them to perform in-context reinforcement learning (ICRL). The central objective is to train models that can adaptively solve novel interactive tasks by leveraging environmental feedback received during interaction, without necessitating further gradient-based updates or task-specific fine-tuning. This addresses the limitation of standard LLMs, which often lack robust strategies for exploration and information gathering in partially observable environments.

PAPRIKA Fine-tuning Framework

The PAPRIKA framework integrates several key components to achieve its goal of training generally capable interactive agents.

Task Design and Environment Simulation

A crucial element is the curation of a diverse set of training tasks that inherently require strategic interaction and information acquisition. The paper utilizes 10 distinct task groups, including games (e.g., Twenty Questions, Wordle, Minesweeper, Battleship), reasoning problems (Cellular Automata), simulated interactions (Customer Service, Murder Mystery), and classic reinforcement learning paradigms (Bandit Best Arm Selection). These tasks are characterized by partial observability, forcing the agent to balance exploration (querying the environment for information) and exploitation (using gathered information to achieve the task objective efficiently).

Task environments are simulated using two primary methods:

  1. LLM-based Simulation: For tasks requiring world knowledge or nuanced interaction (e.g., Twenty Questions, Guess My City, Customer Service, Murder Mystery), another LLM (GPT-4o-mini in the experiments) simulates the environment.
  2. Programmatic Simulation: For tasks governed by strict rules (e.g., Wordle, Mastermind, Battleship, Minesweeper, Cellular Automata, Bandit), hardcoded programs serve as the environment simulator. This ensures reliability, prevents reward hacking, and controls computational costs. Chain-of-Thought (COT) prompting is employed for the agent in more complex tasks to facilitate reasoning.

Synthetic Interaction Data Generation

PAPRIKA relies on synthetically generated interaction trajectories rather than potentially costly or risky real-world data collection. A base LLM (Llama-3.1-8B-Instruct in the experiments) interacts with the simulated task environments to produce multiple trajectories for each task. To foster diversity in strategies and prevent the model from overfitting to a single mode of interaction, trajectory generation employs high sampling temperature (e.g., 1.5) and Min-p sampling (e.g., p=0.3).

Preference Dataset Construction

Following trajectory generation for a given task, each trajectory is evaluated based on predefined scoring criteria, typically incorporating task success and efficiency (e.g., number of interaction turns). A preference dataset D\mathcal{D} is then constructed. For each task, the highest-scoring trajectory, denoted hwh^w (winning/preferred trajectory), is identified. This hwh^w is paired with a randomly selected trajectory hlh^l that achieved a lower score (losing/dispreferred trajectory). The resulting dataset consists of pairs (hw,hl)(h^w, h^l). Selecting hlh^l randomly, rather than strictly the worst-performing trajectory, is intended to enhance the diversity of the preference signals within the dataset.

Optimization via Reasoning Preference Optimization (RPO)

The base LLM is fine-tuned using the constructed preference dataset D\mathcal{D}. The primary optimization technique employed is Reasoning Preference Optimization (RPO), an approach that integrates Supervised Fine-Tuning (SFT) with Direct Preference Optimization (DPO).

  1. Supervised Fine-Tuning (SFT): This component maximizes the likelihood of the agent's actions atwa_t^w within the preferred trajectories hwh^w. The SFT loss is the standard negative log-likelihood of the agent's action tokens in hwh^w.

    LSFT=EhwD[tlogπθ(atwctw)]\mathcal{L}_\text{SFT} = - \mathbb{E}_{h^w \sim \mathcal{D}} \left[ \sum_{t} \log \pi_\theta(a_t^w | c_t^w) \right]

    where ctwc_t^w is the context at timestep tt in trajectory hwh^w.

  2. Direct Preference Optimization (DPO): A multi-turn variant of DPO is used. The objective is to increase the relative probability of the preferred trajectory hwh^w over the dispreferred trajectory hlh^l, compared to a reference policy πref\pi_\text{ref} (typically the initial base model before fine-tuning). The DPO loss is computed only on the agent's action tokens, masking out tokens corresponding to environment responses.

    LDPO=E(hw,hl)D[logσ(βlogtlogπθ(atwctw)tlogπref(atwctw)βlogtlogπθ(atlctl)tlogπref(atlctl))]\mathcal{L}_\text{DPO} = - \mathbb{E}_{(h^w, h^l) \sim \mathcal{D}} \left[ \log \sigma \left( \beta \log \frac{\sum_t \log \pi_\theta(a_t^w | c_t^w)}{\sum_t \log \pi_\text{ref}(a_t^w | c_t^w)} - \beta \log \frac{\sum_t \log \pi_\theta(a_t^l | c_t^l)}{\sum_t \log \pi_\text{ref}(a_t^l | c_t^l)} \right) \right]

    where β\beta is a hyperparameter controlling the deviation from the reference policy, and σ\sigma is the sigmoid function.

The combined RPO loss is a weighted sum:

LRPO=LDPO+αLSFT\mathcal{L}_{\text{RPO}} = \mathcal{L}_\text{DPO} + \alpha \mathcal{L}_\text{SFT}

The paper uses α=1.0\alpha=1.0. The rationale for using RPO is to leverage the preference signal effectively while potentially mitigating "unintentional unalignment," where pure DPO might inadvertently decrease the likelihood of desirable actions present in hwh^w. Offline methods like DPO/RPO were chosen primarily for their reduced memory footprint compared to online RL methods, although online RL is acknowledged as a potentially more potent alternative if computational resources permit.

Curriculum Learning for Sample Efficiency

A significant bottleneck identified in the PAPRIKA training process is not the computational cost of model parameter updates but the cost associated with generating useful interaction data, particularly when using powerful LLMs as simulators or judges. Many generated trajectories might provide limited learning signal if the task is currently too easy (low variance in outcomes) or too difficult (few successful attempts) for the agent policy. To enhance sample efficiency, a curriculum learning strategy is proposed.

Motivation and Learning Potential Metric

The core idea is to prioritize sampling trajectories from tasks where the current agent policy π\pi exhibits the highest potential for learning. This potential is quantified using the coefficient of variation (ν) of the episode scores (rewards) obtained for a task τ\tau:

νπ(τ)=σπ2(τ)Rπ(τ)\nu_\pi(\tau) = \frac{\sqrt{\sigma^2_\pi(\tau)}}{R_\pi(\tau)}

Here, Rπ(τ)R_\pi(\tau) is the mean score and σπ2(τ)\sigma^2_\pi(\tau) is the variance of scores achieved by policy π\pi on task τ\tau. The coefficient of variation provides a normalized measure of score variability relative to the mean score, allowing for comparison of learning potential across tasks with potentially different reward scales. A high νπ(τ)\nu_\pi(\tau) indicates that the policy generates diverse outcomes for the task (a mix of high and low scores), suggesting that trajectories from this task will provide strong contrastive signals for preference-based learning methods like DPO/RPO.

Multi-Armed Bandit Formulation with UCB

The problem of selecting which task group to sample from at each stage of data generation is framed as a Multi-Armed Bandit (MAB) problem. Each task group (e.g., "Wordle", "Twenty Questions") represents an "arm" of the bandit. The goal is to dynamically allocate sampling effort to the arms (task groups) that are currently estimated to yield the highest learning potential.

The Upper Confidence Bound (UCB) algorithm is employed to manage the exploration-exploitation trade-off in selecting task groups. The procedure (Algorithm 1 in the paper) operates iteratively:

  1. For each task group (arm) kk, calculate a UCB score θk\theta_k. This score typically combines the current estimate of the mean learning potential for that group (ν^k\hat{\nu}_k) and an exploration bonus that encourages sampling from less-frequently chosen groups. The formula might resemble θk=ν^k+clogNnk\theta_k = \hat{\nu}_k + c \sqrt{\frac{\log N}{n_k}}, where NN is the total number of sampling rounds so far, nkn_k is the number of times group kk has been sampled, and cc is an exploration hyperparameter.
  2. Select the task group kk^\star with the highest UCB score θk\theta_{k^\star}.
  3. Uniformly sample a specific task instance τ\tau from the chosen group kk^\star.
  4. Generate a batch of CC interaction trajectories for task τ\tau using the current agent policy π\pi.
  5. Estimate the learning potential ν^π(τ)\hat{\nu}_\pi(\tau) from the scores of the CC generated trajectories (the paper uses the number of turns as a proxy for the score/reward).
  6. Update the statistics for the selected group kk^\star: increment its sample count nkn_{k^\star} and update its estimated mean learning potential ν^k\hat{\nu}_{k^\star} using the newly computed ν^π(τ)\hat{\nu}_\pi(\tau).
  7. Repeat steps 1-6 for a predetermined number of rounds TT.
  8. Construct the final training dataset D\mathcal{D} using the trajectories collected across all rounds.
  9. Fine-tune the policy π\pi using D\mathcal{D} via RPO.

This curriculum strategy dynamically focuses the computationally expensive data generation process on task groups that offer the most informative trajectories for the current state of the model, thereby improving overall sample efficiency compared to uniform sampling across all tasks.

Implementation Details and Experimental Validation

The experiments primarily utilize Llama-3.1-8B-Instruct as the base model. Key findings validate the PAPRIKA approach:

  • Performance on Training Tasks: PAPRIKA fine-tuning yielded a significant average relative improvement of 47% in success rates across the 10 task groups compared to the base Llama-3.1-8B-Instruct model, using approximately 22,500 trajectories in total for training.
  • Generalization to Unseen Tasks: The central claim of generalization