ST-PPO: Stabilized PPO for LLMs
- ST-PPO is a reinforcement learning algorithm that adapts PPO for multi-turn LLM tasks by incorporating turn-level importance sampling.
- It corrects clipping bias through gradient normalization to reduce high-variance updates and promote training stability.
- Empirical evaluations in multi-hop and medical QA tasks demonstrate that ST-PPO outperforms token-level PPO with improved success rates and accuracy.
ST-PPO (Stabilized Off-Policy Proximal Policy Optimization) is a reinforcement learning algorithm designed to stabilize and enhance the training of LLMs acting as multi-turn agents. ST-PPO addresses instability in Proximal Policy Optimization (PPO) that arises when applying token-level optimization in multi-turn tasks such as multi-hop question answering, search, and reasoning. By introducing turn-level importance sampling and clipping-bias correction, ST-PPO aligns the optimization granularity with the structure of multi-turn environments and normalizes high-variance gradients from off-policy samples, resulting in improved stability and performance in large-model training contexts (Li et al., 25 Nov 2025).
1. Algorithmic Structure and Definitions
ST-PPO operates within a multi-turn Markov decision process (MDP), where interaction is decomposed into discrete “turns”—contiguous sequences of agent-generated tokens bounded by tool calls or special markers such as <eot>. The full trajectory consists of turns, each defined by its boundary . Let denote the user query, the -th output token, and the -th turn.
For each turn , state and action are defined. The policy , parameterized as an auto-regressive LLM (LM), outputs . The critic estimates turn-level state value. Token-level advantages are computed using Generalized Advantage Estimation (GAE) with discount and GAE parameter . The PPO clipping parameter is set to $0.2$.
Standard token-level PPO surrogate objective is:
where .
ST-PPO integrates two modifications:
- Turn-level importance sampling
- Clipping-bias correction
Algorithmic steps (Algorithm 1) include trajectory rollout, turn detection via loss mask or end-of-turn tokens, GAE advantage computation, gradient formation using turn-level ratios, calculation of clipping-bias norms and , surrogate gradient normalization, and updates to policy and critic.
2. Turn-Level Importance Sampling
In multi-turn tasks, the dialogue trajectory is segmented by grouping agent tokens (loss_mask=1) as turns. Each turn’s importance sampling ratio is defined as the geometric mean of token-level ratios:
This reduces variance compared to full product sequence-level ratios, yet preserves sub-goal credit assignment. The turn-level PPO surrogate objective is:
Under inactive clipping, the gradient assignment aggregates token advantages for each turn, weighted by . Lemma 1 formalizes the resulting clean turn-level credit assignment.
This approach matches the natural decomposition of multi-turn tasks into reasoning and tool-call stages. Token-level importance sampling is overly noisy—variance increases as off-policy drift grows—while sequence-level sampling discards useful sub-goal structure. Turn-level importance sampling strikes a balance.
3. Clipping-Bias Correction
PPO’s clipped surrogate discards tokens or turns whose importance ratio falls outside , introducing a systematic bias term in the gradient. Gradient decomposition (Lemma 2) yields:
For token-level PPO:
where is the set of tokens with inactive clipping.
During large-model training, grows and oscillates, indicating unreliable critic estimation and off-policy drift. To dampen such gradients, S-PPO rescales by :
For turn-level PPO, similar bias correction applies:
with .
Samples associated with high clipping bias are down-weighted, resulting in stabilized gradient variance.
4. ST-PPO: Combined Approach
ST-PPO synthesizes turn-level importance sampling and clipping-bias correction. The turn-level ratio is used for credit assignment and gradients are normalized by the turn-level clipping bias norm :
Since division by a positive scalar preserves the gradient direction, ST-PPO has fixed points identical to Turn-PPO, but applies more conservative updates when samples are risky. Figure diagnostics demonstrate that this procedure effectively stabilizes the training process.
5. Theoretical Properties and Stability Analysis
Lemma 1 proves that turn-level importance sampling yields correct credit assignment and aggregates token-level advantages proportionally to the turn’s geometric mean ratio. Lemma 2 decomposes PPO’s gradient, identifying the clipping-bias term’s contribution to instability as off-policy drift intensifies.
Down-weighting the clipping-bias term using ST-PPO’s normalization controls both variance and bias in the learning signal. Experimental diagnostics (e.g., gradient norm and clipping ratio curves in Figures 2–5) show that ST-PPO maintains lower gradient magnitudes and reduced clipping rates, preventing training collapse. Although no closed-form variance bounds are provided, -norm trends empirically validate improved stability.
6. Empirical Evaluation
Experiments examine general QA (Natural Questions), multi-hop QA (HotpotQA), and medical multiple-choice QA tasks (AlphaMed19K, MedQA, MedMCQA, PubMedQA, MMLU-M, MedXpert). Models use a 3-passage dense retriever on Wikipedia and the Qwen-2.5-7B base policy. Evaluation metrics include Exact Match (EM), success rate, and accuracy.
Findings are summarized as:
- Token-level PPO and GRPO collapse mid-training and require early stopping.
- Turn-level PPO improves stability but still collapses on larger models.
- S-PPO prevents collapse and improves peak performance.
- ST-PPO achieves smooth, stable learning curves and superior success rates.
Stability metrics indicate ST-PPO and S-PPO achieve 10–20% clipping ratios (versus 40–60% for token-level PPO), consistently lower KL divergence to the behavior policy, and reduced gradient norms. Ablations show complementary effects: turn-level IS lowers gradient norm and boosts performance, bias-correction alone stabilizes training, and ST-PPO outperforms both. In medical QA (Table 2), ST-PPO attains 49.90% average accuracy, exceeding Search-R1 (token-level PPO RL, 45.37%) and baseline retrieval-augmented generation (RAG) and chain-of-thought (CoT) models.
7. Implementation Guidance and Practical Recommendations
Key hyperparameters consist of:
- Hardware: 8 × NVIDIA H100, FSDP with offloading, gradient checkpointing.
- Policy learning rate: ; Critic learning rate:
- Warm-up ratios: 0.285 (policy), 0.015 (critic)
- Effective batch size: 512; Mini-batches: 256; Micro-batches: 64 (policy), 8 (critic)
- GAE ,
- PPO KL penalty coefficient: 0.001; Clipping parameter
- Maximum tokens: 4096; Response ≤ 500; Context ≤ 2048; Retrieved ≤ 500
- Sampling through vLLM (TP size 4), GPU memory utilization 0.6, temperature 1.0, top- 1.0
- Token grouping for turn detection: agent tokens (loss_mask=1) segmented between environment tokens
Practically, monitoring the clipping ratio and clipping-bias norm is recommended. If training instability or clipping ratios >50% arise, clipping-bias normalization should be introduced. empirically balances off-policy weight and variance. If critic reliability deteriorates, a cold restart may be employed, but ST-PPO largely obviates this need. Surging can signal the requirement for more frequent critic updates or smaller learning rates.
A plausible implication is that in large, multi-turn agent tasks, applying both turn-level importance sampling and clipping-bias correction together offers substantial robustness against training collapse, outperforming baseline PPO and its single-modification counterparts.
In summary, ST-PPO extends standard PPO by matching optimization granularity with task structure and counteracting high-variance, unreliable updates. Empirical and theoretical results confirm that ST-PPO yields stable training dynamics and superior performance for multi-turn LLM agent tasks, without necessitating early stopping or intricate manual intervention (Li et al., 25 Nov 2025).