- The paper introduces ShiQ, an off-policy RL algorithm that adapts Bellman equations for fine-tuning LLMs using a token-level objective.
- It reparameterizes the Q-function into a scoring function, enabling direct softmax inference without the need for extra reference model adjustments.
- ShiQ accelerates learning with a multi-step consistency update that efficiently propagates sparse rewards, outperforming baselines in multi-turn tasks.
This paper introduces ShiQ (Shifted-Q), a novel offline reinforcement learning algorithm for fine-tuning LLMs based on adapting Q-learning and BeLLMan equations. Unlike traditional RL methods for LLMs like PPO or REINFORCE, which are on-policy and computationally expensive due to requiring fresh samples for each update, ShiQ is off-policy, allowing training on fixed datasets without costly model rollouts. Furthermore, while direct alignment methods (like DPO, SLiC) are off-policy, they are typically limited to preference data and aggregate sequence-level rewards. ShiQ is designed to handle arbitrary reward functions and leverage token-wise signals when available.
The core contribution of ShiQ lies in carefully adapting standard Q-learning concepts to the specific characteristics of LLMs, which are autoregressive models whose logits are typically interpreted as unnormalized log-probabilities. The authors identify three main challenges:
- Easing sampling: A standard Q-learning policy samples proportional to exp(Q(s,a)/β). Applying this directly to LLM logits would require loading both the learned Q-model and the reference model (for the KL term) during inference and managing the temperature parameter. ShiQ addresses this by reparameterizing the Q-function q(s,a) into a new scoring function g(s,a) such that βg(s,a)=q(s,a)+βlnref(a∣s). The optimal policy then becomes proportional to exp(g(s,a)), meaning direct softmax over the learned logits g gives the optimal policy, simplifying inference.
- Improved initialization: LLMs are fine-tuned from a pre-trained reference model. Ideally, the fine-tuning algorithm should leverage this strong initialization. Simply interpreting the reference logits as initial Q-values or g values is not theoretically grounded for the derived BeLLMan equations, leading to non-zero gradients even when the objective is already optimal (e.g., zero reward). ShiQ introduces a reward shaping technique based on the log-partition function of the reference model (ϕ(s)=−ref(s)) to modify the BeLLMan equation. The new target function ℓ(s,a) is related to g(s,a) by ℓ(s,a)=g(s,a)+ϕ(s). The BeLLMan equation is then derived for ℓ, making initialization with reference logits ref(s,a) a more natural starting point for learning the optimal policy's logits.
- Multi-step extension: LLM rewards are often sparse (e.g., a single reward at the end of the sequence). A one-step BeLLMan update backpropagates this signal very slowly. ShiQ adopts a multi-step consistency equation (similar to Path Consistency Learning) that relates the value difference β(vℓ(st)−ref(st)) at state st to the sum of future regularized rewards from t to the end of the trajectory. This allows faster and more effective propagation of rewards, especially sparse ones.
Combining these three steps, ShiQ optimizes a token-level loss function based on the multi-step BeLLMan equation for the logits ℓ:
1
|
L_ShiQ(\ell) = E_{x,y~D}[ sum_{t=1}^{|y|} ( sum_{k=t}^{|y|} gamma^{k-t}(r(s_k^{xy},a_k^{xy}) - beta ln(pi_l(a_k^{xy}|s_k^{xy}) / ref(a_k^{xy}|s_k^{xy}))) - beta(v_l(s_t^{xy}) - ref(s_t^{xy})) )^2 ] |
In the common case of sequence-level reward
R(x,y) given only at the end (
r(st,at)=R(x,y) if
t=∣y∣, 0 otherwise) and
γ=1, this simplifies in LLM notation to:
1
|
L_ShiQ(ell) = E_{x,y~D}[ sum_{t=1}^{|y|} (R(x,y) - beta ln(pi_ell(y_{>=t}|x, y_{<t}) / ref(y_{>=t}|x, y_{<t})) - beta(v_ell(x \oplus y_{<t}) - ref(x \oplus y_{<t})) )^2 ] |
Where
y>=t denotes the subsequence from token
t to the end. The loss is computed for every token in the sampled trajectory, making it a token-level objective. This allows it to potentially leverage token-wise rewards and accelerate learning for sparse rewards compared to sequence-level losses. ShiQ is off-policy, meaning it can train on any dataset
D of trajectories, whether collected on-policy, offline, or from a replay buffer.
The paper evaluates ShiQ on synthetic bandit and gridworld tasks, demonstrating its theoretical properties, ability to handle fine-grained rewards, and faster convergence compared to baselines like DPO and CoPG. On LLM benchmarks (Anthropic-Harmless/Helpful, UltraFeedback, BFCL-V3) using a 7B parameter model, ShiQ shows competitive performance in single-turn settings (matching or exceeding CoPG and DRO in reward optimization while maintaining comparable KL divergence) and superior performance in multi-turn scenarios (BFCL-V3), attributed to its ability to leverage fine-grained, per-turn information.
Practical Implementation Details:
- Single Network: ShiQ trains only the logits of the LLM (the ℓ function), avoiding the need for multiple large networks (actor, critic, target networks) common in other RL approaches, which saves significant memory and computational resources.
- Off-Policy: Training can be done entirely offline on a fixed dataset of prompts and responses. This is crucial for reducing the high cost of online sampling with large LLMs.
- Token-wise Loss: The loss function is computed and summed over each token position in the sequence. This provides denser gradient signals, potentially speeding up learning, especially with fine-grained or multi-turn rewards.
- Reward Flexibility: Can handle sequence-level rewards, turn-level rewards, or token-level rewards.
- Simple Inference: Once trained, the policy is simply the softmax over the learned logits, allowing standard decoding methods without any modifications or dependence on the reference model during inference.
Implementation Considerations:
- Computational Cost: While off-policy, computing the loss still involves forward and backward passes through the LLM for potentially long sequences and requires summing over tokens, which can be computationally intensive.
- Data Requirements: Assumes access to trajectories with associated rewards. The theory guarantees optimality under full support of the data distribution matching the reference policy's support, but empirical performance relies on the quality and coverage of the offline dataset.
- Reward Model: Like many alignment methods, ShiQ relies on a reward model to provide feedback. The performance is inherently limited by the quality of the reward model.
- Hyperparameters: The temperature parameter β and the discount factor γ (though often set to 1 in finite-horizon LLM tasks) are important hyperparameters requiring tuning.
- Normalization: Normalizing the loss by the number of tokens (similar to supervised fine-tuning) is a practical implementation detail.
ShiQ offers a theoretically grounded, practical alternative for RL fine-tuning of LLMs, particularly promising for settings with offline data, fine-grained rewards, or multi-turn interactions. Future work includes evaluating it on a wider range of tasks and datasets, potentially incorporating online data collection, and developing mechanisms to mitigate the risks associated with optimizing a potentially flawed learned reward function.