- The paper introduces scheduled sampling as a curriculum learning strategy that mitigates the training-inference gap in RNN-based sequence prediction tasks.
- It employs a decay schedule to gradually shift from using true tokens to model predictions during training, significantly improving metrics in tasks like image captioning and constituency parsing.
- Experimental results show that scheduled sampling enhances performance without extra training time, demonstrating its potential for real-world applications such as machine translation, speech recognition, and more.
Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks
Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks presents a novel curriculum learning strategy to address the discrepancy between training and inference in sequence prediction tasks using Recurrent Neural Networks (RNNs). This paper addresses the common issue where during inference, the model generates tokens based on its previous predictions rather than the true previous tokens used during training, which can lead to the accumulation of errors.
Introduction
RNNs, particularly Long Short-Term Memory (LSTM) networks, have shown impressive results in sequence prediction tasks such as machine translation and image captioning. Typically, these models are trained to predict each token in a sequence by maximizing the likelihood given the current state and the previous true token. However, during inference, true previous tokens are not available and are replaced by tokens generated by the model itself, introducing a training-inference mismatch that can lead to cascading errors.
Proposed Approach
The proposed method, termed Scheduled Sampling, aims to incrementally bridge the gap between training and inference. The idea is to gradually transition the training process from always using the true previous token to using the model's own predictions. This is achieved by flipping a coin at each time step during training to decide whether to use the true previous token or the model's generated token as input. The probability of selecting the true token, denoted as ϵi, is decreased over time according to a predefined schedule.
Examples of decay schedules include:
- Linear Decay: ϵi=max(ϵ,k−c⋅i)
- Exponential Decay: ϵi=ki
- Inverse Sigmoid Decay: ϵi=k/(k+exp(i/k))
The paper situates its contributions within the broader context of prior work addressing training-inference discrepancies in sequential tasks, such as SEARN, which involves iterative training and policy updates for sequence prediction. Unlike SEARN, the proposed approach is online and more computationally efficient. Previous work on parsing tasks has used beam search during training to align training and inference, which is not feasible for RNNs due to the continuous state space.
Experimental Results
Three different tasks were used to evaluate the effectiveness of Scheduled Sampling: image captioning, constituency parsing, and speech recognition.
- Image Captioning:
- Dataset: MSCOCO
- Model: LSTM with 512 hidden units
- Metrics: BLEU-4, METEOR, CIDEr
- Results showed significant improvements across all metrics compared to the baseline, demonstrating the robustness of Scheduled Sampling. The approach outperformed both the baseline and uniformly sampled models, and the method was instrumental in winning the 2015 MSCOCO image captioning challenge.
- Constituency Parsing:
- Dataset: WSJ 22
- Model: LSTM with Attention Mechanism
- Metric: F1 Score
- Scheduled Sampling exhibited additive benefits to applying dropout, improving F1 scores from 86.54 (baseline) to 88.68.
- Speech Recognition:
- Dataset: TIMIT
- Model: LSTM with 250 cells
- Metrics: Frame Error Rate (FER) for Next Step Prediction and Decoding
- The method led to improved decoding FER, although the baseline achieved better Next Step Prediction due to the deterministic nature of the task.
Conclusion
Scheduled Sampling successfully mitigates the training-inference discrepancy in sequence prediction tasks with RNNs by introducing a curriculum learning-based approach. This methodology demonstrated empirical improvements across diverse tasks without incurring additional training time. Future directions include integrating back-propagation through the sampling decisions and exploring confidence-based sampling strategies.
Implications and Future Directions
The practical implications of this research are substantial, offering a robust method to enhance the performance of RNNs in real-world applications such as machine translation, image captioning, and speech recognition. Theoretically, it opens new avenues for addressing the divergence between training and inference distributions in sequential models. Future advancements may target more sophisticated scheduling and sampling mechanisms, leveraging model confidence to further tailor the training regime.