Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
140 tokens/sec
GPT-4o
7 tokens/sec
Gemini 2.5 Pro Pro
46 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

ShiQ: Bringing back Bellman to LLMs (2505.11081v1)

Published 16 May 2025 in cs.LG

Abstract: The fine-tuning of pre-trained LLMs using reinforcement learning (RL) is generally formulated as direct policy optimization. This approach was naturally favored as it efficiently improves a pretrained LLM, seen as an initial policy. Another RL paradigm, Q-learning methods, has received far less attention in the LLM community while demonstrating major success in various non-LLM RL tasks. In particular, Q-learning effectiveness comes from its sample efficiency and ability to learn offline, which is particularly valuable given the high computational cost of sampling with LLMs. However, naively applying a Q-learning-style update to the model's logits is ineffective due to the specificity of LLMs. Our core contribution is to derive theoretically grounded loss functions from BeLLMan equations to adapt Q-learning methods to LLMs. To do so, we carefully adapt insights from the RL literature to account for LLM-specific characteristics, ensuring that the logits become reliable Q-value estimates. We then use this loss to build a practical algorithm, ShiQ for Shifted-Q, that supports off-policy, token-wise learning while remaining simple to implement. Finally, we evaluate ShiQ on both synthetic data and real-world benchmarks, e.g., UltraFeedback and BFCL-V3, demonstrating its effectiveness in both single-turn and multi-turn LLM settings

Summary

  • The paper introduces ShiQ, an off-policy RL algorithm that adapts Bellman equations for fine-tuning LLMs using a token-level objective.
  • It reparameterizes the Q-function into a scoring function, enabling direct softmax inference without the need for extra reference model adjustments.
  • ShiQ accelerates learning with a multi-step consistency update that efficiently propagates sparse rewards, outperforming baselines in multi-turn tasks.

This paper introduces ShiQ (Shifted-Q), a novel offline reinforcement learning algorithm for fine-tuning LLMs based on adapting Q-learning and BeLLMan equations. Unlike traditional RL methods for LLMs like PPO or REINFORCE, which are on-policy and computationally expensive due to requiring fresh samples for each update, ShiQ is off-policy, allowing training on fixed datasets without costly model rollouts. Furthermore, while direct alignment methods (like DPO, SLiC) are off-policy, they are typically limited to preference data and aggregate sequence-level rewards. ShiQ is designed to handle arbitrary reward functions and leverage token-wise signals when available.

The core contribution of ShiQ lies in carefully adapting standard Q-learning concepts to the specific characteristics of LLMs, which are autoregressive models whose logits are typically interpreted as unnormalized log-probabilities. The authors identify three main challenges:

  1. Easing sampling: A standard Q-learning policy samples proportional to exp(Q(s,a)/β)\exp(Q(s,a)/\beta). Applying this directly to LLM logits would require loading both the learned Q-model and the reference model (for the KL term) during inference and managing the temperature parameter. ShiQ addresses this by reparameterizing the Q-function q(s,a)q(s,a) into a new scoring function g(s,a)g(s,a) such that βg(s,a)=q(s,a)+βlnref(as)\beta g(s,a) = q(s,a) + \beta \ln ref(a|s). The optimal policy then becomes proportional to exp(g(s,a))\exp(g(s,a)), meaning direct softmax over the learned logits gg gives the optimal policy, simplifying inference.
  2. Improved initialization: LLMs are fine-tuned from a pre-trained reference model. Ideally, the fine-tuning algorithm should leverage this strong initialization. Simply interpreting the reference logits as initial Q-values or gg values is not theoretically grounded for the derived BeLLMan equations, leading to non-zero gradients even when the objective is already optimal (e.g., zero reward). ShiQ introduces a reward shaping technique based on the log-partition function of the reference model (ϕ(s)=ref(s)\phi(s) = -ref(s)) to modify the BeLLMan equation. The new target function (s,a)\ell(s,a) is related to g(s,a)g(s,a) by (s,a)=g(s,a)+ϕ(s)\ell(s,a) = g(s,a) + \phi(s). The BeLLMan equation is then derived for \ell, making initialization with reference logits ref(s,a)ref(s,a) a more natural starting point for learning the optimal policy's logits.
  3. Multi-step extension: LLM rewards are often sparse (e.g., a single reward at the end of the sequence). A one-step BeLLMan update backpropagates this signal very slowly. ShiQ adopts a multi-step consistency equation (similar to Path Consistency Learning) that relates the value difference β(v(st)ref(st))\beta(v_\ell(s_t) - ref(s_t)) at state sts_t to the sum of future regularized rewards from tt to the end of the trajectory. This allows faster and more effective propagation of rewards, especially sparse ones.

