Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
140 tokens/sec
GPT-4o
7 tokens/sec
Gemini 2.5 Pro Pro
46 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
38 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Token Assorted: Mixing Latent and Text Tokens for Improved Language Model Reasoning (2502.03275v1)

Published 5 Feb 2025 in cs.CL, cs.AI, cs.LG, and cs.LO

Abstract: LLMs excel at reasoning and planning when trained on chainof-thought (CoT) data, where the step-by-step thought process is explicitly outlined by text tokens. However, this results in lengthy inputs where many words support textual coherence rather than core reasoning information, and processing these inputs consumes substantial computation resources. In this work, we propose a hybrid representation of the reasoning process, where we partially abstract away the initial reasoning steps using latent discrete tokens generated by VQ-VAE, significantly reducing the length of reasoning traces. We explore the use of latent trace abstractions in two scenarios: 1) training the model from scratch for the Keys-Finding Maze problem, 2) fine-tuning LLMs on this hybrid data with an extended vocabulary including unseen latent tokens, for both logical and mathematical reasoning problems. To facilitate effective learning, we introduce a simple training procedure that randomly mixes latent and text tokens, which enables fast adaptation to new latent tokens. Our approach consistently outperforms the baselines methods in various benchmarks.

Summary

  • The paper introduces a novel method that mixes latent tokens with text tokens to improve language model reasoning.
  • It uses a VQ-VAE to compress chain-of-thought tokens, reducing sequence length and computational cost while enhancing accuracy.
  • Experiments on synthetic and mathematical benchmarks show significant gains in both prediction accuracy and token efficiency over standard approaches.

The paper "Token Assorted: Mixing Latent and Text Tokens for Improved LLM Reasoning" introduces a novel approach to enhance reasoning in LLMs by integrating discrete latent tokens, derived from a VQ-VAE, with textual representations of reasoning processes. The core idea revolves around abstracting the initial steps of CoT reasoning traces using these latent tokens, thereby compressing lengthy input sequences and reducing computational costs. The approach is evaluated in two primary scenarios: training models from scratch on tasks such as Keys-Finding Maze and fine-tuning pre-trained LLMs (specifically, Llama-3.1 and Llama-3.2 variants) on logical and mathematical reasoning benchmarks.

The methodology involves a two-stage training procedure. First, a VQ-VAE is employed to map CoT tokens into a compressed sequence of discrete latent tokens, achieving a reduction in length governed by a pre-set compression rate, r=tc/tzr = t_c / t_z, where:

  • rr: compression rate
  • tct_c: length of CoT tokens
  • tzt_z: length of discrete latent tokens

The VQ-VAE architecture encompasses several key components: a codebook E\mathcal{E} containing KK vectors in Rd\mathbb{R}^d; an encoder fe:RLRd×Lrf_e: \mathbb{R}^L \mapsto \mathbb{R}^{d \times \frac{L}{r}} that maps a sequence of LL text tokens to Lr\frac{L}{r} latent embedding vectors Xˉ={xˉ1,,xˉLr}\bar{X} = \{\bar{x}_1, \ldots, \bar{x}_{\frac{L}{r}}\}; a quantization operator q:RdEq: \mathbb{R}^d \mapsto \mathcal{E} that replaces the encoded embedding xˉ\bar{x} by the nearest neighbor in E\mathcal{E}, defined as q(xˉ)=arg mineiEeixˉ22q(\bar{x}) = \argmin_{e_i \in \mathcal{E}} \|e_i - \bar{x}\|_2^2; an embedding function g:RKRdg: \mathbb{R}^K \mapsto \mathbb{R}^d that maps KK text tokens to a dd-dimensional embedding vector; and a decoder fd:Rd×Lr×RKRLf_d: \mathbb{R}^{d \times \frac{L}{r}} \times \mathbb{R}^K \mapsto \mathbb{R}^L that decodes latent embeddings back to text tokens, conditioned on the prompt embedding.

The VQ-VAE is trained using a composite loss function L(X)\mathcal{L}(X), which includes a reconstruction loss (logp(Xfd(q(Xˉ)g(P)))\log p(X | f_d( q(\bar{X}) | g(P) ))), a VQ loss (i=1L[Xˉi]q(Xˉi)22\sum_{i=1}^L \| [\bar{X}_i] - q(\bar{X}_i) \|_2^2), and a commitment loss (βXˉi[q(Xˉi)]22\beta \| \bar{X}_i - [q(\bar{X}_i)] \|_2^2), where:

  • L(X)\mathcal{L}(X): loss function
  • XX: input sequence
  • PP: prompt
  • Xˉ\bar{X}: latent embedding vectors
  • q(Xˉ)q(\bar{X}): quantization operator applied to latent embedding vectors
  • fdf_d: decoder function
  • g(P)g(P): embedding of the prompt
  • [][\cdot]: stop-gradient operator
  • β\beta: hyperparameter controlling the strength of the commitment loss.

In the second stage, the pre-trained VQ-VAE is used to construct modified samples where the initial mm tokens of the CoT sequence are replaced by their latent abstractions. A randomized replacement strategy is employed, where the value of mm is randomly varied during training to facilitate adaptation to the new latent tokens.

The approach was evaluated on a range of benchmarks. For synthetic datasets, including Keys-Finding Maze, ProntoQA, and ProsQA, T5 or GPT-2 models were trained from scratch. On Keys-Finding Maze, the approach achieved a 1-Feasible-10 score of 62.8%, significantly outperforming the CoT baseline (43%). For ProntoQA, the approach reached 100% accuracy, compared to 98.8% for the CoT baseline. On the more challenging ProsQA, the approach achieved 96.2% accuracy, compared to 77.5% for the CoT baseline. The method also exhibited improved token efficiency, generating shorter reasoning traces than the CoT baseline while maintaining or improving performance.

For mathematical reasoning, Llama-3.1 and Llama-3.2 models of varying sizes (1B, 3B, and 8B parameters) were fine-tuned on the MetaMathQA dataset and evaluated on in-domain (Math, GSM8K) and out-of-domain (Fresh-Gaokao-Math-2023, DeepMind-Math, College-Math, OlympiaBench-Math, TheoremQA) benchmarks. The approach consistently outperformed baseline methods, including Sol-Only, CoT, iCoT, and Pause Token, across nearly all tasks and model sizes. For instance, with the Llama-3.1-8B model, the approach achieved an average accuracy of 37.9% across all datasets, compared to 33.4% for the CoT baseline. Significant gains were observed on specific datasets such as Fresh-Gaokao-Math-2023, where the approach achieved a 30.0% accuracy compared to 16.7% for the CoT baseline. The approach also demonstrated improved token efficiency, achieving an average reduction of 17% in the number of tokens compared to the CoT baseline, while simultaneously improving prediction accuracy.

Ablation studies were conducted to analyze the impact of different replacement strategies. The results indicated that the partial, left-to-right (AR) replacement strategy, where the value of mm varies for each sample, outperformed alternative strategies such as All-Replace, Curriculum-Replace, and Poisson-Replace. The AR-Replace strategy is better, because the remaining text tokens serve as anchors, helping the model interpret and integrate the latent representations more effectively.

Analysis of attention weights revealed that the approach enables the model to focus more effectively on important tokens, such as numbers and words representing mathematical operations.

In summary, the paper presents a method for enhancing reasoning capabilities in LLMs through the integration of discrete latent tokens and textual representations. The approach demonstrates strong performance across a range of tasks and model architectures, offering improvements in both accuracy and token efficiency.

Youtube Logo Streamline Icon: https://streamlinehq.com