Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
Gemini 2.5 Pro
GPT-5
GPT-4o
DeepSeek R1 via Azure
2000 character limit reached

Fast Training of Recurrent Neural Networks with Stationary State Feedbacks (2503.23104v1)

Published 29 Mar 2025 in cs.LG and cs.AI

Abstract: Recurrent neural networks (RNNs) have recently demonstrated strong performance and faster inference than Transformers at comparable parameter budgets. However, the recursive gradient computation with the backpropagation through time (or BPTT) algorithm remains the major computational bottleneck. In this work, we propose a novel method that replaces BPTT with a fixed gradient feedback mechanism, yielding an efficient approximation of the exact gradient propagation based on the assumption of time stationarity. Our approach leverages state-space model (SSM) principles to define a structured feedback matrix that directly propagates gradients from future time steps. This formulation bypasses the need for recursive gradient backpropagation, significantly reducing training overhead while preserving the network's ability to capture long-term dependencies. The experiments on LLMing benchmarks exhibit competitive perplexity scores, while significantly reducing the training costs. These promising results suggest that designing a feedback method like an SSM can fully exploit the efficiency advantages of RNNs for many practical applications.

Summary

  • The paper proposes DSF, a novel approach that replaces time-dependent Jacobians with a fixed diagonal matrix to reduce gradient computation complexity.
  • It reformulates the backward pass as a convolution-based, time-invariant system that enables efficient parallel computation using FFTs or prefix-sum algorithms.
  • Experimental results on language modeling tasks show DSF achieves near-BPTT performance, significantly outperforming FT-BPTT while enhancing scalability.

This paper introduces Diagonal State Feedbacks (DSF), a novel method for training Recurrent Neural Networks (RNNs) more efficiently by approximating the gradient computation of the Backpropagation Through Time (BPTT) algorithm. The core problem DSF addresses is the computational bottleneck caused by the recursive and sequential nature of BPTT, which limits the scalability of RNNs despite their recent resurgence and strong performance in sequence modeling.

The key idea behind DSF is to replace the time-dependent Jacobian matrix of the recurrent cell, Jt=ht+1ht\mathbf{J}_t = \frac{\partial \mathbf{h}_{t+1}}{\partial \mathbf{h}_t}, which is used in the backward pass of BPTT, with a fixed, time-invariant, diagonal matrix A\mathbf{A}. The standard BPTT gradient recurrence for gt=Lht\mathbf{g}_t = \frac{\partial \mathcal{L}}{\partial \mathbf{h}_t} is:

gt=et+gt+1Jt\mathbf{g}_t = \mathbf{e}_t + \mathbf{g}_{t+1} \mathbf{J}_t

with gT=eT\mathbf{g}_T = \mathbf{e}_T, where et=Ltht\mathbf{e}_t = \frac{\partial \mathcal{L}_t}{\partial \mathbf{h}_t} is the error from the output layer at time tt.

DSF simplifies this to:

gt=et+gt+1A\mathbf{g}_t = \mathbf{e}_t + \mathbf{g}_{t+1} \mathbf{A}

This transforms the backward gradient propagation into a linear, time-invariant dynamical system, which can be viewed as a State-Space Model (SSM) operating in reverse time. This formulation allows the computation of all gt\mathbf{g}_t terms using a convolution:

(gT,gT1,,g1)=(eT,eT1,,e1)(I,A,A2,,AT1)(\mathbf{g}_T, \mathbf{g}_{T-1},\dots,\mathbf{g}_1) = (\mathbf{e}_T,\mathbf{e}_{T-1},\dots,\mathbf{e}_1) * (\mathbf{I}, \mathbf{A}, \mathbf{A}^2, \dots, \mathbf{A}^{T-1})

This convolution can be computed efficiently, especially because A\mathbf{A} is diagonal.

