Papers
Topics
Authors
Recent
2000 character limit reached

Adaptive Grouped Speculative Decoding (AGSD)

Updated 16 December 2025
  • Adaptive Grouped Speculative Decoding (AGSD) is a method that dynamically selects candidate token groups for LLM inference, using a fast draft model and a slower high-fidelity target model.
  • It models group selection as a Markov Decision Process with an optimal threshold policy, balancing longer speculative groups against the risk of wasted computation.
  • AGSD incorporates a learned acceptance prediction head and shows significant throughput improvements, recovering over 99% of optimal inference speed.

Adaptive Grouped Speculative Decoding (AGSD) refers to a class of algorithms that increase the throughput of LLM inference by adaptively selecting the length of speculative token candidate groups, in contrast to static-parameter speculative decoding. The framework leverages a fast draft model and a slower, high-fidelity target model; at each decoding step, the draft model generates a batch of candidate tokens, which are then validated by the target model. AGSD, as instantiated in SpecDec++ (Huang et al., 2024), formalizes candidate group selection as a Markov Decision Process (MDP) with an optimal threshold-based stopping rule, supplementing the draft model with a learned acceptance-prediction head to maximize inference speedup while minimizing wasted computation.

1. Problem Setting and MDP Formulation

The central problem in speculative decoding is to dynamically determine, at each generation round, whether to extend the candidate group by one more token (action “continue”) or hand off the current group to the target model for verification (action “stop”). This process is modeled as an MDP (S,A,P,c)(S, A, P, c):

  • State sSs \in S: s=(xprefix,(Y1,,Yk))s = (\mathbf{x}_{\text{prefix}}, (Y_1, \ldots, Y_k)), where xprefix\mathbf{x}_{\text{prefix}} is the prompt concatenated with all accepted tokens, and (Y1,,Yk)(Y_1, \ldots, Y_k) are candidate tokens sampled auto-regressively from the draft model qq.
  • Actions aA={,}a \in A = \{\rightsquigarrow, \checkmark\}: \rightsquigarrow denotes “continue” (draft one more token), \checkmark denotes “stop” (submit candidates for verification).
  • Transitions P(ss,a)P(s'|s,a): Continuing samples Yk+1 q(xprefix,Y1,,Yk)Y_{k+1}~q(\cdot|\mathbf{x}_{\text{prefix}}, Y_1, \ldots, Y_k), while stopping hands the candidate batch to the target model and transitions the state based on how many tokens were accepted.
  • Immediate cost c(s,a,s)c(s,a,s'): This cost models compute, where tdraftt_{\text{draft}} is the draft-model forward time, ttargett_{\text{target}} the target-model forward time, and costs accrue if any token in the current batch is eventually rejected, or for the full verification step.

The MDP formulation captures the tension between longer speculative groups (which can amortize target model cost) and the risk of wasted computation if the group is too long and many candidates are rejected (Huang et al., 2024).

2. Threshold Policy as Optimal Solution

The candidate group selection policy aims to minimize expected total inference cost. Let E1E_1 denote the event that at least one candidate in the current group (Y1,,Yk)(Y_1, \ldots, Y_k) will be rejected. Let pE1Pr(E1xprefix,(Y1,,Yk))p_{E_1} \equiv \Pr(E_1|\mathbf{x}_{\text{prefix}}, (Y_1, \ldots, Y_k)). The optimal policy, proven in Theorem 1 of (Huang et al., 2024), is:

π(s)={if pE1τ, otherwise,\pi^*(s) = \begin{cases} \checkmark & \text{if } p_{E_1} \geq \tau, \ \rightsquigarrow & \text{otherwise,} \end{cases}

where the threshold τ=c2+Δc1+c2+Δ\tau = \frac{c_2 + \Delta}{c_1 + c_2 + \Delta}, c1=tdraftc_1 = t_{\text{draft}}, c2=ttargettdraftc_2 = t_{\text{target}} - t_{\text{draft}}, and Δ\Delta upper-bounds downstream cost differentials. This result establishes that the greedy thresholding of the predicted group rejection probability is globally optimal under bounded time horizons. Thus, AGSD differs fundamentally from fixed-KK heuristics or statically tuned strategies (Huang et al., 2024).

3. Acceptance Prediction Head: Architecture and Training

To operationalize the threshold policy, AGSD augments the draft model with an acceptance-prediction head fθf_\theta, estimating (conditionally) per-token acceptance probabilities:

  • Architecture: fθf_\theta is a (D+1)(D+1)-layer ResNet (D{0,1,2,3,4}D\in\{0,1,2,3,4\}) with SiLU activations, mapping each last hidden state hih_i of a draft-token YiY_i to p^i=σ(fθ(hi))\hat{p}_i = \sigma(f_\theta(h_i)), i.e., the predicted conditional acceptance probability.
  • Training data: Samples xDx \sim D are constructed from prompts with target model responses (X1,,XN)(X_1,\ldots,X_N). The draft model is run in teacher-forcing mode, and for each position the true acceptance probability is pi=min(1,p(Yi)/q(Yi))p_i = \min(1, p(Y_i|\cdot)/q(Y_i|\cdot)).
  • Random mixing: To counteract class imbalance and sparsity, sequences Z1,,ZNZ_1,\ldots,Z_N are formed by sampling each ZiZ_i from XiX_i (with probability rr) or YiY_i (with probability $1-r$), with loss computed only where Zi=YiZ_i=Y_i.
  • Loss: Weighted binary cross-entropy,

L(θ)=Ei:Zi=Yi[waccpilogp^iwrej(1pi)log(1p^i)],L(\theta) = \mathbb{E}_{i:Z_i=Y_i}[-w_{\text{acc}}\cdot p_i \log \hat{p}_i - w_{\text{rej}} (1-p_i)\log (1-\hat{p}_i)],

with wrej{1,3,6,12}w_{\text{rej}} \in \{1,3,6,12\} to upweight rare rejections.

Calibration proceeds via held-out KL divergence and throughput tuning on validation data, securing robust probability estimates across candidate group lengths and distributions (Huang et al., 2024).

4. AGSD Online Algorithm: Candidate Length Adaptation

The adaptive candidate-length selection is performed online using the acceptance head’s predictions. The core procedure is:

  1. Initialize cumulative accept-probability Pacc1P_{\text{acc}} \leftarrow 1.
  2. For i=1,2,,Kmaxi = 1,2,\ldots, K_{\text{max}}:
    • For i>1i > 1, extract hi1h_{i-1} and compute p^i1=σ(fθ(hi1))\hat{p}_{i-1} = \sigma(f_\theta(h_{i-1})).
    • Update PaccPaccp^i1P_{\text{acc}} \leftarrow P_{\text{acc}} \cdot \hat{p}_{i-1}.
    • If 1Pacc>h1-P_{\text{acc}} > h (for threshold hh), break; else sample next candidate YiY_i.
  3. Let K=iK = i, submit (Y1,,YK)(Y_1,\ldots,Y_K) to the target for batch verification.
  4. Post-processing involves sampling which tokens are accepted and possibly performing necessary corrections as per speculative decoding protocol (see Algorithm SpecDec++ in (Huang et al., 2024)).

This adaptive procedure generalizes conventional speculative decoding by dynamically tailoring KK to the instantaneous predicted risk, achieving close to the oracle optimal throughput across both in-distribution and out-of-distribution tasks.

5. Throughput Analysis and Empirical Results

The total inference cost is:

Ttotal=tdraftNdraft+ttargetNtargetT_{\text{total}} = t_{\text{draft}}\cdot N_{\text{draft}} + t_{\text{target}}\cdot N_{\text{target}}

Per-token latency becomes:

Latency=Ttotal/N=tdraft+tdraftNdiscardedN+(ttargettdraft)NtargetN\text{Latency} = T_{\text{total}}/N = t_{\text{draft}} + t_{\text{draft}} \cdot \frac{N_{\text{discarded}}}{N} + (t_{\text{target}}-t_{\text{draft}})\cdot \frac{N_{\text{target}}}{N}

Resulting speedup is:

Speedup=ttargetLatencySpec\text{Speedup} = \frac{t_{\text{target}}}{\text{Latency}_{\text{Spec}}}

The following summarizes experimental measurements on llama-2-chat 7B draft and 70B target, with D=3D=3, wrej=6w_{\text{rej}}=6, and h=0.7h=0.7, across datasets:

Dataset Vanilla Target (tok/s) SpecDec (Fixed KK) (tok/s, ×) SpecDec++ (Adaptive) (tok/s, ×) Rel. Impr.
Alpaca 9.26 17.58 (1.90×) 18.88 (2.04×) +7.2%
HumanEval 9.26 18.52 (2.00×) 20.65 (2.23×) +11.1%
GSM8K 9.26 19.17 (2.07×) 20.93 (2.26×) +9.4%

AGSD achieves strictly lower discard and verification rates, yielding additional speedup at no cost to wasted computation. SpecDec++ recovers in excess of 99.3% of the optimal achievable throughput on all tested datasets (Huang et al., 2024).

6. Context and Comparisons

Prior speculative decoding approaches generally fix the candidate group size KK using simplistic heuristics or optimize under IID assumptions. AGSD, by modeling group selection as an MDP and leveraging learned acceptance probabilities, guarantees theoretically optimal cost-minimization policies. This framework allows robust generalization across data distributions (including out-of-distribution), and is compatible with LLMs modified with a small, efficiently trainable prediction head. A plausible implication is that online adaptive speculative decoding may become the new baseline for high-throughput LLM inference in production settings, given the consistent Pareto-improving performance demonstrated in (Huang et al., 2024).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Whiteboard

Topic to Video (Beta)

Follow Topic

Get notified by email when new papers are published related to Adaptive Grouped Speculative Decoding (AGSD).

Don't miss out on important new AI/ML research

See which papers are being discussed right now on X, Reddit, and more:

“Emergent Mind helps me see which AI papers have caught fire online.”

Philip

Philip

Creator, AI Explained on YouTube