Speech-LLaMA models, which integrate a speech encoder with a LLM decoder for Automatic Speech Recognition (ASR), often face high inference times due to the auto-regressive nature of the LLM decoder and its substantial computational cost per step. This limits their real-world applicability despite their strong accuracy on various ASR tasks.
This paper proposes accelerating Speech-LLaMA inference by enabling the model to predict multiple tokens () simultaneously in a single decoding step, aiming to reduce the total number of decoder calls required for a full transcription. This approach draws inspiration from similar techniques used to speed up LLM inference.
The core idea is to modify the decoder to output predictions for subsequent tokens conditioned on the current context, rather than just one. The probability of a sequence of tokens given the previous context and speech input is approximated by assuming conditional independence among the tokens: . Each is estimated by a separate prediction head.
The paper explores two main model architectures for implementing these multiple prediction heads:
- Independent Projection Heads (Medusa-style): This architecture uses a shared transformer trunk () from the original decoder and adds independent linear projection heads (), each mapping the trunk output to a vocabulary distribution. This directly corresponds to parallel copies of the final output layer. While conceptually simple, this adds parameters, where is the model dimension and is the vocabulary size.
- Latent-space Expansion: To mitigate the parameter increase, this architecture factorizes each head matrix into a full-rank matrix and a shared un-embedding matrix (initialized from the original model's un-embedding). This adds extra parameters, which is independent of vocabulary size and can be significantly less than the projection heads approach when . The latent-space approach resulted in lower decoder Real-Time Factor (RTF) compared to the projection heads, indicating better efficiency despite a larger total parameter count than the baseline (234M vs 306M for a 4-head model compared to a 125M baseline).
To translate multi-token predictions into an output sequence while maintaining accuracy, the paper discusses and proposes different inference strategies:
- Verification-based: This follows a "predict-verify-accept" paradigm similar to speculative decoding. At each step , tokens are predicted based on . Then, these predicted tokens are verified sequentially using the probability distribution from the main head (head 1). The process accepts tokens as long as is the most probable token predicted by head 1 using the correct prefix . A loosened version checks if is within the top- predictions of head 1. This method guarantees the output sequence is the same as standard auto-regressive decoding (for ), but acceptance rate varies unpredictably.
- Threshold-based: At each step , tokens are predicted using heads respectively. Tokens are accepted if the probability from the corresponding head for the predicted token is above a hyperparameter threshold , for all . This offers explicit control over the speed-performance trade-off via the threshold .
For training, the standard cross-entropy loss is extended to sum losses over all prediction heads, potentially weighted by . A modified Minimum Word Error Rate (MWER) training objective is also introduced. This involves using the multi-token decoding process (specifically, the threshold-based method) to generate N-best lists. The sequence probability for the MWER loss is calculated by multiplying the probabilities assigned by the specific head that generated each token in the sequence. To make batched MWER training efficient, a prefix-based beam search is used, grouping hypotheses by prefix length for batched KV-cache computation.
Experiments on LibriSpeech show that both multi-token architectures can improve WER over the baseline Speech-LLaMA. The latent-space expansion model is more efficient in terms of decoder RTF for the same number of heads (). Increasing generally reduces the number of decoder calls (), but WER starts degrading beyond or . The threshold-based inference strategy provides better control over the speed-performance trade-off than verification methods. For large-scale multi-lingual ASR, the 4-head latent model significantly reduces (around 3.2x reduction, or 62.8\% average reduction), making inference faster, although it showed a slight increase in WER with the 125M decoder, suggesting a larger decoder might be beneficial. MWER training was found to improve performance, especially in the multi-lingual setting. The multi-token approach also reduced the variability in across different languages.
In summary, the paper demonstrates that multi-token prediction is a viable and effective technique for speeding up Speech-LLaMA inference. The latent-space expansion architecture is more parameter-efficient than independent projection heads, and threshold-based inference allows fine-grained control over the speed-accuracy balance. While larger decoders might be needed for optimal performance in complex multi-lingual scenarios, the method significantly reduces the number of decoder calls and improves inference speed for ASR with LLM-based decoders.