Adaptive Grouped Speculative Decoding (AGSD)
- Adaptive Grouped Speculative Decoding (AGSD) is a method that dynamically selects candidate token groups for LLM inference, using a fast draft model and a slower high-fidelity target model.
- It models group selection as a Markov Decision Process with an optimal threshold policy, balancing longer speculative groups against the risk of wasted computation.
- AGSD incorporates a learned acceptance prediction head and shows significant throughput improvements, recovering over 99% of optimal inference speed.
Adaptive Grouped Speculative Decoding (AGSD) refers to a class of algorithms that increase the throughput of LLM inference by adaptively selecting the length of speculative token candidate groups, in contrast to static-parameter speculative decoding. The framework leverages a fast draft model and a slower, high-fidelity target model; at each decoding step, the draft model generates a batch of candidate tokens, which are then validated by the target model. AGSD, as instantiated in SpecDec++ (Huang et al., 2024), formalizes candidate group selection as a Markov Decision Process (MDP) with an optimal threshold-based stopping rule, supplementing the draft model with a learned acceptance-prediction head to maximize inference speedup while minimizing wasted computation.
1. Problem Setting and MDP Formulation
The central problem in speculative decoding is to dynamically determine, at each generation round, whether to extend the candidate group by one more token (action “continue”) or hand off the current group to the target model for verification (action “stop”). This process is modeled as an MDP :
- State : , where is the prompt concatenated with all accepted tokens, and are candidate tokens sampled auto-regressively from the draft model .
- Actions : denotes “continue” (draft one more token), denotes “stop” (submit candidates for verification).
- Transitions : Continuing samples , while stopping hands the candidate batch to the target model and transitions the state based on how many tokens were accepted.
- Immediate cost : This cost models compute, where is the draft-model forward time, the target-model forward time, and costs accrue if any token in the current batch is eventually rejected, or for the full verification step.
The MDP formulation captures the tension between longer speculative groups (which can amortize target model cost) and the risk of wasted computation if the group is too long and many candidates are rejected (Huang et al., 2024).
2. Threshold Policy as Optimal Solution
The candidate group selection policy aims to minimize expected total inference cost. Let denote the event that at least one candidate in the current group will be rejected. Let . The optimal policy, proven in Theorem 1 of (Huang et al., 2024), is:
where the threshold , , , and upper-bounds downstream cost differentials. This result establishes that the greedy thresholding of the predicted group rejection probability is globally optimal under bounded time horizons. Thus, AGSD differs fundamentally from fixed- heuristics or statically tuned strategies (Huang et al., 2024).
3. Acceptance Prediction Head: Architecture and Training
To operationalize the threshold policy, AGSD augments the draft model with an acceptance-prediction head , estimating (conditionally) per-token acceptance probabilities:
- Architecture: is a -layer ResNet () with SiLU activations, mapping each last hidden state of a draft-token to , i.e., the predicted conditional acceptance probability.
- Training data: Samples are constructed from prompts with target model responses . The draft model is run in teacher-forcing mode, and for each position the true acceptance probability is .
- Random mixing: To counteract class imbalance and sparsity, sequences are formed by sampling each from (with probability ) or (with probability $1-r$), with loss computed only where .
- Loss: Weighted binary cross-entropy,
with to upweight rare rejections.
Calibration proceeds via held-out KL divergence and throughput tuning on validation data, securing robust probability estimates across candidate group lengths and distributions (Huang et al., 2024).
4. AGSD Online Algorithm: Candidate Length Adaptation
The adaptive candidate-length selection is performed online using the acceptance head’s predictions. The core procedure is:
- Initialize cumulative accept-probability .
- For :
- For , extract and compute .
- Update .
- If (for threshold ), break; else sample next candidate .
- Let , submit to the target for batch verification.
- Post-processing involves sampling which tokens are accepted and possibly performing necessary corrections as per speculative decoding protocol (see Algorithm SpecDec++ in (Huang et al., 2024)).
This adaptive procedure generalizes conventional speculative decoding by dynamically tailoring to the instantaneous predicted risk, achieving close to the oracle optimal throughput across both in-distribution and out-of-distribution tasks.
5. Throughput Analysis and Empirical Results
The total inference cost is:
Per-token latency becomes:
Resulting speedup is:
The following summarizes experimental measurements on llama-2-chat 7B draft and 70B target, with , , and , across datasets:
| Dataset | Vanilla Target (tok/s) | SpecDec (Fixed ) (tok/s, ×) | SpecDec++ (Adaptive) (tok/s, ×) | Rel. Impr. |
|---|---|---|---|---|
| Alpaca | 9.26 | 17.58 (1.90×) | 18.88 (2.04×) | +7.2% |
| HumanEval | 9.26 | 18.52 (2.00×) | 20.65 (2.23×) | +11.1% |
| GSM8K | 9.26 | 19.17 (2.07×) | 20.93 (2.26×) | +9.4% |
AGSD achieves strictly lower discard and verification rates, yielding additional speedup at no cost to wasted computation. SpecDec++ recovers in excess of 99.3% of the optimal achievable throughput on all tested datasets (Huang et al., 2024).
6. Context and Comparisons
Prior speculative decoding approaches generally fix the candidate group size using simplistic heuristics or optimize under IID assumptions. AGSD, by modeling group selection as an MDP and leveraging learned acceptance probabilities, guarantees theoretically optimal cost-minimization policies. This framework allows robust generalization across data distributions (including out-of-distribution), and is compatible with LLMs modified with a small, efficiently trainable prediction head. A plausible implication is that online adaptive speculative decoding may become the new baseline for high-throughput LLM inference in production settings, given the consistent Pareto-improving performance demonstrated in (Huang et al., 2024).