Mask-Predict: Iterative Non-Autoregressive Decoding
- Mask-Predict is a non-autoregressive sequence generation method that uses conditional masked language modeling to iteratively refine token predictions.
- It masks low-confidence tokens and re-predicts them in parallel, striking a balance between decoding speed and model accuracy.
- The algorithm applies to tasks like neural machine translation and speech recognition, integrating techniques such as CTC and SMART for enhanced performance.
Mask-Predict is a non-autoregressive, parallel decoding algorithm for sequence generation built upon conditional masked LLMs (CMLMs). Originally developed for neural machine translation, Mask-Predict combines the speed of fully parallel decoding with the ability to iteratively refine outputs, reducing the performance gap between non-autoregressive and autoregressive generation. The approach has subsequently been extended to other structured prediction tasks, notably automatic speech recognition (ASR), via integration with objectives such as connectionist temporal classification (CTC).
1. Conditional Masked Language Modeling for Sequence Generation
Mask-Predict relies on training a conditional masked LLM to enable arbitrary, order-agnostic token prediction. Given a source sequence and a target sequence , a random subset of target positions is replaced with a special token during training. The loss is minimized over only the masked positions: where includes unmasked positions. This objective permits the model to fill in any subset of positions, conditioning on source inputs and a partially observed target sequence (Ghazvininejad et al., 2019).
An auxiliary loss is commonly used to predict the target sequence length , often via a special token; however, the primary focus remains on the above loss.
2. Mask-Predict Decoding Algorithm
Mask-Predict operates in parallel decoding iterations. The initial step predicts all tokens in a non-autoregressive pass. In subsequent iterations, the method identifies low-confidence tokens, masks them, and re-predicts these positions in parallel:
- Initialization (): All positions are masked. Each is predicted in parallel.
- Iteration (0): At each step, the 1 least-confident tokens (as measured by 2) are masked. These tokens are re-predicted, and observed tokens remain unchanged.
- Final output: After 3 iterations, no masked tokens remain.
This process is defined in algorithmic pseudocode and involves confidence-based selection for masking and prediction, with linear decay in the number of masked tokens per iteration (Ghazvininejad et al., 2019, Ghazvininejad et al., 2020). The approach is strictly parallel at each step, enabling efficient inference.
3. Performance and Empirical Analysis
Mask-Predict attains a balance between the efficiency of fully non-autoregressive inference and the quality of autoregressive models:
- With 4 iterations, Mask-Predict achieves 25.94 BLEU on WMT'14 En–De, outperforming all prior non-autoregressive models by 54 BLEU.
- With 6, performance rises to 27.03 BLEU, within 0.7 BLEU of the left-to-right autoregressive Transformer baseline (27.74 BLEU).
- Decoding time is drastically reduced: Mask-Predict with 7 delivers a %%%%1718%%%% speed-up over standard Transformers at the cost of about two BLEU points (Ghazvininejad et al., 2019).
These results generalize to other translation directions and corpora. The design choices of 0 and the number of candidate sequence lengths can be tuned to trade-off between speed and accuracy.
4. Advances and Training Enhancements: SMART and Semi-Autoregressive Methods
A limitation of original CMLM training is the divergence from inference: during training, masked positions are replaced by ground truth, whereas prediction operates on model outputs. The SMART training method addresses this by introducing prediction-based contexts into training:
- SMART procedure: Each training example generates model predictions on masked gold targets, then applies masking to these predictions. The model is then trained to recover the original target from this corrupted sequence, ensuring exposure to its own likely errors.
- SMART achieves a consistent 1 BLEU improvement over standard non-autoregressive training (NART) on Mask-Predict, and at 2 iterations closes the gap to autoregressive approaches to within 3–4 BLEU on multiple benchmarks.
- Analyses indicate that re-predicting all tokens rather than just the lowest-confidence ones can further improve outcomes, suggesting that repeated refinement of even seemingly correct tokens is advantageous (Ghazvininejad et al., 2020).
SMART's two-pass recipe, with injected model errors, best matches the inference scenario encountered by Mask-Predict, thus reducing train–test mismatch and further reducing the gap with fully sequential decoding.
5. Mask-Predict Variants: Non-Autoregressive Speech Recognition with Mask CTC
Mask-Predict has been adapted for automatic speech recognition (ASR) by integrating it with CTC, an alignment-free sequence loss:
- Joint optimization: The model jointly minimizes a CTC loss and a CMLM-style Mask-Predict loss:
5
- Inference: Sequence generation begins with greedy CTC decoding to produce a length-determined initial hypothesis. Low-confidence tokens are masked (either by a threshold or via easy-first selection), which are then iteratively refined using a Mask-Predict procedure for 6 iterations.
- Architectural enhancements: The Conformer encoder, which integrates convolutional local modeling into a Transformer structure, further improves recognition performance compared to vanilla Transformers (Higuchi et al., 2020).
- Length prediction: An auxiliary dynamic length-prediction head enables the decoder to insert or delete tokens during iterative refinement, supporting more flexible sequence corrections.
Empirical results demonstrate substantial gains:
- On WSJ eval92, Mask CTC reduces WER from 17.9% (CTC-only) to 12.1% with 7. Conformer-Mask-CTC with dynamic length prediction achieves 9.1% WER, matching or closely approaching autoregressive models, with %%%%2627%%%% real-time speedup (Higuchi et al., 2020).
- Mask CTC generalizes favorably to end-to-end speech translation, with some configurations outperforming autoregressive baselines in BLEU.
6. Computational Complexity and Practical Considerations
Autoregressive left-to-right decoders require 0 sequential steps, where 1 is the target sequence length. Mask-Predict, in contrast, executes precisely 2 full parallel decoding passes regardless of 3, yielding sublinear (constant) complexity with respect to sequence length. Empirical timing demonstrates that Mask-Predict and its variants can achieve 4 speed-ups (MT) and 5–6 faster decoding (ASR) compared to autoregressive systems, with minimal loss in prediction quality (Ghazvininejad et al., 2019, Higuchi et al., 2020).
The tuning of 7, 8, and the length candidate search is central to the speed/accuracy trade-off, and the modularity of Mask-Predict enables integration with enhanced training, new model architectures, and hybrid objectives.
7. Contributions, Extensions, and Applications
- Mask-Predict introduced CMLMs with a masked-LM loss to support arbitrary-order prediction.
- It established a simple, parallelizable, and iterative decoding algorithm bridging fully non-autoregressive and fully autoregressive methods.
- The approach provides empirical control over inference latency versus output quality, making it adaptable for deployment scenarios where speed or accuracy must be prioritized.
- Extensions such as SMART training, Conformer encoders, and dynamic length-prediction heads further generalize the method to structured sequence prediction beyond natural language, such as ASR and direct speech translation (Ghazvininejad et al., 2019, Ghazvininejad et al., 2020, Higuchi et al., 2020, Higuchi et al., 2020).
A key insight is that Mask-Predict, augmented with techniques that better align training and inference, achieves performance rivaling classical sequential decoders while delivering accessible parallelism for both research and production environments.