Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
194 tokens/sec
GPT-4o
7 tokens/sec
Gemini 2.5 Pro Pro
46 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Inference-Aware Fine-Tuning for Best-of-N Sampling in Large Language Models (2412.15287v1)

Published 18 Dec 2024 in cs.CL, cs.AI, and cs.LG

Abstract: Recent studies have indicated that effectively utilizing inference-time compute is crucial for attaining better performance from LLMs. In this work, we propose a novel inference-aware fine-tuning paradigm, in which the model is fine-tuned in a manner that directly optimizes the performance of the inference-time strategy. We study this paradigm using the simple yet effective Best-of-N (BoN) inference strategy, in which a verifier selects the best out of a set of LLM-generated responses. We devise the first imitation learning and reinforcement learning~(RL) methods for BoN-aware fine-tuning, overcoming the challenging, non-differentiable argmax operator within BoN. We empirically demonstrate that our BoN-aware models implicitly learn a meta-strategy that interleaves best responses with more diverse responses that might be better suited to a test-time input -- a process reminiscent of the exploration-exploitation trade-off in RL. Our experiments demonstrate the effectiveness of BoN-aware fine-tuning in terms of improved performance and inference-time compute. In particular, we show that our methods improve the Bo32 performance of Gemma 2B on Hendrycks MATH from 26.8% to 30.8%, and pass@32 from 60.0% to 67.0%, as well as the pass@16 on HumanEval from 61.6% to 67.1%.

Summary

  • The paper presents an inference-aware training paradigm that optimizes best-of-N sampling directly during the training phase.
  • It introduces novel imitation learning and reinforcement learning approaches to address the non-differentiability challenges in output selection.
  • Empirical results show significant improvements, with Gemma 2B Bo32 performance rising from 26.8% to 30.8% and pass@32 accuracy from 60.0% to 67.0%.

Inference-Aware Fine-Tuning for Best-of-N Sampling in LLMs

Introduction

The paper, "Inference-Aware Fine-Tuning for Best-of-N Sampling in LLMs," explores the optimization of LLMs by fine-tuning these models to improve inference-time strategies directly. The authors propose a methodology termed inference-aware fine-tuning, with a focus on the Best-of-N (BoN) sampling strategy. This strategy involves generating multiple model outputs and selecting the best one according to a verifier's scoring function. The paper advances the application of BoN sampling by integrating it into the training process using imitation learning and reinforcement learning (RL) techniques.

Key Contributions

The authors identify several key contributions in their research, providing robust methodologies and insights into the training and utilization of LLMs with BoN inference strategies:

  1. Inference-Aware Training Paradigm: The paper introduces an inference-aware training paradigm, specifically designed to optimize the performance of LLMs by considering the inference mechanism during the training phase. This focus contrasts with traditional methods that treat inference-time computation as a separate, post-hoc decision.
  2. BoN-Aware Imitation and Reinforcement Learning Approaches: By addressing the challenge of non-differentiability in the argmax operator integral to BoN strategies, the authors develop novel imitation learning and RL approaches. These methods account for the exploration-exploitation trade-off crucial for optimizing inference outcomes in LLMs.
  3. Empirical Evaluation: Empirical studies performed by the authors demonstrate significant improvements in model performance. Notably, their methods increased the Bo32 performance of the Gemma 2B model on the Hendrycks MATH benchmark from 26.8% to 30.8%, and pass@32 accuracy from 60.0% to 67.0%. Similar improvements were observed on coding benchmarks, reinforcing the potential of their proposed techniques.

Methodology

The paper outlines a BoN inference strategy where the LLM generates multiple candidate responses, and a verifier selects the optimal output based on a scoring function. The research foregrounds a co-scaling relationship between exploration temperature and the number of samples N, influencing both exploration and output quality. The authors devised both supervised fine-tuning and RL frameworks to align the models' distributions with the BoN policy distribution, thereby achieving better inference-time performance.

  • BoN-SFT (Supervised Fine-Tuning): Utilizing a variational inference approach, the authors develop a solution for supervised fine-tuning that adapts the likelihood of expert responses and adjusts the model's policy to be more exploratory through regularization.
  • BoN-RL (Reinforcement Learning): This approach leverages RL frameworks to optimize LLMs directly for inference strategies that use the BoN policy. The method involves training models to maximize expected rewards of outputs selected through BoN sampling, which incorporates both the complexity of verifier error tolerance and environment reward learning.

Implications and Prospects

The implications of this research are substantial, both in practical and theoretical realms. Practically, the enhanced methodologies for inference-aware fine-tuning of LLMs promise a more efficient deployment of computational resources, emphasizing the potential in refining inference-time strategies over brute-force model size expansions. Theoretically, the alignment between training and inference through BoN-aware strategies elucidates new pathways in model optimization that go beyond standard pre-training and fine-tuning separations.

The findings provoke further inquiry into inference-optimized neural network training, suggesting that these methods could extend to other complex neural architectures and application domains. Future developments might also explore extending these techniques to varying downstream tasks with distinct verification processes or leveraging more sophisticated verification models to enhance BoN accuracy.

Overall, this paper contributes meaningful advancements in optimizing LLM operations, enhancing performance metrics in computationally constrained environments, and providing a structured framework for further exploration in inference-aware model training and optimization strategies.