Implementation and Advantages of DSF:

  1. Diagonal Feedback Matrix (A\mathbf{A}):
    • A\mathbf{A} is initialized once (e.g., from a uniform distribution in [0,1][0, 1]) and remains fixed throughout training. This avoids optimizing A\mathbf{A}, unlike in many SSMs.
    • Being diagonal, matrix-vector multiplications gt+1A\mathbf{g}_{t+1}\mathbf{A} become element-wise vector operations, reducing the complexity per time step from O(d2)\mathcal{O}(d^2) (for dense Jt\mathbf{J}_t) to O(d)\mathcal{O}(d), where dd is the hidden state dimension.
    • Memory to store A\mathbf{A} is O(d)\mathcal{O}(d) instead of O(d2)\mathcal{O}(d^2).
  2. Computational Efficiency:
    • Naive Implementation: The recurrent computation of gt\mathbf{g}_t using the diagonal A\mathbf{A} has a total complexity of O(dT)\mathcal{O}(dT) but remains sequential (O(T)\mathcal{O}(T)).
    • Parallel Implementation (FFT/Prefix-Sum): The convolution formulation allows for parallel computation using FFTs or prefix-sum algorithms. This results in a complexity of O(dTlogT)\mathcal{O}(dT \log T) but significantly reduces sequentiality to O(logT)\mathcal{O}(\log T), making it well-suited for GPUs.
  3. Comparison to BPTT and FT-BPTT:
    • BPTT (Backpropagation Through Time): Complexity O(d2T)\mathcal{O}(d^2 T), Sequentiality O(T)\mathcal{O}(T). Computes exact gradients.
    • DSF (FFT/Prefix-Sum): Complexity O(dTlogT)\mathcal{O}(dT \log T), Sequentiality O(logT)\mathcal{O}(\log T). Approximates gradients.
    • FT-BPTT (Fully Truncated BPTT): Jacobian is assumed to be zero (A=0\mathbf{A}=0). Gradients gt=et\mathbf{g}_t = \mathbf{e}_t. Complexity O(d)\mathcal{O}(d), Sequentiality O(1)\mathcal{O}(1). Ignores temporal dependencies in the backward pass.
    Method Complexity Sequentiality
    BPTT O(d2T)\mathcal{O}(d^2 T) O(T)\mathcal{O}(T)
    DSF (naive implementation) O(dT)\mathcal{O}(d T) O(T)\mathcal{O}(T)
    DSF (FFT/Prefix-sum impl.) O(dTlogT)\mathcal{O}(d T \log T) O(logT)\mathcal{O}(\log T)
    Fully Truncated BPTT (FT-BPTT) O(1)\mathcal{O}(1) (per step, total O(dT)\mathcal{O}(dT) for et\mathbf{e}_t and param grads) O(1)\mathcal{O}(1)
  4. Connection to Direct Feedback Alignment (DFA): DSF extends the idea of DFA (which uses fixed random matrices for layer-wise feedback in feedforward networks) to the temporal domain in RNNs. Instead of approximating layer-to-layer Jacobians, DSF approximates the time-step-to-time-step Jacobian Jt\mathbf{J}_t.

  5. Stability: Using a fixed diagonal A\mathbf{A} can help mitigate exploding/vanishing gradients often associated with the product of Jacobians in BPTT. The values in A\mathbf{A} (e.g., initialized between 0 and 1) can control the decay of error signals propagated backward.

Experimental Results:

Experiments were conducted on LLMing tasks using Penn Treebank (PTB) and Wikitext-103 datasets, comparing DSF with BPTT and FT-BPTT across various RNN architectures (Vanilla RNN, GRU, LSTM), network depths, hidden dimensions, and sequence lengths.

  • Performance: DSF consistently and significantly outperformed FT-BPTT, demonstrating its ability to capture temporal dependencies effectively. It achieved performance remarkably close to BPTT across most configurations.

    • For example, on Wikitext-103 (3 GRU layers, 512 hidden units, sequence length 256), perplexities were: BPTT (28.32), DSF (31.54), FT-BPTT (46.60).
  • Scalability: DSF maintained its effectiveness relative to BPTT even with increasing network depth, width, and sequence length. The performance gap between DSF and BPTT did not widen substantially for larger models.
  • Architecture Agnostic: DSF proved effective across vanilla RNNs, GRUs, and LSTMs.
  • Comparison with other Architectures: On Wikitext-103 (approx. 28M parameters):
    • Transformer (GPT-2 style): 24.58 perplexity
    • RNN (GRU) + BPTT: 28.32 perplexity
    • RNN (GRU) + DSF: 31.54 perplexity
    • SSM (MEGA-style diagonal kernel): 33.19 perplexity
    • DSF-trained RNNs were competitive, even outperforming a comparable SSM, highlighting that efficient gradient approximation in RNNs can be a strong alternative.

Pseudocode for DSF Gradient Computation (Reverse Time):

1
2
3
4
5
6
7
G = array of size T
G[T-1] = E[T-1]  # g_T = e_T (using 0-based indexing for arrays)

for t from T-2 down to 0:
  G[t] = E[t] + G[t+1] * A  # '*' is element-wise product as A is diagonal
                           # and g vectors are row vectors
The final parameter gradients Lθ\frac{\partial \mathcal{L}}{\partial \theta} are then computed as t=1Tgtfθtθ\sum_{t=1}^{T} \mathbf{g}_t \frac{\partial f_{\theta}^t}{\partial \theta}.

Discussion and Future Directions:

  • DSF offers a practical trade-off between computational cost and model performance.
  • Future work could explore:
    • Adaptive Feedback Matrices: Periodically updating A\mathbf{A} or using meta-learning.
    • Structured Non-Diagonal Extensions: Low-rank + diagonal, or other structures for A\mathbf{A} to potentially capture more complex dependencies.
    • Initialization Strategies: Investigating more sophisticated initializations for A\mathbf{A}, possibly inspired by SSM parameterizations like HiPPO or S4.
    • Theoretical Analysis: Deeper understanding of the gradient approximation and its impact.

In conclusion, DSF is presented as an efficient and effective method for training RNNs, making them more scalable for long sequences and large models by simplifying the BPTT bottleneck with a fixed, diagonal feedback mechanism for temporal gradient propagation. It achieves performance competitive with BPTT while being significantly faster and more memory-efficient.

Dice Question Streamline Icon: https://streamlinehq.com

Follow-up Questions

We haven't generated follow-up questions for this paper yet.