- The paper proposes DSF, a novel approach that replaces time-dependent Jacobians with a fixed diagonal matrix to reduce gradient computation complexity.
- It reformulates the backward pass as a convolution-based, time-invariant system that enables efficient parallel computation using FFTs or prefix-sum algorithms.
- Experimental results on language modeling tasks show DSF achieves near-BPTT performance, significantly outperforming FT-BPTT while enhancing scalability.
This paper introduces Diagonal State Feedbacks (DSF), a novel method for training Recurrent Neural Networks (RNNs) more efficiently by approximating the gradient computation of the Backpropagation Through Time (BPTT) algorithm. The core problem DSF addresses is the computational bottleneck caused by the recursive and sequential nature of BPTT, which limits the scalability of RNNs despite their recent resurgence and strong performance in sequence modeling.
The key idea behind DSF is to replace the time-dependent Jacobian matrix of the recurrent cell, Jt=∂ht∂ht+1, which is used in the backward pass of BPTT, with a fixed, time-invariant, diagonal matrix A.
The standard BPTT gradient recurrence for gt=∂ht∂L is:
gt=et+gt+1Jt
with gT=eT, where et=∂ht∂Lt is the error from the output layer at time t.
DSF simplifies this to:
gt=et+gt+1A
This transforms the backward gradient propagation into a linear, time-invariant dynamical system, which can be viewed as a State-Space Model (SSM) operating in reverse time. This formulation allows the computation of all gt terms using a convolution:
(gT,gT−1,…,g1)=(eT,eT−1,…,e1)∗(I,A,A2,…,AT−1)
This convolution can be computed efficiently, especially because A is diagonal.
Implementation and Advantages of DSF:
- Diagonal Feedback Matrix (A):
- A is initialized once (e.g., from a uniform distribution in [0,1]) and remains fixed throughout training. This avoids optimizing A, unlike in many SSMs.
- Being diagonal, matrix-vector multiplications gt+1A become element-wise vector operations, reducing the complexity per time step from O(d2) (for dense Jt) to O(d), where d is the hidden state dimension.
- Memory to store A is O(d) instead of O(d2).
- Computational Efficiency:
- Naive Implementation: The recurrent computation of gt using the diagonal A has a total complexity of O(dT) but remains sequential (O(T)).
- Parallel Implementation (FFT/Prefix-Sum): The convolution formulation allows for parallel computation using FFTs or prefix-sum algorithms. This results in a complexity of O(dTlogT) but significantly reduces sequentiality to O(logT), making it well-suited for GPUs.
- Comparison to BPTT and FT-BPTT:
- BPTT (Backpropagation Through Time): Complexity O(d2T), Sequentiality O(T). Computes exact gradients.
- DSF (FFT/Prefix-Sum): Complexity O(dTlogT), Sequentiality O(logT). Approximates gradients.
- FT-BPTT (Fully Truncated BPTT): Jacobian is assumed to be zero (A=0). Gradients gt=et. Complexity O(d), Sequentiality O(1). Ignores temporal dependencies in the backward pass.
Method |
Complexity |
Sequentiality |
BPTT |
O(d2T) |
O(T) |
DSF (naive implementation) |
O(dT) |
O(T) |
DSF (FFT/Prefix-sum impl.) |
O(dTlogT) |
O(logT) |
Fully Truncated BPTT (FT-BPTT) |
O(1) (per step, total O(dT) for et and param grads) |
O(1) |
Connection to Direct Feedback Alignment (DFA): DSF extends the idea of DFA (which uses fixed random matrices for layer-wise feedback in feedforward networks) to the temporal domain in RNNs. Instead of approximating layer-to-layer Jacobians, DSF approximates the time-step-to-time-step Jacobian Jt.
Stability: Using a fixed diagonal A can help mitigate exploding/vanishing gradients often associated with the product of Jacobians in BPTT. The values in A (e.g., initialized between 0 and 1) can control the decay of error signals propagated backward.
Experimental Results:
Experiments were conducted on LLMing tasks using Penn Treebank (PTB) and Wikitext-103 datasets, comparing DSF with BPTT and FT-BPTT across various RNN architectures (Vanilla RNN, GRU, LSTM), network depths, hidden dimensions, and sequence lengths.
Pseudocode for DSF Gradient Computation (Reverse Time):
1
2
3
4
5
6
7
|
G = array of size T
G[T-1] = E[T-1] # g_T = e_T (using 0-based indexing for arrays)
for t from T-2 down to 0:
G[t] = E[t] + G[t+1] * A # '*' is element-wise product as A is diagonal
# and g vectors are row vectors |
The final parameter gradients
∂θ∂L are then computed as
t=1∑Tgt∂θ∂fθt.
Discussion and Future Directions:
- DSF offers a practical trade-off between computational cost and model performance.
- Future work could explore:
- Adaptive Feedback Matrices: Periodically updating A or using meta-learning.
- Structured Non-Diagonal Extensions: Low-rank + diagonal, or other structures for A to potentially capture more complex dependencies.
- Initialization Strategies: Investigating more sophisticated initializations for A, possibly inspired by SSM parameterizations like HiPPO or S4.
- Theoretical Analysis: Deeper understanding of the gradient approximation and its impact.
In conclusion, DSF is presented as an efficient and effective method for training RNNs, making them more scalable for long sequences and large models by simplifying the BPTT bottleneck with a fixed, diagonal feedback mechanism for temporal gradient propagation. It achieves performance competitive with BPTT while being significantly faster and more memory-efficient.