Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
134 tokens/sec
GPT-4o
10 tokens/sec
Gemini 2.5 Pro Pro
47 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Draft-based Approximate Inference for LLMs (2506.08373v1)

Published 10 Jun 2025 in cs.CL and cs.AI

Abstract: Optimizing inference for long-context LLMs is increasingly important due to the quadratic compute and linear memory complexity of Transformers. Existing approximation methods, such as key-value (KV) cache dropping, sparse attention, and prompt compression, typically rely on rough predictions of token or KV pair importance. We propose a novel framework for approximate LLM inference that leverages small draft models to more accurately predict the importance of tokens and KV pairs. Specifically, we introduce two instantiations of our proposed framework: (i) SpecKV, which leverages a draft output to accurately assess the importance of each KV pair for more effective KV cache dropping, and (ii) SpecPC, which uses the draft model's attention activations to identify and discard unimportant prompt tokens. To the best of our knowledge, this is the first work to use draft models for approximate LLM inference acceleration, extending their utility beyond traditional lossless speculative decoding. We motivate our methods with theoretical and empirical analyses, and show a strong correlation between the attention patterns of draft and target models. Extensive experiments on long-context benchmarks show that our methods consistently achieve higher accuracy than existing baselines, while preserving the same improvements in memory usage, latency, and throughput. Our code is available at https://github.com/furiosa-ai/draft-based-approx-LLM.

Summary

  • The paper introduces a novel framework that leverages a draft model to generate approximate tokens, enabling more precise estimation of key-value pair importance.
  • It presents two methods—SpecKV for KV cache dropping and SpecPC for prompt compression—that enhance performance and reduce computational overhead.
  • Experiments on benchmarks like RULER and LongBench demonstrate up to 25-point accuracy gains and reduced latency, confirming the framework’s practical benefits.

This paper introduces a novel framework called "Speculative Decoding-Assisted Approximate LLM Inference" to optimize long-context LLM inference. The core problem addressed is the quadratic compute and linear memory complexity of Transformers, which becomes prohibitive for very long sequences. Existing approximation methods (e.g., KV cache dropping, sparse attention, prompt compression) often rely on imprecise estimations of token or Key-Value (KV) pair importance. The proposed framework enhances these estimations by using a small "draft" model to generate approximate future tokens. This "lookahead" information provides a better basis for determining which parts of the input or KV cache are most crucial for generating subsequent tokens.

The authors present two specific instantiations of this framework:

  1. SpecKV (Speculative KV Dropping): This method aims to improve KV cache dropping and enable sparse prefilling.
    • Importance Estimation: Instead of relying solely on attention patterns from input tokens (like SnapKV), SpecKV uses the draft model's output as a proxy for future target model outputs. The importance of an input KV pair is defined as the average attention activation from the draft model's output queries to the input keys. Theorem 1 provides theoretical justification, showing that if the draft model's output embeddings are close to the target model's, the estimated importance scores will also be close.
    • Algorithm:

    1. A draft output of length nlookaheadn_{\text{lookahead}} is generated using the small draft model. 2. Both the original input sequence and the draft output are fed through the target model during prefilling. 3. For each attention head, cross-attention is computed. The queries come from the last nwindown_{\text{window}} input tokens and all nlookaheadn_{\text{lookahead}} draft output tokens. The keys and values come from the remaining input tokens (excluding the nwindown_{\text{window}} most recent ones). 4. The attention scores are aggregated (e.g., max reduction) per input key and smoothed using 1D average pooling. 5. These scores guide two optimizations: * Sparse Prefilling: A Vertical-Slash attention pattern is applied, using the top-kk most important tokens as global tokens. * KV Cache Dropping: After prefilling, the top CcachenwindowC_{\text{cache}} - n_{\text{window}} KV pairs (based on importance scores) are retained, along with the KV pairs from the final nwindown_{\text{window}} input tokens. * Implementation Detail: Figure 3 visually contrasts SpecKV with SnapKV, highlighting SpecKV's use of draft tokens for richer context in importance estimation. Figure 4 shows SpecKV identifying important tokens more effectively than SnapKV in a needle retrieval task.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# Pseudocode for SpecKV Importance Score Calculation (Conceptual)
