Papers
Topics
Authors
Recent
2000 character limit reached

ST-PPO: Stabilized PPO for LLMs

Updated 29 November 2025
  • ST-PPO is a reinforcement learning algorithm that adapts PPO for multi-turn LLM tasks by incorporating turn-level importance sampling.
  • It corrects clipping bias through gradient normalization to reduce high-variance updates and promote training stability.
  • Empirical evaluations in multi-hop and medical QA tasks demonstrate that ST-PPO outperforms token-level PPO with improved success rates and accuracy.

ST-PPO (Stabilized Off-Policy Proximal Policy Optimization) is a reinforcement learning algorithm designed to stabilize and enhance the training of LLMs acting as multi-turn agents. ST-PPO addresses instability in Proximal Policy Optimization (PPO) that arises when applying token-level optimization in multi-turn tasks such as multi-hop question answering, search, and reasoning. By introducing turn-level importance sampling and clipping-bias correction, ST-PPO aligns the optimization granularity with the structure of multi-turn environments and normalizes high-variance gradients from off-policy samples, resulting in improved stability and performance in large-model training contexts (Li et al., 25 Nov 2025).

1. Algorithmic Structure and Definitions

ST-PPO operates within a multi-turn Markov decision process (MDP), where interaction is decomposed into discrete “turns”—contiguous sequences of agent-generated tokens bounded by tool calls or special markers such as <eot>. The full trajectory y=(y1;y2;;yK)y = (y^1; y^2; \ldots; y^K) consists of KK turns, each defined by its boundary (tkstart,tkend)(t_k^{\text{start}}, t_k^{\text{end}}). Let xDx \in \mathcal{D} denote the user query, yty_t the tt-th output token, and yk=(ytkstart,,ytkend)y^k = (y_{t_k^{\text{start}}}, \ldots, y_{t_k^{\text{end}}}) the kk-th turn.

For each turn kk, state sk=(x,y1,,yk1)s_k = (x, y^1, \ldots, y^{k-1}) and action ak=yka_k = y^k are defined. The policy πθ\pi_\theta, parameterized as an auto-regressive LLM (LM), outputs πθ(ytx,y<t)\pi_\theta(y_t | x, y_{<t}). The critic Vϕ(sk)V_\phi(s_k) estimates turn-level state value. Token-level advantages A^t\hat{A}_t are computed using Generalized Advantage Estimation (GAE) with discount γ\gamma and GAE parameter λ\lambda. The PPO clipping parameter ϵ\epsilon is set to $0.2$.

Standard token-level PPO surrogate objective is:

JPPO(θ)=Ex,yπθold[1yt=1ymin(wt(θ)A^t,clip(wt(θ),1ϵ,1+ϵ)A^t)]J_\text{PPO}(\theta) = \mathbb{E}_{x,y \sim \pi_{\theta_\text{old}}} \left[ \frac{1}{|y|} \sum_{t=1}^{|y|} \min \left( w_t(\theta) \hat{A}_t, \text{clip}(w_t(\theta), 1-\epsilon, 1+\epsilon) \hat{A}_t \right) \right]

where wt(θ)=πθ(ytx,y<t)πθold(ytx,y<t)w_t(\theta) = \frac{\pi_\theta(y_t | x, y_{<t})}{\pi_{\theta_\text{old}}(y_t | x, y_{<t})}.

ST-PPO integrates two modifications:

  1. Turn-level importance sampling
  2. Clipping-bias correction

Algorithmic steps (Algorithm 1) include trajectory rollout, turn detection via loss mask or end-of-turn tokens, GAE advantage computation, gradient formation using turn-level ratios, calculation of clipping-bias norms CturnC_\text{turn} and CtokenC_\text{token}, surrogate gradient normalization, and updates to policy and critic.

2. Turn-Level Importance Sampling

In multi-turn tasks, the dialogue trajectory is segmented by grouping agent tokens (loss_mask=1) as turns. Each turn’s importance sampling ratio is defined as the geometric mean of token-level ratios:

wturnk(θ)=(πθ(ykx,y<k)πθold(ykx,y<k))1/yk=exp[1ykt=tkstarttkendlog(πθ(ytx,y<t)πθold(ytx,y<t))]w^k_\text{turn}(\theta) = \left( \frac{\pi_\theta(y^k | x, y^{<k})}{\pi_{\theta_\text{old}}(y^k | x, y^{<k})} \right)^{1/|y^k|} = \exp\left[ \frac{1}{|y^k|} \sum_{t = t_k^{\text{start}}}^{t_k^{\text{end}}} \log \left( \frac{\pi_\theta(y_t|x,y_{<t})}{\pi_{\theta_\text{old}}(y_t|x,y_{<t})} \right) \right]