Combining these three steps, ShiQ optimizes a token-level loss function based on the multi-step BeLLMan equation for the logits \ell:

1
L_ShiQ(\ell) = E_{x,y~D}[ sum_{t=1}^{|y|} ( sum_{k=t}^{|y|} gamma^{k-t}(r(s_k^{xy},a_k^{xy}) - beta ln(pi_l(a_k^{xy}|s_k^{xy}) / ref(a_k^{xy}|s_k^{xy}))) - beta(v_l(s_t^{xy}) - ref(s_t^{xy})) )^2 ]
In the common case of sequence-level reward R(x,y)R(x,y) given only at the end (r(st,at)=R(x,y)r(s_t, a_t) = R(x,y) if t=yt=|y|, 0 otherwise) and γ=1\gamma=1, this simplifies in LLM notation to:
1
L_ShiQ(ell) = E_{x,y~D}[ sum_{t=1}^{|y|} (R(x,y) - beta ln(pi_ell(y_{>=t}|x, y_{<t}) / ref(y_{>=t}|x, y_{<t})) - beta(v_ell(x \oplus y_{<t}) - ref(x \oplus y_{<t})) )^2 ]
Where y>=ty_{>=t} denotes the subsequence from token tt to the end. The loss is computed for every token in the sampled trajectory, making it a token-level objective. This allows it to potentially leverage token-wise rewards and accelerate learning for sparse rewards compared to sequence-level losses. ShiQ is off-policy, meaning it can train on any dataset D\mathcal{D} of trajectories, whether collected on-policy, offline, or from a replay buffer.

The paper evaluates ShiQ on synthetic bandit and gridworld tasks, demonstrating its theoretical properties, ability to handle fine-grained rewards, and faster convergence compared to baselines like DPO and CoPG. On LLM benchmarks (Anthropic-Harmless/Helpful, UltraFeedback, BFCL-V3) using a 7B parameter model, ShiQ shows competitive performance in single-turn settings (matching or exceeding CoPG and DRO in reward optimization while maintaining comparable KL divergence) and superior performance in multi-turn scenarios (BFCL-V3), attributed to its ability to leverage fine-grained, per-turn information.

Practical Implementation Details:

  • Single Network: ShiQ trains only the logits of the LLM (the \ell function), avoiding the need for multiple large networks (actor, critic, target networks) common in other RL approaches, which saves significant memory and computational resources.
  • Off-Policy: Training can be done entirely offline on a fixed dataset of prompts and responses. This is crucial for reducing the high cost of online sampling with large LLMs.
  • Token-wise Loss: The loss function is computed and summed over each token position in the sequence. This provides denser gradient signals, potentially speeding up learning, especially with fine-grained or multi-turn rewards.
  • Reward Flexibility: Can handle sequence-level rewards, turn-level rewards, or token-level rewards.
  • Simple Inference: Once trained, the policy is simply the softmax over the learned logits, allowing standard decoding methods without any modifications or dependence on the reference model during inference.

Implementation Considerations:

  • Computational Cost: While off-policy, computing the loss still involves forward and backward passes through the LLM for potentially long sequences and requires summing over tokens, which can be computationally intensive.
  • Data Requirements: Assumes access to trajectories with associated rewards. The theory guarantees optimality under full support of the data distribution matching the reference policy's support, but empirical performance relies on the quality and coverage of the offline dataset.
  • Reward Model: Like many alignment methods, ShiQ relies on a reward model to provide feedback. The performance is inherently limited by the quality of the reward model.
  • Hyperparameters: The temperature parameter β\beta and the discount factor γ\gamma (though often set to 1 in finite-horizon LLM tasks) are important hyperparameters requiring tuning.
  • Normalization: Normalizing the loss by the number of tokens (similar to supervised fine-tuning) is a practical implementation detail.

ShiQ offers a theoretically grounded, practical alternative for RL fine-tuning of LLMs, particularly promising for settings with offline data, fine-grained rewards, or multi-turn interactions. Future work includes evaluating it on a wider range of tasks and datasets, potentially incorporating online data collection, and developing mechanisms to mitigate the risks associated with optimizing a potentially flawed learned reward function.

Youtube Logo Streamline Icon: https://streamlinehq.com