TokenSkip: Controllable Chain-of-Thought Compression in LLMs
(2502.12067v2)
Published 17 Feb 2025 in cs.CL and cs.AI
Abstract: Chain-of-Thought (CoT) has been proven effective in enhancing the reasoning capabilities of LLMs. Recent advancements, such as OpenAI's o1 and DeepSeek-R1, suggest that scaling up the length of CoT sequences during inference could further boost LLM reasoning performance. However, due to the autoregressive nature of LLM decoding, longer CoT outputs lead to a linear increase in inference latency, adversely affecting user experience, particularly when the CoT exceeds 10,000 tokens. To address this limitation, we analyze the semantic importance of tokens within CoT outputs and reveal that their contributions to reasoning vary. Building on this insight, we propose TokenSkip, a simple yet effective approach that enables LLMs to selectively skip less important tokens, allowing for controllable CoT compression. Extensive experiments across various models and tasks demonstrate the effectiveness of TokenSkip in reducing CoT token usage while preserving strong reasoning performance. Notably, when applied to Qwen2.5-14B-Instruct, TokenSkip reduces reasoning tokens by 40% (from 313 to 181) on GSM8K, with less than a 0.4% performance drop.
Summary
The paper introduces TokenSkip, a method to compress Large Language Model Chain-of-Thought outputs by selectively skipping tokens to reduce computational overhead and latency.
TokenSkip measures token importance using a bidirectional model (like LLMLingua-2) to prune CoT sequences based on a target compression ratio before fine-tuning the LLM on the compressed data.
Evaluations demonstrate that TokenSkip can reduce CoT token usage by 30-40% with minimal performance drop (under 4%) and achieve up to a 1.4x inference speedup on mathematical reasoning tasks.
The paper "TokenSkip: Controllable Chain-of-Thought Compression in LLMs" introduces TokenSkip, a method for compressing Chain-of-Thought (CoT) outputs from LLMs by selectively skipping tokens deemed less important. This approach addresses the computational overhead associated with long CoT sequences, which can lead to increased inference latency and memory usage due to the autoregressive nature of LLM decoding.
The authors begin by analyzing the semantic importance of tokens within CoT outputs, revealing that their contributions to reasoning performance vary. They draw upon existing research in prompt compression, particularly the concept of token redundancy. The paper references Selective Context, which measures token importance based on the semantic confidence of LLMs using the equation:
I1(xi)=−logP(xi∣x<i;θML)
Where:
I1(xi) is the token importance for token xi
xi represents a token within the text x={xi}i=1n
x<i denotes the tokens preceding xi
θML represents the parameters of the LLM (ML) used to compute the token's confidence.
The authors also discuss the limitations of this approach, as highlighted by LLMLingua-2, which argues that LLM perplexity can lead to lower importance scores for tokens at the end of sentences due to position dependency. Furthermore, the unidirectional attention mechanism in causal LLMs may fail to capture all essential information for token importance. To address these limitations, LLMLingua-2 utilizes a bidirectional BERT-like LLM for token importance measurement, trained with a token classification objective. Token importance is then measured by the predicted probability of each token using the equation:
I2(xi)=P(xi∣x≤n;θMB)
Where:
I2(xi) is the token importance for token xi
x is the given text
θMB represents the parameters of the bidirectional LLM MB.
The paper applies LLMLingua-2 as the token importance measurement to LLM CoT outputs and observes that mathematical equations tend to have a greater contribution to the final answer, while semantic connectors contribute less.
TokenSkip involves three main steps: token pruning, training, and inference. During token pruning, given a target LLM M, a CoT trajectory c={ci}i=1m, and a desired compression ratio γ∈[0,1], TokenSkip calculates the semantic importance of each CoT token I(ci) using Eq. 2. The tokens are then ranked, and a threshold Iγ is computed as the γ-th percentile of the importance values. CoT tokens with an importance value greater than or equal to Iγ are retained in the compressed CoT trajectory: c={ci∣I(ci)≥Iγ},1≤i≤m.
For training, a dataset D with N samples is used to obtain N CoT trajectories with the target LLM M. Trajectories with incorrect answers are filtered out. The remaining CoT trajectories are pruned with a randomly selected compression ratio γ. The training samples are formatted as Q [EOS] γ [EOS] Compressed CoT A, where Q and A represent the question and answer pair. The target LLM M is fine-tuned by minimizing the loss function:
L=i=1∑llogP(yi∣x,γ,y<i;θM)
Where:
L is the loss function to be minimized
yi is the i-th token in the output sequence y
x is the input question
γ is the compression ratio
y<i are the tokens preceding yi in the output sequence
θM represents the parameters of the target LLM M.
The inference process in TokenSkip follows autoregressive decoding. Given a question x and compression ratio γ, the input prompt is formatted as Q [EOS] γ [EOS]. The LLM M predicts the output sequence y^:
The method was evaluated using LLaMA-3.1-8B-Instruct and the Qwen2.5-Instruct series on the GSM8K and MATH datasets. For training, the respective training sets from both datasets were used. For the MATH dataset, the method was assessed on MATH-500. LLMLingua-2 was used as the token importance metric, and LoRA was adopted for training. The rank r was set to 8, and the scaling parameter α was set to 16.
The baselines used for comparison were prompt-based reduction, where the LLM is instructed to reduce a fixed proportion of output tokens, and truncation, where the maximum number of output tokens is restricted. Evaluation metrics included accuracy, the number of CoT tokens, and inference latency per sample.
The results showed that Qwen2.5-14B-Instruct exhibited almost no performance drop (less than 0.4%) with a 40% reduction in token usage on GSM8K. On the MATH-500 dataset, LLaMA-3.1-8B-Instruct effectively reduced CoT token usage by 30% with a performance decline of less than 4%, resulting in a 1.4x inference speedup.
Further analysis explored the impact of different token importance metrics, including GPT-4o, and the effect of varying the maximum length constraints during inference. The paper concludes by presenting case studies that illustrate how TokenSkip enables LLMs to learn shortcuts between critical reasoning tokens.
The limitations acknowledge that experiments with larger LLMs, such as Qwen2.5-32B-Instruct and Qwen2.5-72B-Instruct, were not conducted due to computational constraints. The authors also note that the token importance measurement used in their paper was not specifically trained on mathematical data, which may affect compression effectiveness.