def get_speckv_importance_scores(target_model, draft_model, input_tokens, n_lookahead, n_window_input):
    # 1. Generate draft output
    draft_output_tokens = draft_model.generate(input_tokens, max_new_tokens=n_lookahead)

    # 2. Get hidden states from target model
    # For simplicity, assume a single layer; in reality, this is per layer, per head
    # X_input: hidden states of input_tokens
    # X_draft_output: hidden states of draft_output_tokens
    all_tokens = concatenate(input_tokens, draft_output_tokens)
    hidden_states_all = target_model.get_hidden_states(all_tokens) # Simplified

    input_len = len(input_tokens)
    prompt_keys_values = hidden_states_all[:input_len - n_window_input]

    # Queries from last n_window_input tokens and all draft_output_tokens
    query_tokens_indices = list(range(input_len - n_window_input, input_len)) + \
                           list(range(input_len, input_len + n_lookahead))
    queries = hidden_states_all[query_tokens_indices]

    # 3. Compute cross-attention scores
    # attention_scores[query_idx, key_idx]
    attention_scores = calculate_cross_attention(queries, prompt_keys_values)

    # 4. Aggregate scores for each key in prompt_keys_values
    # Example: Max score received by each key token
    importance_scores_per_key = attention_scores.max(axis=0) # Max over queries

    # 5. Apply pooling (e.g., AvgPool1D)
    smoothed_scores = apply_1d_avg_pool(importance_scores_per_key, kernel_size=k)

    return smoothed_scores

  1. SpecPC (Speculative Prompt Compression): This method uses the draft model's internal attention activations directly to compress the input prompt.

    • Motivation: The paper posits that if draft and target models produce similar outputs (a core assumption of speculative decoding), their attention patterns should also be similar. Theorem 2 (under Restricted Isometry Property conditions) supports this, showing that the error in approximate attention activations is proportional to the error in approximate outputs. Figure 5 empirically demonstrates high correlation between draft (Llama-3.2-1B) and target (Llama-3.1-8B) model attention activations.
    • Algorithm:

    1. The input prompt is fed to the draft model. 2. The draft model's attention activations $A \in \mathbb{R}^{n_\text{layer} \times n_\text{head} \times (n_{\text{input}}+n_{\text{draft_output}}-1) \times n_{\text{input}}}$ are extracted. Typically, only activations from later layers (skipping the first lskipl_{\text{skip}} layers) are used, focusing on queries from the final tokens (e.g., last input token if nlookahead=1n_{\text{lookahead}}=1). 3. Queries are reweighted to give more importance to those closer to the end of the prompt. 4. Attention scores are aggregated across layers, heads, and queries to get a single importance score per input token (e.g., using max reduction). Max aggregation empirically works better for retrieval. 5. The aggregated scores are smoothed with 1D average pooling, followed by 1D max pooling to ensure that selected tokens also bring their local neighbors, preserving context. 6. The top-CpromptC_{\text{prompt}} tokens with the highest scores are selected, always including a fixed window of recent tokens (e.g., last nwindown_{\text{window}} tokens), to form the compressed prompt. This compressed prompt is then fed to the target model. * Advantages: Unlike some methods requiring sentence-level processing, SpecPC is modality-agnostic and works with various input types. It reduces computation for both attention and MLP layers during prefill and decoding.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# Pseudocode for SpecPC Token Selection (Conceptual)
