mLSTM with Sigmoid Input Gate
- mLSTM with Sigmoid Input Gate is a gated recurrent neural network that leverages a multiplicative intermediate state and sigmoid activation to enable dynamic, context-sensitive memory updates.
- The architecture simplifies recurrence by eliminating auxiliary states, reducing computation and memory overhead while achieving performance speedups of up to 30%.
- Practical evaluations demonstrate that mLSTM variants excel in long-context language modeling, delivering competitive results in benchmarks with efficient compute and memory usage.
The multiplicative Long Short-Term Memory (mLSTM) architecture with sigmoid input gate is a gated recurrent neural network that synthesizes the expressive hidden-state transitions of multiplicative RNNs with the robust memory control mechanisms of LSTMs. Distinguished by its use of input-dependent multiplicative paths and a sigmoid-activated input gate, mLSTM—and its modern variants such as mLSTMsig—demonstrate enhanced modeling capabilities for sequence data, while enabling efficient implementation and kernel optimization for long-context scenarios (Krause et al., 2016, Beck et al., 18 Mar 2025).
1. Mathematical Formulation
The mLSTM family is characterized by the introduction of an intermediate multiplicative state that modulates both the candidate cell update and all gating signals. The classic recurrence equations for a single-layer mLSTM with sigmoid input gate are:
with , , and the elementwise sigmoid. All gates are thus functions of both the input and a multiplicative interaction between the input and previous hidden state (Krause et al., 2016).
Modern efficient mLSTM variants such as mLSTMsig reformulate the cell state as a learned matrix memory :
Here, are linearly projected from , and the plain sigmoid gating removes the need for auxiliary stabilizer states (Beck et al., 18 Mar 2025).
2. Distinguishing Features and Theoretical Implications
Unlike standard LSTMs, which drive gating by additive combinations of input and previous hidden state, mLSTM employs an elementwise product, imparting each input symbol the potential to induce a unique hidden-to-hidden transition. This expands the effective capacity of the model's state space. The dependence of the input, forget, and output gates on introduces strong input-state interactions, which permit context-dependent modulation of memory writing and retention.
Sigmoid input gating constrains the valid update range to , ensuring numerically stable input modulation and obviating the need for normalization or max-rescaling routines present in earlier exponential-gated architectures.
3. Comparison with Exponential-Gated and Other Linear RNNs
The transition from exponential-gated input (mLSTMexp) to sigmoid-gated input (mLSTMsig) eliminates two auxiliary states: the numerical normalizer and max state. This simplification yields several practical consequences:
- Reduces the forward compute and memory footprint by 20–30% per chunk.
- Simplifies the recurrence, removing the need for rescaling and auxiliary tracking in long-context runs.
- Yields kernel designs that halve the number of global memory barriers and significantly decrease non-tensor-core compute (Beck et al., 18 Mar 2025).
Relative to linear RNNs omitting input gates (e.g., FlashLinearAttention), mLSTM with sigmoid input gate maintains a richer gating structure while retaining competitive computational properties.
4. Practical Performance and Implementation
Empirical results establish mLSTM and its variants as strong sequence modelers. Key benchmarks are:
- On text8, character-level bits per char: .
- On Hutter Prize: bits/char (with 46M parameters).
- On WikiText-2: character-level entropy $1.26$ bits/char (byte-level), word-level perplexity $88.8$—on par with optimized word-level LSTMs (Krause et al., 2016).
- In next-token prediction tasks (DCLM) up to 1.4B parameters, mLSTMsig and mLSTMexp match perplexities to within $0.1$ PPL for all head configs (Beck et al., 18 Mar 2025).
The resource-optimized TFLA kernel for mLSTMsig achieves:
- Inference (65k tokens): faster than mLSTMexp, up to faster than FlashAttention 3 for .
- Training: faster than Mamba 2 for all , and $20$– faster than FlashAttention 3 at .
- Memory–runtime trade-off: at , batch 8, embedding 4096, optimal chunking yields lower GPU memory use than competing kernels, consistent with roofline/runtime analysis (Beck et al., 18 Mar 2025).
5. Kernel Design and Optimization
The Tiled Flash Linear Attention (TFLA) kernel for mLSTMsig adopts a two-level sequence parallelism:
- Sequence is chunked (Level 1); each chunk's initial state is computed and materialized by a recurrent kernel.
- Within the chunk (Level 2), matrix-multiply operations— with and chunk-initial —are block-tiled for efficient tensor-core utilization and SRAM management.
The absence of auxiliary normalizers/max-states in mLSTMsig enables fusion of intra- and inter-chunk computations. In Triton implementation, a thread block loads and processes blocks of , , and , applying sigmoid gates, and accumulates both attention-derived and memory contributions in a fully fused forward pass with fewer memory barriers (Beck et al., 18 Mar 2025).
6. Training Stability, Hyperparameters, and Best Practices
Experimental best practices for mLSTM recurrent architectures, including mLSTMsig:
- Hidden dimension (or up to for large-scale models).
- Embedding layer of size $400$ preceding first-layer projections.
- Adam optimizer, learning rate decaying from $0.001$ to .
- Scaled orthogonal initialization for recurrent matrices.
- Initial forget-gate bias (classic) or input-gate bias (for mLSTMsig, to suppress early gradient spikes).
- Truncated BPTT: length $200$–$250$.
- Variational dropout $0.2$–$0.5$ on embeddings and hidden paths; weight normalization on all recurrent matrices (Krause et al., 2016, Beck et al., 18 Mar 2025).
Both classical and sigmoid-gated mLSTM variants exhibit stable optimization and non-degrading accuracy up to large parameter and context scales.
7. Applications and Modeling Power
mLSTM with sigmoid input gate excels in character-level and byte-level language modeling due to the flexibility of input-dependent gating and the capacity to instantiate complex, context-sensitive transition functions. The architecture’s capacity for long-range dependency modeling is maintained through its extended gating structure. The expressiveness introduced by the multiplicative input path, without incurring additional instabilities of exponential gating, renders mLSTMsig a preferred primitive for efficient, scalable sequence modeling, particularly in long-context and high-throughput deployment scenarios (Krause et al., 2016, Beck et al., 18 Mar 2025).