Fundamentals of Recurrent Neural Network (RNN) and Long Short-Term Memory (LSTM) Network
(1808.03314v10)
Published 9 Aug 2018 in cs.LG and stat.ML
Abstract: Because of their effectiveness in broad practical applications, LSTM networks have received a wealth of coverage in scientific journals, technical blogs, and implementation guides. However, in most articles, the inference formulas for the LSTM network and its parent, RNN, are stated axiomatically, while the training formulas are omitted altogether. In addition, the technique of "unrolling" an RNN is routinely presented without justification throughout the literature. The goal of this paper is to explain the essential RNN and LSTM fundamentals in a single document. Drawing from concepts in signal processing, we formally derive the canonical RNN formulation from differential equations. We then propose and prove a precise statement, which yields the RNN unrolling technique. We also review the difficulties with training the standard RNN and address them by transforming the RNN into the "Vanilla LSTM" network through a series of logical arguments. We provide all equations pertaining to the LSTM system together with detailed descriptions of its constituent entities. Albeit unconventional, our choice of notation and the method for presenting the LSTM system emphasizes ease of understanding. As part of the analysis, we identify new opportunities to enrich the LSTM system and incorporate these extensions into the Vanilla LSTM network, producing the most general LSTM variant to date. The target reader has already been exposed to RNNs and LSTM networks through numerous available resources and is open to an alternative pedagogical approach. A Machine Learning practitioner seeking guidance for implementing our new augmented LSTM model in software for experimentation and research will find the insights and derivations in this tutorial valuable as well.
The paper derives core RNN equations using differential equations and unrolling principles to justify practical training and inference strategies.
It investigates training challenges such as vanishing and exploding gradients using Back Propagation Through Time and demonstrates LSTM's gated mechanism to mitigate these issues.
The study extends the Vanilla LSTM to an Augmented LSTM by incorporating input context windows, a recurrent projection layer, and an extra control gate to enhance sequence modeling.
This paper provides a fundamental explanation of Recurrent Neural Networks (RNNs) and Long Short-Term Memory (LSTM) networks, focusing on their underlying principles, derivations, and practical implementation details. The authors motivate the work by highlighting the lack of comprehensive resources that derive the core equations and justify common techniques like "unrolling."
The paper begins by deriving the canonical RNN formulation from differential equations, drawing concepts from signal processing. It starts with a general first-order nonlinear non-homogeneous ordinary differential equation for a state signal s(t). By considering a specific form of the function governing the state change, linearizing operators, and applying the backward Euler discretization method with a specific time delay equal to the sampling step, the authors arrive at the canonical discrete-time RNN equations:
s[n]=Wss[n−1]+Wrr[n−1]+Wxx[n]+θs
r[n]=G(s[n])
where s[n] is the state vector at step n, r[n] is the readout vector (a warped version of the state via activation function G), x[n] is the input vector, Ws,Wr,Wx are weight matrices, and θs is a bias vector. The standard RNN definition is then presented as a simplification where Ws is effectively zeroed out based on stability arguments derived from the differential equation counterpart, resulting in s[n]=Wrr[n−1]+Wxx[n]+θs.
The technique of "unrolling" an RNN is formally addressed. The paper defines an RNN "cell" and explains that unrolling it for a finite number of steps (K) transforms its recurrent graph into a directed acyclic graph (DAG). A proposition is introduced and proved, stating that if a long sequence of ground truth outputs can be partitioned into mutually independent, finite-length segments, then a single, reusable RNN cell, unrolled for the segment length, is sufficient for optimizing parameters and performing inference. This formalizes the practical approach of processing data in batches of fixed or variable-length sequences. The proof relies on showing the independence of segment-level state signal subsequences when initialized appropriately (e.g., to zero). Practical considerations of unrolling, such as the artificial boundaries introduced by truncation and the inability to capture dependencies longer than the unrolling window, are discussed.
The paper then explores the well-known training difficulties of standard RNNs, namely vanishing and exploding gradients. Using Back Propagation Through Time (BPTT), the authors derive the equations for the gradient of the objective function E with respect to the state signal ψ[n]. They show that the magnitude of the gradient propagated backward through time, ∂ψ[l]∂ψ[n] for l≫n, depends on the product of Jacobian matrices over l−n steps. In the standard RNN, this product often leads to the gradient exponentially decaying (vanishing) if the spectral radius of the relevant matrix product is less than 1, or exponentially growing (exploding) if it is greater than 1, hindering the network's ability to learn long-range dependencies.
To address these issues, the paper logically evolves the canonical RNN into the Vanilla LSTM network. The core idea is to introduce multiplicative gates that control the flow of information. The state update equation is modified to include gates for the previous state contribution (gcs) and the new update candidate contribution (gcu):
s[n]=gcs[n]⊙s[n−1]+gcu[n]⊙u[n]
where u[n] is a new update candidate signal derived from the previous readout and current input, passed through a data warping function Gd. A control readout gate (gcr) is introduced to modulate the output signal v[n]=gcr[n]⊙r[n]. The gates themselves are computed using logistic sigmoid functions (Gc) applied to linear combinations of available signals (input, previous state, previous value signal). The paper presents the Constant Error Carousel (CEC) mode as an idealized illustration where gates are set to allow error gradient to propagate unattenuated.
The paper provides a detailed, self-contained explanation of the Vanilla LSTM mechanism, including notation, data standardization (mean 0, standard deviation 1), the two warping functions (Gc and Gd), a list of all 15 sets of parameters, and the full set of equations for both the forward pass (inference) and the backward pass (derivatives for training). The backward pass equations include the total partial derivatives of the objective function with respect to the value signal (χ[n]), readout signal (ρ[n]), gate signals (γ), accumulation nodes (α), and state signal (ψ[n]), culminating in the derivatives with respect to all model parameters dΘdE. It is shown that parameter updates depend on the accumulation derivatives, which in turn depend on ψ[n]. The analysis of the error gradient flow in LSTM shows that the gcs[n] gate directly controls a key term in the ψ[n] update, allowing the gradient to be potentially sustained over many steps, mitigating the vanishing gradient problem.
Finally, the paper proposes extensions to the Vanilla LSTM network to create an "Augmented LSTM." These extensions include:
External Input Context Windows: Replacing the single input sample x[n] in the accumulation node computations with a convolution over a small window of input samples, including "future" samples (non-causal filter), to incorporate local context.
Recurrent Projection Layer: Adding a linear transformation Wqdr after the gated readout q[n]=gcr[n]⊙r[n] to produce the value signal v[n]=Wqdrq[n]. This allows reducing the dimensionality of the recurrent connection (dv<ds), trading off model capacity for computational efficiency.
Controlling External Input with a New Gate: Introducing an additional gate, the control input gate gcx[n], to modulate the contribution of the composite external input signal (after the context window convolution) to the data update accumulation node adu[n].
The full forward and backward pass equations for this Augmented LSTM are provided, detailing how these extensions integrate into the model and how their parameters are learned via BPTT. The paper concludes by suggesting future work, including implementing and benchmarking the Augmented LSTM on tasks like question answering and customer support automation, comparing its performance against baselines (Vanilla LSTM, bidirectional LSTM, Transformer), and evaluating the specific contributions of the new input context windows and gate.