AXE Loss: Aligned Cross Entropy in NA Models
- AXE Loss is a training criterion that uses dynamic programming for monotonic alignment in non-autoregressive models to reward lexical accuracy despite token misplacements.
- It improves performance in machine translation and speech recognition by mitigating harsh penalties for minor misalignments, yielding notable gains in BLEU and WER metrics.
- AXE employs a structured alignment mechanism with a blank token for unaligned predictions, offering an effective and robust alternative to standard cross entropy.
Aligned Cross Entropy (AXE) Loss is a training criterion for non-autoregressive sequence generation models, designed to address the issues posed by position-sensitive cross entropy in settings where predictions of target tokens are made in parallel. By leveraging a differentiable dynamic programming alignment, AXE loss aligns model outputs to targets according to lexical correctness and relative order, mitigating harsh penalties for minor positional misalignments. This refinement establishes AXE as an effective alternative to standard cross entropy, delivering substantial gains on machine translation and speech recognition tasks.
1. Motivation and Conceptual Foundation
Traditional cross entropy loss enforces rigid position-by-position correspondence between model predictions and ground truth sequences. In non-autoregressive models, which predict whole sequences in parallel rather than token-by-token, even slight misalignments—such as a shifted token or an insert—can incur substantial loss. This is especially problematic for models like Conditional Masked LLMs (CMLMs), where autoregressive conditioning is absent and word order variability is pronounced. AXE loss was introduced to soften this strictness by rewarding models for producing the right lexical items, regardless of their exact placement, so long as monotonicity in alignment is respected. In this framework, partial credit is preserved for position-shifted but correct predictions, providing a more meaningful gradient for training parallel decoders (Ghazvininejad et al., 2020, Zhang et al., 2023).
2. Formal Mechanism and Dynamic Programming Alignment
AXE operates by determining the best monotonic alignment between the gold target sequence and the predicted sequence . Alignment is formalized by the mapping , such that is non-decreasing ( for ).
Given a particular alignment , the conditional AXE loss is:
where is a special blank token that penalizes unaligned model predictions.
The final AXE loss is the minimum over all possible monotonic alignments:
This minimization is performed efficiently via dynamic programming, using three primary local update operators:
- Align: (aligns to )
- Skip Prediction: (skips )
- Skip Target: (skips , with extra penalty )
This algorithm allows realignment of sequences during training, offering tolerance to position shifts and structural variations (Ghazvininejad et al., 2020, Zhang et al., 2023).
3. Empirical Performance and Impact on Non-Autoregressive Models
The application of AXE to CMLMs in non-autoregressive machine translation leads to marked improvements on major benchmarks. For example, on WMT’14 English–German, CMLMs trained with AXE outperform those trained with standard cross entropy by approximately 5 BLEU points (Ghazvininejad et al., 2020). The dynamic alignment method enhances “position confidence,” resulting in sharply peaked output probability distributions that suppress spurious multimodal hypotheses and substantially decrease token repetition rates—reported to drop by a factor greater than ten.
In speech recognition domains, AXE also demonstrates efficacy. Experiments on the WSJ dataset show significant WER reduction when substituting cross entropy with AXE in Mask CTC models. The effect is further enhanced by the “dynamic rectification” method, which simulates inference-time conditions by introducing challenging sample variations during training (Zhang et al., 2023).
4. Comparative Analysis to Standard and Alternative Loss Functions
AXE loss distinguishes itself from standard cross entropy and other sequence alignment criteria, such as Connectionist Temporal Classification (CTC) and latent-variable approaches, in several respects:
Loss Function | Alignment Flexibility | Position Sensitivity | Benchmark Reported Performance |
---|---|---|---|
Cross Entropy | Strict, absolute | High | Baseline |
CTC | Allows flexible alignment | Moderate | Moderate improvement |
AXE | Best monotonic alignment | Low (lexical focus) | State of the art for NA models |
AXE advances the state of the art among purely non-autoregressive models, as demonstrated by superior BLEU and WER metrics in experimental studies. Ablation trials reveal the benefit of hyperparameter tuning (e.g., penalty ) and consistent outperformance compared to both semi-autoregressive and hint-based approaches (Ghazvininejad et al., 2020, Zhang et al., 2023).
5. Extension: Dynamic Rectification and Robust Training Pipelines
A key adjunct to AXE is the dynamic rectification technique, particularly relevant for Mask CTC models in speech recognition (Zhang et al., 2023). During training, after masking parts of the ground truth input, preliminary predictions are generated, and high-confidence predictions are selectively re-masked to produce more realistic training inputs. This method models the errors encountered during greedy CTC inference and narrows the gap between training and deployment, leading to further improvements in recognition accuracy.
6. Connections to Structured Entropy and Other Generalized Losses
AXE shares conceptual similarities with structured entropy loss functions, which also seek to soften the “strict error” property of cross entropy by considering inherent structure or similarity between targets (Lucena, 2022). Whereas structured entropy averages cross entropy over partitions (thus encoding prior knowledge about class similarity), AXE focuses on monotonic sequence alignment. Both methods aim to reward errors that are closer in meaning or structure to the gold reference, though AXE is tailored to sequence generation and dynamic temporal alignment rather than fixed-label structured classification.
7. Applications and Future Research Directions
Beyond non-autoregressive machine translation and speech recognition, AXE loss is applicable to conditional sequence generation tasks where relative token order is of greater interest than absolute positional alignment—such as text summarization or morphological analysis. Future research may explore optimal integration of AXE with auxiliary signals, modifications of the alignment mechanism, and extensions to alternative model architectures and decoding strategies (Ghazvininejad et al., 2020, Zhang et al., 2023).
Possible limitations of AXE include increased computational overhead from dynamic programming alignment during training and the complexity introduced by rectification methods. However, empirical results suggest these costs are justified by the resultant improvements in lexical accuracy, fluency, and robustness.
In summary, AXE loss offers a theoretically principled and empirically validated refinement to standard training losses for non-autoregressive sequence models, leveraging monotonic alignment to address the challenges of position-sensitive criteria while facilitating state-of-the-art performance across diverse generation tasks.