This reduces variance compared to full product sequence-level ratios, yet preserves sub-goal credit assignment. The turn-level PPO surrogate objective is:

JTurn-PPO(θ)=Ex,yπθold[1yk=1Kt=tkstarttkendmin(wturnk(θ)A^t,clip(wturnk(θ),1ϵ,1+ϵ)A^t)]J_\text{Turn-PPO}(\theta) = \mathbb{E}_{x, y \sim \pi_{\theta_\text{old}}} \left[ \frac{1}{|y|} \sum_{k=1}^K \sum_{t = t_k^{\text{start}}}^{t_k^{\text{end}}} \min (w^k_\text{turn}(\theta) \hat{A}_t, \text{clip}(w^k_\text{turn}(\theta), 1-\epsilon, 1+\epsilon) \hat{A}_t ) \right]

Under inactive clipping, the gradient assignment aggregates token advantages A^k\hat{A}^k for each turn, weighted by wturnkw^k_\text{turn}. Lemma 1 formalizes the resulting clean turn-level credit assignment.

This approach matches the natural decomposition of multi-turn tasks into reasoning and tool-call stages. Token-level importance sampling is overly noisy—variance increases as off-policy drift grows—while sequence-level sampling discards useful sub-goal structure. Turn-level importance sampling strikes a balance.

3. Clipping-Bias Correction

PPO’s clipped surrogate discards tokens or turns whose importance ratio falls outside [1ϵ,1+ϵ][1-\epsilon, 1+\epsilon], introducing a systematic bias term in the gradient. Gradient decomposition (Lemma 2) yields:

θJPPO=(Off-policy term)+(Advantage-estimation error term)(Clipping-bias term C(θ))\nabla_\theta J_\text{PPO} = \text{(Off-policy term)} + \text{(Advantage-estimation error term)} - \text{(Clipping-bias term } C(\theta))

For token-level PPO:

Ctoken(θ)=E[1yt=1y1tBtokenwtA^t]C_\text{token}(\theta) = \mathbb{E} \left[ \frac{1}{|y|} \sum_{t=1}^{|y|} \mathbb{1}_{t \notin \mathcal{B}_\text{token}} \cdot w_t \hat{A}_t \right]

where Btoken\mathcal{B}_\text{token} is the set of tokens with inactive clipping.

During large-model training, Ctoken(θ)2\Vert C_\text{token}(\theta) \Vert_2 grows and oscillates, indicating unreliable critic estimation and off-policy drift. To dampen such gradients, S-PPO rescales by 1/Ctoken(θ)21 / \Vert C_\text{token}(\theta) \Vert_2:

θJS-PPO(θ)1Ctoken(θ)2θJPPO(θ)\nabla_\theta J_\text{S-PPO}(\theta) \equiv \frac{1}{\Vert C_\text{token}(\theta) \Vert_2} \cdot \nabla_\theta J_\text{PPO}(\theta)

For turn-level PPO, similar bias correction applies:

Cturn(θ)=E[1yt=1y1tBturnwk(t)turnA^t]C_\text{turn}(\theta) = \mathbb{E} \left[ \frac{1}{|y|} \sum_{t=1}^{|y|} \mathbb{1}_{t \notin \mathcal{B}_\text{turn}} \cdot w^k(t)_\text{turn} \hat{A}_t \right]

with θJST-PPO(θ)</h1><p>1Cturn(θ)2θJTurn-PPO(θ)\nabla_\theta J_\text{ST-PPO}(\theta)</h1> <p>\frac{1}{\Vert C_\text{turn}(\theta) \Vert_2} \cdot \nabla_\theta J_\text{Turn-PPO}(\theta) .

Samples associated with high clipping bias are down-weighted, resulting in stabilized gradient variance.

4. ST-PPO: Combined Approach

ST-PPO synthesizes turn-level importance sampling and clipping-bias correction. The turn-level ratio wturnkw^k_\text{turn} is used for credit assignment and gradients are normalized by the turn-level clipping bias norm Cturn2\Vert C_\text{turn} \Vert_2:

θJST-PPO(θ)=1Cturn(θ)2θJTurn-PPO(θ),with ϵ=0.2\nabla_\theta J_\text{ST-PPO}(\theta) = \frac{1}{\Vert C_\text{turn}(\theta) \Vert_2} \cdot \nabla_\theta J_\text{Turn-PPO}(\theta), \quad \text{with } \epsilon = 0.2

