Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
41 tokens/sec
GPT-4o
59 tokens/sec
Gemini 2.5 Pro Pro
41 tokens/sec
o3 Pro
7 tokens/sec
GPT-4.1 Pro
50 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

Faster Speech-LLaMA Inference with Multi-token Prediction (2409.08148v1)

Published 12 Sep 2024 in eess.AS and cs.SD
Faster Speech-LLaMA Inference with Multi-token Prediction

Abstract: LLMs have become proficient at solving a wide variety of tasks, including those involving multi-modal inputs. In particular, instantiating an LLM (such as LLaMA) with a speech encoder and training it on paired data imparts speech recognition (ASR) abilities to the decoder-only model, hence called Speech-LLaMA. Nevertheless, due to the sequential nature of auto-regressive inference and the relatively large decoder, Speech-LLaMA models require relatively high inference time. In this work, we propose to speed up Speech-LLaMA inference by predicting multiple tokens in the same decoding step. We explore several model architectures that enable this, and investigate their performance using threshold-based and verification-based inference strategies. We also propose a prefix-based beam search decoding method that allows efficient minimum word error rate (MWER) training for such models. We evaluate our models on a variety of public benchmarks, where they reduce the number of decoder calls by ~3.2x while maintaining or improving WER performance.

Speech-LLaMA models, which integrate a speech encoder with a LLM decoder for Automatic Speech Recognition (ASR), often face high inference times due to the auto-regressive nature of the LLM decoder and its substantial computational cost per step. This limits their real-world applicability despite their strong accuracy on various ASR tasks.

This paper proposes accelerating Speech-LLaMA inference by enabling the model to predict multiple tokens (KK) simultaneously in a single decoding step, aiming to reduce the total number of decoder calls required for a full transcription. This approach draws inspiration from similar techniques used to speed up LLM inference.

The core idea is to modify the decoder to output predictions for KK subsequent tokens conditioned on the current context, rather than just one. The probability of a sequence of KK tokens yu+1:u+K\mathbf{y}_{u+1:u+K} given the previous context yu\mathbf{y}_{\leqslant u} and speech input X\mathbf{X} is approximated by assuming conditional independence among the KK tokens: P(yu+1:u+Kyu,X)k=1KPk(yu+kyu,X)P(\mathbf{y}_{u+1:u+K}\mid \mathbf{y}_{\leqslant u},\mathbf{X}) \approx \prod_{k=1}^K P_k(y_{u+k}\mid \mathbf{y}_{\leqslant u},\mathbf{X}). Each PkP_k is estimated by a separate prediction head.

The paper explores two main model architectures for implementing these multiple prediction heads:

  1. Independent Projection Heads (Medusa-style): This architecture uses a shared transformer trunk (ftrff_{\mathrm{trf}}) from the original decoder and adds KK independent linear projection heads (H1,,HK\mathbf{H}_1, \ldots, \mathbf{H}_K), each mapping the trunk output to a vocabulary distribution. This directly corresponds to KK parallel copies of the final output layer. While conceptually simple, this adds (K1)×D×V(K-1) \times D \times V parameters, where DD is the model dimension and VV is the vocabulary size.
  2. Latent-space Expansion: To mitigate the parameter increase, this architecture factorizes each head matrix Hk\mathbf{H}_k into a full-rank matrix Lk\mathbf{L}_k and a shared un-embedding matrix H\mathbf{H} (initialized from the original model's un-embedding). This adds K×D2K \times D^2 extra parameters, which is independent of vocabulary size VV and can be significantly less than the projection heads approach when DVD \ll V. The latent-space approach resulted in lower decoder Real-Time Factor (RTF) compared to the projection heads, indicating better efficiency despite a larger total parameter count than the baseline (234M vs 306M for a 4-head model compared to a 125M baseline).

To translate multi-token predictions into an output sequence while maintaining accuracy, the paper discusses and proposes different inference strategies:

  1. Verification-based: This follows a "predict-verify-accept" paradigm similar to speculative decoding. At each step uu, KK tokens are predicted based on yu\mathbf{y}_{\leqslant u}. Then, these predicted tokens y^u+1,,y^u+K\widehat{y}_{u+1}, \ldots, \widehat{y}_{u+K} are verified sequentially using the probability distribution from the main head (head 1). The process accepts tokens y^u+1,,y^u+k^\widehat{y}_{u+1}, \ldots, \widehat{y}_{u+\widehat{k}} as long as y^u+r\widehat{y}_{u+r} is the most probable token predicted by head 1 using the correct prefix y<u+r\mathbf{y}_{<u+r}. A loosened version checks if y^u+k\widehat{y}_{u+k} is within the top-MM predictions of head 1. This method guarantees the output sequence is the same as standard auto-regressive decoding (for M=1M=1), but acceptance rate varies unpredictably.
  2. Threshold-based: At each step uu, KK tokens y^u+1,,y^u+K\widehat{y}_{u+1}, \ldots, \widehat{y}_{u+K} are predicted using heads 1,,K1, \ldots, K respectively. Tokens y^u+1,,y^u+k^\widehat{y}_{u+1}, \ldots, \widehat{y}_{u+\widehat{k}} are accepted if the probability Pr(yu+ryu,X)P_r(y_{u+r}\mid \mathbf{y}_{\leqslant u},\mathbf{X}) from the corresponding head rr for the predicted token yu+ry_{u+r} is above a hyperparameter threshold τ\tau, for all rk^r \leqslant \widehat{k}. This offers explicit control over the speed-performance trade-off via the threshold τ\tau.

For training, the standard cross-entropy loss is extended to sum losses over all KK prediction heads, potentially weighted by αk\alpha_k. A modified Minimum Word Error Rate (MWER) training objective is also introduced. This involves using the multi-token decoding process (specifically, the threshold-based method) to generate N-best lists. The sequence probability for the MWER loss is calculated by multiplying the probabilities assigned by the specific head that generated each token in the sequence. To make batched MWER training efficient, a prefix-based beam search is used, grouping hypotheses by prefix length for batched KV-cache computation.

Experiments on LibriSpeech show that both multi-token architectures can improve WER over the baseline Speech-LLaMA. The latent-space expansion model is more efficient in terms of decoder RTF for the same number of heads (K=4K=4). Increasing KK generally reduces the number of decoder calls (η\eta), but WER starts degrading beyond K=4K=4 or K=5K=5. The threshold-based inference strategy provides better control over the speed-performance trade-off than verification methods. For large-scale multi-lingual ASR, the 4-head latent model significantly reduces η\eta (around 3.2x reduction, or 62.8\% average reduction), making inference faster, although it showed a slight increase in WER with the 125M decoder, suggesting a larger decoder might be beneficial. MWER training was found to improve performance, especially in the multi-lingual setting. The multi-token approach also reduced the variability in η\eta across different languages.

In summary, the paper demonstrates that multi-token prediction is a viable and effective technique for speeding up Speech-LLaMA inference. The latent-space expansion architecture is more parameter-efficient than independent projection heads, and threshold-based inference allows fine-grained control over the speed-accuracy balance. While larger decoders might be needed for optimal performance in complex multi-lingual scenarios, the method significantly reduces the number of decoder calls and improves inference speed for ASR with LLM-based decoders.

User Edit Pencil Streamline Icon: https://streamlinehq.com
Authors (5)
  1. Desh Raj (32 papers)
  2. Gil Keren (22 papers)
  3. Junteng Jia (23 papers)
  4. Jay Mahadeokar (36 papers)
  5. Ozlem Kalinli (49 papers)
X Twitter Logo Streamline Icon: https://streamlinehq.com