def get_specpc_compressed_indices(draft_model, input_tokens, C_prompt, n_window, l_skip, kernel_size, n_neighbor):
    # 1. Get draft model attention activations
    # A[layer, head, query_pos, key_pos]
    # For simplicity, assume n_lookahead_draft_output = 1 (or 0 if only input self-attention)
    # and we use attention from the last query token in the draft model.
    attention_activations = draft_model.get_attention_maps(input_tokens) # Simplified

    num_layers, num_heads, num_queries, num_keys = attention_activations.shape
    m = num_keys - n_window # Number of non-window keys

    # 2. Slice activations: skip early layers, consider relevant queries/keys
    # relevant_queries typically from near the end of the (input + draft_output) sequence
    # relevant_keys are the input tokens up to the window
    # Example: use activations from last query to non-window keys
    relevant_attentions = attention_activations[l_skip:, :, -1, :m] # Simplified

    # 3. Reweight queries (if multiple queries were used, skipped for simplicity here)

    # 4. Aggregate scores (e.g., MaxReduce across layers and heads)
    # This results in one score per key token in the [:m] range
    token_importance_scores = relevant_attentions.max(axis=(0, 1)) # Max over layers, heads

    # 5. Smooth and expand context
    smoothed_scores = apply_1d_avg_pool(token_importance_scores, kernel_size=kernel_size)
    contextual_scores = apply_1d_max_pool(smoothed_scores, kernel_size=n_neighbor)

    # 6. Select top-C_prompt tokens
    # Ensure scores are for the original token indices
    scores_for_topk = contextual_scores
    indices_to_keep_from_compression = top_k_indices(scores_for_topk, C_prompt - n_window)

    # Always keep the last n_window tokens
    window_indices = list(range(m, num_keys))
    selected_indices = sorted(list(set(indices_to_keep_from_compression + window_indices)))

    return selected_indices

Experiments and Results:

The methods were evaluated on RULER and LongBench benchmarks using Qwen2.5 (0.5B draft, 14B target) and Llama-3 (3.2-1B draft, 3.1-8B target) models.

  • Accuracy: Both SpecKV and SpecPC consistently outperformed baselines (StreamingLLM, H2O, SnapKV, PyramidKV, AdaKV for SpecKV; LLMLingua-2, CPC, R2C for SpecPC). SpecKV showed up to 25 points improvement on RULER. SpecPC nearly matched the full target model's performance. The methods performed particularly well on few-shot learning and code completion tasks. Combining SpecKV with AdaKV (Ada-SpecKV) often yielded the best KV dropping results.

  • Latency: (TTFT - Time To First Token, on Qwen2.5-14B target, Qwen2.5-0.5B draft, H100 GPU)

    • SpecKV (with sparse prefill) was faster than SnapKV. The draft model overhead was minimal.
    • SpecPC outperformed other prompt compression baselines (CPC, R2C), which can have higher overhead from text preprocessing for long sequences.
    • Prompt compression methods were generally faster than KV dropping methods for TTFT.
  • Memory:
    • SpecKV's memory usage was comparable to SnapKV, with a small, constant overhead for draft model weights (which can be offloaded to CPU if needed).
    • SpecPC was more memory-efficient than R2C for the auxiliary model stage.
  • Ablations:
    • Better/larger draft models improved SpecKV performance.
    • Increasing draft output length (nlookaheadn_{\text{lookahead}}) generally boosted SpecKV accuracy but had marginal impact on SpecPC. For experiments, nlookaheadn_{\text{lookahead}} was set to max tokens for SpecKV and 1 for SpecPC.
    • The paper also includes results for multimodal models (Qwen2.5-VL on MileBench) and larger target LLMs (e.g., Llama-3.1-70B), showing consistent benefits.

Discussion and Limitations:

The paper demonstrates that leveraging draft models for approximate inference significantly improves accuracy over existing methods while maintaining efficiency gains.

  • Limitations:
    • For SpecKV, very long draft outputs or a large nlookaheadn_{\text{lookahead}} could increase latency.
    • For SpecPC, further work is needed to better leverage longer draft outputs.
  • Future Work: Extending the framework to lookahead-based sparse decoding or iterative KV cache dropping.

Code: The implementation is available at https://github.com/furiosa-ai/draft-based-approx-LLM.

In summary, the paper presents a well-motivated framework that cleverly extends the utility of draft models beyond lossless speculative decoding to the field of approximate inference for long-context LLMs. By using draft model outputs or attention patterns, SpecKV and SpecPC achieve more accurate importance estimation, leading to superior performance in KV cache dropping and prompt compression respectively, with practical benefits in latency and memory usage.