Since division by a positive scalar preserves the gradient direction, ST-PPO has fixed points identical to Turn-PPO, but applies more conservative updates when samples are risky. Figure diagnostics demonstrate that this procedure effectively stabilizes the training process.

5. Theoretical Properties and Stability Analysis

Lemma 1 proves that turn-level importance sampling yields correct credit assignment and aggregates token-level advantages proportionally to the turn’s geometric mean ratio. Lemma 2 decomposes PPO’s gradient, identifying the clipping-bias term’s contribution to instability as off-policy drift intensifies.

Down-weighting the clipping-bias term using ST-PPO’s normalization controls both variance and bias in the learning signal. Experimental diagnostics (e.g., gradient norm and clipping ratio curves in Figures 2–5) show that ST-PPO maintains lower gradient magnitudes and reduced clipping rates, preventing training collapse. Although no closed-form variance bounds are provided, L2L_2-norm trends empirically validate improved stability.

6. Empirical Evaluation

Experiments examine general QA (Natural Questions), multi-hop QA (HotpotQA), and medical multiple-choice QA tasks (AlphaMed19K, MedQA, MedMCQA, PubMedQA, MMLU-M, MedXpert). Models use a 3-passage dense retriever on Wikipedia and the Qwen-2.5-7B base policy. Evaluation metrics include Exact Match (EM), success rate, and accuracy.

Findings are summarized as:

  • Token-level PPO and GRPO collapse mid-training and require early stopping.
  • Turn-level PPO improves stability but still collapses on larger models.
  • S-PPO prevents collapse and improves peak performance.
  • ST-PPO achieves smooth, stable learning curves and superior success rates.

Stability metrics indicate ST-PPO and S-PPO achieve 10–20% clipping ratios (versus 40–60% for token-level PPO), consistently lower KL divergence to the behavior policy, and reduced gradient norms. Ablations show complementary effects: turn-level IS lowers gradient norm and boosts performance, bias-correction alone stabilizes training, and ST-PPO outperforms both. In medical QA (Table 2), ST-PPO attains 49.90% average accuracy, exceeding Search-R1 (token-level PPO RL, 45.37%) and baseline retrieval-augmented generation (RAG) and chain-of-thought (CoT) models.

7. Implementation Guidance and Practical Recommendations

Key hyperparameters consist of:

  • Hardware: 8 × NVIDIA H100, FSDP with offloading, gradient checkpointing.
  • Policy learning rate: 1×1061 \times 10^{-6}; Critic learning rate: 1×1051 \times 10^{-5}
  • Warm-up ratios: 0.285 (policy), 0.015 (critic)
  • Effective batch size: 512; Mini-batches: 256; Micro-batches: 64 (policy), 8 (critic)
  • GAE λ=1\lambda = 1, γ=1\gamma = 1
  • PPO KL penalty coefficient: 0.001; Clipping parameter ϵ=0.2\epsilon = 0.2
  • Maximum tokens: 4096; Response ≤ 500; Context ≤ 2048; Retrieved ≤ 500
  • Sampling through vLLM (TP size 4), GPU memory utilization 0.6, temperature 1.0, top-pp 1.0
  • Token grouping for turn detection: agent tokens (loss_mask=1) segmented between environment tokens

Practically, monitoring the clipping ratio and clipping-bias norm is recommended. If training instability or clipping ratios >50% arise, clipping-bias normalization should be introduced. ϵ0.2\epsilon \approx 0.2 empirically balances off-policy weight and variance. If critic reliability deteriorates, a cold restart may be employed, but ST-PPO largely obviates this need. Surging Cturn2\Vert C_\text{turn} \Vert_2 can signal the requirement for more frequent critic updates or smaller learning rates.

A plausible implication is that in large, multi-turn agent tasks, applying both turn-level importance sampling and clipping-bias correction together offers substantial robustness against training collapse, outperforming baseline PPO and its single-modification counterparts.


In summary, ST-PPO extends standard PPO by matching optimization granularity with task structure and counteracting high-variance, unreliable updates. Empirical and theoretical results confirm that ST-PPO yields stable training dynamics and superior performance for multi-turn LLM agent tasks, without necessitating early stopping or intricate manual intervention (Li et al., 25 Nov 2025).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)
Slide Deck Streamline Icon: https://streamlinehq.com

Whiteboard

Forward Email Streamline Icon: https://streamlinehq.com

Follow Topic

Get notified by email when new papers are published related to ST-PPO.