Scheduled Sampling in Sequence Prediction
- Scheduled sampling is a curriculum learning strategy that mitigates exposure bias by gradually transitioning from ground-truth tokens to the model’s own predictions during training.
- It employs various decay schedules, including linear, exponential, and inverse sigmoid, to control the mix of teacher forcing and self-sampling.
- Empirical results in tasks like image captioning, parsing, and acoustic modeling show that scheduled sampling significantly improves model robustness and accuracy.
Scheduled sampling is a curriculum learning strategy for sequence prediction in neural networks, devised to address the mismatch between training and inference conditions—commonly known as exposure bias. In standard autoregressive models (such as RNNs, Transformers, encoder–decoders), models are trained by maximizing the likelihood of the next token given the ground-truth prefix (teacher forcing). However, at inference, the model conditions on its own previously generated predictions, resulting in a distribution shift. Scheduled sampling gradually transitions the training regime from using strictly ground-truth inputs to consuming the model’s own predictions, aiming to align training and inference dynamics and improve sequence model robustness (Bengio et al., 2015).
1. Motivation and Formulation: Addressing Exposure Bias
Exposure bias arises due to the difference between training (conditioned on true history) and inference (conditioned on the model’s own history). Formally, given a dataset of input–output pairs where , the hidden state of an autoregressive model (e.g., RNN, LSTM) at step is
In standard teacher-forcing, . During inference, the network must instead run on its own prediction , producing a risk of error accumulation not encountered during training (Bengio et al., 2015).
Scheduled sampling stochastically replaces some ground-truth tokens with model predictions at training time. At each step , a Bernoulli coin with probability is flipped: if heads, the true token is provided; otherwise, the model’s own prediction (greedy or sampled) is supplied. This probability is scheduled—typically it decays over training steps from 1 (pure teacher forcing) toward 0 (pure self-sampling) (Bengio et al., 2015). In the cross-entropy loss for each sequence, one thus optimizes with respect to a prefix that is a mixture of gold and generated tokens.
2. Scheduling Strategies and Algorithmic Variants
The choice of scheduling function for is critical. Canonical schedules include:
- Linear Decay:
- Exponential Decay:
- Inverse Sigmoid Decay:
Pseudocode for a stepwise scheduled sampling loop in vanilla RNNs:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
for batch_idx, (X, Y*) in enumerate(training_data): i = global_step ε = schedule(i) h = f_θ.init_state(X) loss = 0 for t in 1..T: if uniform(0,1) < ε: input_token = Y*[t-1] else: logits = W·h + b input_token = argmax(logits) # or sample(logits) h = f_θ(h, input_token) loss += -log p_θ(Y*[t] | h) backpropagate(loss) global_step += 1 |
Variants and extensions are required for non-RNN architectures:
- Transformer Models: Two-pass decoding is necessary since all prefix tokens are visible simultaneously due to self-attention. In the first pass, the model processes the full gold sequence; in the second, it mixes gold embeddings and model-predicted embeddings (e.g., softmax or Gumbel-softmax) according to the schedule (Mihaylova et al., 2019).
- Confidence-Aware Schedules: Instead of a fixed decay, input selection is based on model confidence per token. High-confidence predictions are replaced with model samples or noise, low-confidence with gold, enabling fine-grained adaptation (Liu et al., 2021).
- Dynamic Schedules: Curriculum decisions driven automatically by training accuracy, not training step count (Lin et al., 2023).
- Decoding-Step-Based Scheduling: The golden-token probability is varied over the decoding position , simulating the empirically observed growth of error rate in later positions (Liu et al., 2021).
- Parallel Scheduled Sampling: Scheduled sampling across timesteps is parallelized with vectorized mixing and multiple passes, enabling accelerator-optimized implementations (Duckworth et al., 2019).
3. Empirical Impact and Application Domains
Scheduled sampling has demonstrated substantial empirical gains in diverse sequence generation tasks. In the original work (Bengio et al., 2015):
- MSCOCO Image Captioning: BLEU-4 improves 28.8 → 30.6, METEOR 24.2 → 24.3, CIDER 89.5 → 92.1. “Always sampling” (no ground-truth) fails catastrophically.
- Constituency Parsing: F1 improves from 86.54 → 88.08, with dropout to 88.68.
- Frame-Level Acoustic Modeling: Frame Error Rate decreases from 46.0% → 34.5%.
Subsequent work applied scheduled sampling to neural machine translation (Liu et al., 2021, Liu et al., 2021, Korakakis et al., 2021), speech recognition with RNNTs (Moriya et al., 2023), vision–language pretraining (Li et al., 2021), video captioning (Chen et al., 2019), and even denoising diffusion models, where it mitigates compound drift in the latent denoising trajectory (Deng et al., 2022).
Application-specific extensions have included:
- Dialogue generation: Action-Tree based scheduled sampling at the action sequence (rather than token) level using tree similarity, improving robustness to policy errors in dialogue systems (Liu et al., 28 Jan 2024).
- Bilevel scheduled sampling: Combines word-level model confidence and sentence-level quality metrics (BLEU, cosine similarity of embeddings) via a smooth mapping for per-token sampling decisions (Liu et al., 2023).
4. Algorithmic and Theoretical Considerations
While scheduled sampling empirically reduces exposure bias, several drawbacks and caveats have been identified:
- Impossibility of Pure Sampling: Training with only model predictions (ε=0) leads to convergence failures (Bengio et al., 2015).
- Non-differentiable Sampling: Vanilla scheduled sampling requires sampling or argmax, which is not differentiable w.r.t. the model. Proposals for continuous relaxations (e.g., soft argmax, Gumbel-softmax, Straight-Through Gumbel) enable gradient flow through the selection (Goyal et al., 2017, Xu et al., 2019).
- Alignment Assumptions: Scheduled sampling assumes gold and generated prefixes are positional aligned, inappropriate for tasks with flexible output order (e.g., translation), for which soft-alignment objectives (e.g., SAML) are superior (Xu et al., 2019).
- Catastrophic Forgetting: Scheduled sampling can cause the model to neglect probability mass on gold prefixes, degrading performance when the prefix at inference is correct. Adding an Elastic Weight Consolidation (EWC) penalty mitigates this by preserving parameters important for teacher-forced training (Korakakis et al., 2021).
- Improper Objective: The scheduled sampling objective is an improper scoring rule and can be inconsistent—its loss function may not be uniquely minimized by the true distribution P. As , scheduled sampling encourages independence between consecutive tokens (factorized sequence), a pathology shown formally in (Huszár, 2015).
5. Methodological Extensions and Practical Implementations
Scheduled sampling is highly configurable and has motivated a range of practical enhancements:
- Scheduling by Model Competence: Switch sampling frequency dynamically based on real-time model confidence (Liu et al., 2021), observed generation accuracy (Lin et al., 2023), or position in the output sequence (Liu et al., 2021).
- Noisy and Denoising Schedules: Injecting controlled noise (random target tokens) at high-confidence positions prevents collapse to pure teacher-forcing (Liu et al., 2021, Liu et al., 2023).
- Action-Tree Sampling: For structured output spaces like dialogue actions, scheduled sampling over similar (tree-editing distance) negative policies improves response robustness (Liu et al., 28 Jan 2024).
- Parallelization: Vectorized, multi-pass mixing avoids the sequential bottleneck of token-by-token sampling, yielding major speedups on modern hardware (Duckworth et al., 2019).
- Vision–Language Two-Pass Models: Replace [MASK] tokens with model-sampled tokens in the second pass (cf. BERT-style pretraining for multimodal models) (Li et al., 2021).
Typical hyperparameters—schedule shape, decay rate, minimum ground-truth probability, per-token confidence thresholds—must be tuned for each task and architecture. Validation on held-out data is standard.
6. Limitations, Critiques, and Theoretical Analyses
Critical examination of scheduled sampling reveals several fundamental limitations:
- Improper and Inconsistent Objective: Huszár (Huszár, 2015) demonstrates that scheduled sampling's implicit training objective is not proper; it does not guarantee convergence to the ground-truth data distribution even with infinite data, and tends to drive models to ignore sequential dependencies as grounding is removed.
- Alignment Pathologies: In tasks where reference and output sequences are not positionally aligned, scheduled sampling may yield misleading gradients. Soft-alignment or matched-sequence scoring objectives (e.g., SAML) directly address this (Xu et al., 2019).
- Missing Mode-Seeking Behavior: Scheduled sampling still fundamentally optimizes a likelihood-based (or KL[P||Q]-based) criterion, which does not guarantee high perceptual sample quality. Generalized adversarial objectives (e.g., JS_π with ) better align with perceptual quality metrics (Huszár, 2015).
- Sensitivity to Schedule and Hyperparameters: Empirical performance is highly sensitive to the probability schedule, minimum ground-truth floor, and, in advanced methods, confidence/quality combination strategies (Liu et al., 2021, Liu et al., 2021, Liu et al., 2023).
- Computational Overhead: Some variants (e.g., two-pass, dynamic, tree-based) increase wall-time by 50–100% per step, though this may be mitigated via parallelization strategies (Duckworth et al., 2019, Li et al., 2021).
7. Summary and Outlook
Scheduled sampling constitutes a widely adopted curriculum-based strategy to reduce the exposure bias endemic to sequence prediction under standard cross-entropy (teacher-forced) training. Variants now cover confidence-driven selection, position-aware schedules, parallelized implementations, and structured-output (action-tree) sampling. These methods have shown consistent improvements in language modeling, speech recognition, multimodal generation, and structured prediction benchmarks, often with double-digit relative error reductions or BLEU improvements (Bengio et al., 2015, Liu et al., 2021, Liu et al., 28 Jan 2024, Liu et al., 2021). Scheduled sampling is not a panacea: it admits theoretical pathologies, can be inconsistent, and is best viewed as a practical curriculum device rather than a principled surrogate for the intractable true sample generation objective. Ongoing research aims to merge its practical strengths with more theoretically consistent learning algorithms, particularly in domains requiring long-range structural dependencies or perceptual sample fidelity.