Anticipated Reweighted Truncated BPTT
- The paper introduces ARTBP, achieving unbiased gradient estimates for RNNs by randomizing truncation points and applying compensation factors during backpropagation.
- It maintains the computational efficiency and memory advantages of truncated BPTT while reliably capturing long-term dependencies in sequential data.
- Empirical results on synthetic tasks and language modeling demonstrate ARTBP’s improved convergence and stability over traditional truncated BPTT.
Anticipated Reweighted Truncated Backpropagation (ARTBP) is a stochastic gradient estimation algorithm for training recurrent neural networks (RNNs) on long sequences. It preserves the computational and memory efficiency of truncated Backpropagation Through Time (BPTT) while providing unbiased gradient estimates, thereby enabling reliable learning of long-term dependencies. ARTBP achieves unbiasedness by randomizing truncation points and applying compensation factors within the backward recursion. The method was introduced by Tallec and Ollivier to address the convergence issues associated with the biased gradients of truncated BPTT (Tallec et al., 2017).
1. Background: BPTT and Truncated BPTT
Standard Backpropagation Through Time (BPTT) computes gradients for RNNs by unrolling the entire recurrent system and backpropagating gradients through every timestep. For a system defined by and stepwise loss , the total sequence loss is . The exact gradient requires storing all recurrent states and backpropagating through layers, resulting in space and compute requirements.
Truncated BPTT alleviates this cost by dividing the full sequence into consecutive fixed-length blocks of length . Gradients are computed only within each block, and the recurrence graph is cut at block boundaries—no gradient signals propagate across those boundaries. The backward recursion for approximate signals sets the recurrent term to zero every steps. While this provides memory and update time, the resulting estimator is biased: dependencies across block boundaries are omitted, resulting in unreliable learning of long-term dependencies and possible divergence during training.
2. ARTBP Formalism
ARTBP removes the bias in truncated BPTT by introducing randomness in the truncation points and adjusting the backpropagation equations with compensation factors.
Notation and Recursion
- Let indicate whether a truncation occurs between and .
- is the (possibly time-varying) truncation probability.
The backward recursion for adjusted signals is:
The unbiased gradient estimator is then
Truncation Distribution
- Geometric: Constant , yielding exponentially distributed block lengths.
- Heavy-tailed: For variance control,
where is time since the last truncation, is the target mean length, and .
3. ARTBP Algorithm and Computational Properties
The typical ARTBP workflow includes:
- Sampling subsequence lengths on the fly via Bernoulli draws with probability .
- Forward pass computing over the chosen segment.
- Backward pass computing with compensation, as above.
- Gradient accumulation: .
- Parameter update: .
| Method | Space per Update | Time per Update | Unbiased? |
|---|---|---|---|
| BPTT (full) | Yes | ||
| Truncated BPTT | No | ||
| ARTBP | (mean ) | (mean ) | Yes |
denotes a random subsequence length with mean . Overhead incurred by sampling and compensating by is negligible.
4. Theoretical Unbiasedness
Tallec & Ollivier prove that
where denotes the expectation over the stochastic truncation schedule. The proof utilizes backward induction, showing that the expected backward signal satisfies: At each time , the continuation of the gradient is either dropped with probability or rescaled by with probability , which ensures in expectation a single contribution per step, matching the BPTT recursion. No structural assumptions are required beyond and the Markov property of the truncation process.
5. Empirical Evaluations
Influence-Balancing Synthetic Task
- Setup: Linear chain of positive and negative agents, with signals arriving at a delay.
- Methods: Truncated BPTT with ; ARTBP with , .
- Findings:
- Truncated BPTT diverges for and due to bias; only converges, and slowly.
- ARTBP converges reliably for all random seeds at a rate .
- Unbiasedness is necessary to balance multi-scale temporal dependencies.
Penn Treebank Character-Level Language Modeling
- Model: Single-layer LSTM, batch size 64, Adam optimizer.
- Schedules: Fixed for truncated BPTT; ARTBP uses from , .
- Results:
- Truncated BPTT test bpc: 1.43.
- ARTBP test bpc: 1.40.
- ARTBP provides a small but observable improvement in validation and test metrics.
- Variations in ($4$ vs $6$) had minor impact; smaller decreases memory usage but increases gradient variance.
6. Practical Considerations and Limitations
- Truncation distribution: For a known memory budget , set . A constant (geometric) is simplest but amplifies gradient variance; heavy-tailed with manages high variance in compensation factors.
- Variance vs. memory trade-off: Lower (longer segments) reduces variance at the cost of memory; higher (shorter blocks) increases stochasticity and may require gradient norm monitoring or clipping.
- Operational modes: ARTBP supports both online streaming (stepwise) and mini-batch operation. Batch samples should not cross truncation boundaries to preserve the schedule.
- Limitations: ARTBP introduces gradient noise from the stochastic schedule, which may slow convergence in deterministic settings relative to full BPTT. Compensation factors can become large if approaches unity, potentially destabilizing updates. Monitoring gradient norms and increasing or is recommended if instability or high variance is observed.
- Recommended procedure: Initialize and ; adjust based on observed gradient variance and available memory.
7. Significance and Applications
ARTBP solves the longstanding issue of gradient bias in truncated BPTT for RNN training on long sequences. The unbiasedness of ARTBP’s gradient estimates makes it suitable for tasks where accurate credit assignment across long temporal horizons is essential. Its memory and computational complexity are comparable to traditional truncated BPTT, with the added practical consideration of managing gradient variance via hyperparameter choices. ARTBP provides a principled solution without imposing architectural changes on the underlying recurrent model or loss functions (Tallec et al., 2017).