Single-step multi-token prediction capability of LLM-based decoders

Establish whether an auto-regressive decoder initialized from a large language model, such as LLaMA in the Speech-LLaMA architecture, can predict K future tokens in a single decoding step so that the total number of decoding steps required to generate a sequence of length U reduces to U/K under a modified left-to-right factorization of P(y | X) into blocks of K tokens.

Background

The paper studies accelerating inference for Speech-LLaMA, a decoder-only ASR system that attaches a speech encoder and modality adapter to a pre-trained LLM. Auto-regressive decoding in such systems incurs high latency because each token requires a forward pass through a large transformer decoder.

To reduce inference steps, the authors consider predicting multiple next tokens simultaneously. They introduce architectures that add multiple decoding heads or expand latent space to enable parallel multi-token prediction while controlling model size. The central premise motivating these designs is a conjecture that complex LLM decoders can generate several subsequent tokens in one step, effectively lowering the number of decoding steps required.

References

We conjecture that a complex decoder (such as an LLM) should be able to predict multiple tokens (say, $K$) in a single step of the decoding process, thus reducing the required number of decoding steps to $\frac{U}{K}$.

Faster Speech-LLaMA Inference with Multi-token Prediction (2409.08148 - Raj et al., 12 Sep 2024) in Section 3 (Multi-token Prediction), first paragraph