Papers
Topics
Authors
Recent
Search
2000 character limit reached

Prefix Grouper for Scalable GRPO

Updated 1 January 2026
  • Prefix Grouper is an efficient algorithm that restructures self-attention to eliminate redundant encoding of shared prefixes in GRPO.
  • It concatenates multiple roll-outs into a single sequence, reducing computational overhead by up to a group-fold improvement.
  • Empirical evaluations confirm that it achieves equivalent forward activations and backward gradients with significant FLOP and memory savings.

Prefix Grouper is an efficient training algorithm for Group Relative Policy Optimization (GRPO), specifically designed to address the computational bottlenecks posed by long shared prefixes in sequence modeling tasks. By restructuring self-attention to eliminate redundant encoding of shared prefixes, Prefix Grouper facilitates scalable GRPO for large models and long-context environments while retaining equivalence to standard GRPO in both forward activations and backward gradients (Liu et al., 5 Jun 2025).

1. Computational Bottleneck in Group Relative Policy Optimization

Group Relative Policy Optimization (GRPO) enhances policy learning through relative comparisons among GG candidate outputs originating from a common input prefix PP. In the canonical approach, known as "Repeated-Prefix Forward," each output sequence xi=[P;Ri]x_i = [P; R_i] is independently processed by self-attention, requiring redundant re-encoding of the shared prefix PP for all GG group members. When the prefix length LpL_p is much greater than the suffix length LrL_r, the computational and memory cost scale linearly with GG, which presents a major bottleneck in long-context and multi-modal settings.

2. Shared-Prefix Forward Strategy and Attention Restructuring

Prefix Grouper eliminates redundant prefix encoding by concatenating all roll-outs into a single sequence:

Xours=[P;R1;R2;… ;RG]∈R(Lp+∑iLr,i)×DX_{ours} = [P; R_1; R_2; \dots; R_G] \in \mathbb{R}^{(L_p + \sum_i L_{r,i}) \times D}

Self-attention is decomposed per layer into two operations:

  • Prefix-only attention: Computes attention among prefix tokens Oprefix=Attn(Qprefix,Kprefix,Vprefix,maskprefix)O_{\text{prefix}} = \text{Attn}(Q_{\text{prefix}}, K_{\text{prefix}}, V_{\text{prefix}}, {\rm mask}_{\text{prefix}}).
  • Suffix attention: For each suffix, computes attention over the entire prefix and its own suffix Osuffix=Attn(Qsuffix,[Kprefix;Ksuffix],[Vprefix;Vsuffix],maskfull)O_{\text{suffix}} = \text{Attn}(Q_{\text{suffix}}, [K_{\text{prefix}}; K_{\text{suffix}}], [V_{\text{prefix}}; V_{\text{suffix}}], {\rm mask}_{\text{full}}).

The outputs are concatenated along the sequence dimension. Masking schemes ensure causal behavior, with maskprefix\text{mask}_{\text{prefix}} applied to the prefix-only attention and maskfull\text{mask}_{\text{full}} enforcing causality over prefix and suffix. This design reuses KprefixK_{\text{prefix}} and VprefixV_{\text{prefix}} exactly once for all GG suffixes, removing O(G)O(G) duplication.

3. Implementation Details and Autograd Semantics

The algorithm is instantiated with a PyTorch-style grouped attention function and a custom autograd function, PrefixGrouper. In the forward pass, queries, keys, and values are split into prefix and suffix components, each processed by the corresponding attention call. In the backward pass, gradients received from the output are propagated separately through suffix and prefix attention paths. Gradients with respect to KprefixK_{\text{prefix}} and VprefixV_{\text{prefix}} are aggregated from both prefix-only and suffix attention flows, carefully summing contributions to mirror the original GRPO gradient accumulation.

