- The paper introduces a novel method that mixes latent tokens with text tokens to improve language model reasoning.
- It uses a VQ-VAE to compress chain-of-thought tokens, reducing sequence length and computational cost while enhancing accuracy.
- Experiments on synthetic and mathematical benchmarks show significant gains in both prediction accuracy and token efficiency over standard approaches.
The paper "Token Assorted: Mixing Latent and Text Tokens for Improved LLM Reasoning" introduces a novel approach to enhance reasoning in LLMs by integrating discrete latent tokens, derived from a VQ-VAE, with textual representations of reasoning processes. The core idea revolves around abstracting the initial steps of CoT reasoning traces using these latent tokens, thereby compressing lengthy input sequences and reducing computational costs. The approach is evaluated in two primary scenarios: training models from scratch on tasks such as Keys-Finding Maze and fine-tuning pre-trained LLMs (specifically, Llama-3.1 and Llama-3.2 variants) on logical and mathematical reasoning benchmarks.
The methodology involves a two-stage training procedure. First, a VQ-VAE is employed to map CoT tokens into a compressed sequence of discrete latent tokens, achieving a reduction in length governed by a pre-set compression rate, r=tc/tz, where:
- r: compression rate
- tc: length of CoT tokens
- tz: length of discrete latent tokens
The VQ-VAE architecture encompasses several key components: a codebook E containing K vectors in Rd; an encoder fe:RL↦Rd×rL that maps a sequence of L text tokens to rL latent embedding vectors Xˉ={xˉ1,…,xˉrL}; a quantization operator q:Rd↦E that replaces the encoded embedding xˉ by the nearest neighbor in E, defined as q(xˉ)=ei∈Eargmin∥ei−xˉ∥22; an embedding function g:RK↦Rd that maps K text tokens to a d-dimensional embedding vector; and a decoder fd:Rd×rL×RK↦RL that decodes latent embeddings back to text tokens, conditioned on the prompt embedding.
The VQ-VAE is trained using a composite loss function L(X), which includes a reconstruction loss (logp(X∣fd(q(Xˉ)∣g(P)))), a VQ loss (∑i=1L∥[Xˉi]−q(Xˉi)∥22), and a commitment loss (β∥Xˉi−[q(Xˉi)]∥22), where:
- L(X): loss function
- X: input sequence
- P: prompt
- Xˉ: latent embedding vectors
- q(Xˉ): quantization operator applied to latent embedding vectors
- fd: decoder function
- g(P): embedding of the prompt
- [⋅]: stop-gradient operator
- β: hyperparameter controlling the strength of the commitment loss.
In the second stage, the pre-trained VQ-VAE is used to construct modified samples where the initial m tokens of the CoT sequence are replaced by their latent abstractions. A randomized replacement strategy is employed, where the value of m is randomly varied during training to facilitate adaptation to the new latent tokens.
The approach was evaluated on a range of benchmarks. For synthetic datasets, including Keys-Finding Maze, ProntoQA, and ProsQA, T5 or GPT-2 models were trained from scratch. On Keys-Finding Maze, the approach achieved a 1-Feasible-10 score of 62.8%, significantly outperforming the CoT baseline (43%). For ProntoQA, the approach reached 100% accuracy, compared to 98.8% for the CoT baseline. On the more challenging ProsQA, the approach achieved 96.2% accuracy, compared to 77.5% for the CoT baseline. The method also exhibited improved token efficiency, generating shorter reasoning traces than the CoT baseline while maintaining or improving performance.
For mathematical reasoning, Llama-3.1 and Llama-3.2 models of varying sizes (1B, 3B, and 8B parameters) were fine-tuned on the MetaMathQA dataset and evaluated on in-domain (Math, GSM8K) and out-of-domain (Fresh-Gaokao-Math-2023, DeepMind-Math, College-Math, OlympiaBench-Math, TheoremQA) benchmarks. The approach consistently outperformed baseline methods, including Sol-Only, CoT, iCoT, and Pause Token, across nearly all tasks and model sizes. For instance, with the Llama-3.1-8B model, the approach achieved an average accuracy of 37.9% across all datasets, compared to 33.4% for the CoT baseline. Significant gains were observed on specific datasets such as Fresh-Gaokao-Math-2023, where the approach achieved a 30.0% accuracy compared to 16.7% for the CoT baseline. The approach also demonstrated improved token efficiency, achieving an average reduction of 17% in the number of tokens compared to the CoT baseline, while simultaneously improving prediction accuracy.
Ablation studies were conducted to analyze the impact of different replacement strategies. The results indicated that the partial, left-to-right (AR) replacement strategy, where the value of m varies for each sample, outperformed alternative strategies such as All-Replace, Curriculum-Replace, and Poisson-Replace. The AR-Replace strategy is better, because the remaining text tokens serve as anchors, helping the model interpret and integrate the latent representations more effectively.
Analysis of attention weights revealed that the approach enables the model to focus more effectively on important tokens, such as numbers and words representing mathematical operations.
In summary, the paper presents a method for enhancing reasoning capabilities in LLMs through the integration of discrete latent tokens and textual representations. The approach demonstrates strong performance across a range of tasks and model architectures, offering improvements in both accuracy and token efficiency.