1
2
3
4
5
6
7
8
9
def grouped_attention(q, k, v, prefix_lengths, suffix_lengths, attention_fn):
    q_pref, q_suf = q.split([L_p, sum(suffix_lengths)], dim=2)
    k_pref, k_suf = k.split([L_p, sum(suffix_lengths)], dim=2)
    v_pref, v_suf = v.split([L_p, sum(suffix_lengths)], dim=2)
    out_pref, _ = attention_fn(q_pref, k_pref, v_pref, mask_pref)
    k_full = concat(k_pref, k_suf, dim=2)
    v_full = concat(v_pref, v_suf, dim=2)
    out_suf, _ = attention_fn(q_suf, k_full, v_full, mask_full)
    return concat(out_pref, out_suf, dim=2)

This ensures full differentiability and compatibility with standard end-to-end training frameworks.

4. Formal Equivalence to Standard GRPO

Lemma 3.1 establishes theoretical equivalence: for any policy loss JJ,

∇θJours(Xours,A)≡∇θJbase(Xbase,A)\nabla_\theta J_{\text{ours}}(X_{\text{ours}},A) \equiv \nabla_\theta J_{\text{base}}(X_{\text{base}},A)

Token-wise, forward outputs remain identical, as every token receives the same attention and feed-forward subcomputations. For gradients, the loss depends solely on suffix tokens, and each prefix key/value participates in both prefix-only and all relevant suffix attention calls, resulting in identical gradient sums for each prefix token as in the baseline method.

5. Computational and Memory Complexity

Let LpL_p denote prefix length, LrL_r uniform suffix length, GG group size, nn the number of attention heads, and dd head dimension. The attention operation FLOP counts are:

Algorithm Attention FLOPs
Baseline (Repeated Prefix) G(Lp+Lr)2ndG(L_p + L_r)^2 n d
Prefix Grouper (Ours) Lp2nd+GLr(2Lp+Lr)ndL_p^2 n d + G L_r (2L_p + L_r) n d

For Lp≫LrL_p \gg L_r, the ratio CoursCbase→1G\frac{C_{\text{ours}}}{C_{\text{base}}} \rightarrow \frac{1}{G}, realizing up to GG-fold reduction in computational cost and memory consumption. Pointwise operations (FFN and QKV projections) scale similarly, allowing storage and computation for KprefixK_{\text{prefix}} and VprefixV_{\text{prefix}} only once per group.

6. Integration with Existing GRPO Pipelines

Prefix Grouper is designed for seamless integration with existing GRPO-based architectures:

  • Data loader: Inputs switch from GG separate sequences [P;Ri][P; R_i] to the single sequence [P;R1;…;RG][P; R_1; \ldots; R_G], with LpL_p and suffix lengths recorded.
  • Attention wrapper: Every self-attention call is replaced with its "prefix grouper" variant.
  • Positional encoding: Tokens utilize the same absolute positional indices, maintaining compatibility with schemes such as RoPE.
  • No changes are required to model weights, optimizer, or non-attention layers. Modifications are limited to data construction and attention computation calls.

7. Empirical Evaluation and Observed Performance

Experiments measure FLOPs and GPU memory usage for group sizes G=2,4,8,16G=2,4,8,16 and fixed Lp∈{4096,8192,16384}L_p \in \{4096,8192,16384\}:

  • FLOPs are reduced by almost GG-fold for large LpL_p, with observed values such as 7.8×7.8\times reduction for G=8G=8, Lp=16384L_p=16384.
  • Memory savings align with theoretical predictions, enabling the use of larger batch or group sizes.
  • Policy learning curves and final rewards are empirically identical between Prefix Grouper and standard GRPO on toy reasoning benchmarks, confirming gradient and activation equivalence.

8. Principal Properties and Application Scope

Prefix Grouper delivers O(G)O(G) savings in FLOPs and memory for Lp≫LrL_p \gg L_r, with plug-and-play compatibility requiring minimal changes to existing pipelines. No architectural changes or additional trainable parameters are necessary. Its benefits are maximized when the shared prefix dominates the input, and index-splitting incurs only negligible computational overhead. Typical application domains include long-context RL, multi-QA judge models, and multimodal reasoning where shared prefixes are prevalent.

Prefix Grouper thus enables scalable GRPO for complex tasks and large models while retaining fidelity in both optimization and empirical performance (Liu et al., 5 Jun 2025).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Topic to Video (Beta)

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Prefix Grouper.