Blockwise Flash Kernel for TFLA
- Blockwise Flash Kernel is a GPU algorithm for efficient computation in linear RNNs using two-level tiling and parallel processing.
- It mitigates arithmetic intensity and memory I/O bottlenecks through fused softmax-matmul operations and tunable chunk sizes.
- Performance benchmarks show significant speedups over state-of-the-art methods, enhancing both inference and training on modern accelerators.
Tiled Flash Linear Attention (TFLA), commonly referred to as the Blockwise Flash Kernel, is a two-level, sequence-parallel GPU kernel algorithm engineered for efficient computation in linear recurrent neural networks (RNNs) such as matrix-memory LSTMs (mLSTM/xLSTM). TFLA extends the chunkwise-parallel principles of Flash Linear Attention (FLA) by introducing an additional level of fine-grained tiling within each chunk, addressing bottlenecks in arithmetic intensity and memory I/O that hamper the scaling of long-context sequence models. This kernel enables both high arithmetic intensity and arbitrary large chunk sizes, resulting in significant performance improvements over prior state-of-the-art kernels, including Flash Attention, Linear Attention, and Mamba, on modern accelerators (Beck et al., 18 Mar 2025).
1. Mathematical Formulation
TFLA generalizes to any linear RNN with gating, but is illustrated concretely for the mLSTM/xLSTM cell. At each timestep , the cell maintains:
- Hidden state
- Matrix memory
- Per-timestep scalar gates: (input), (forget), with optional output gate
For the standard "mLSTMexp" recurrence with exponential input gate, the update equations are:
where denotes the sigmoid, and is either RMS- or Layer-Norm. The running max-state stabilizes the exp-input gate, ensuring numerical safety analogous to the softmax.
A simplified variant, "mLSTMsig," replaces the unbounded exponential gates with bounded sigmoidal gates,
where, by construction, no extra max-state or normalizer is required to prevent overflow.
2. Blockwise Tiling and Kernel Workflow
The TFLA kernel partitions the input sequence of length into chunks, each of length . Within each chunk, the attention-style matrix is further divided into tiles of size . The procedure in each chunk comprises:
- Recurrent Pass: Materializing the recurrent state from the previous chunk.
- Parallel Tiled Computation: For every tile within the chunk, the kernel accumulates , constructs cumulative-forget and input exponents, applies a numerically-safe softmax in blockwise fashion, fuses the result, and performs the final matmul with the value matrix block .
- Output Rescaling: Intra-chunk outputs are rescaled to align with the scale of the inter-chunk contribution.
- Inter-chunk Contribution: Outer-chunk state contributes via .
- Output Combination: The final chunk output is a weighted sum of intra- and inter-chunk outputs, normalized for stability.
Pseudocode implementing this tiling and fusion appears explicitly in the kernel’s specification (Beck et al., 18 Mar 2025).
3. GPU Implementation Details
TFLA exploits the modern GPU architecture using a three-dimensional thread-block grid:
- for coarse parallelism
- tiles for the “rows” of the intra-chunk computation
- tiles along the head/value dimension
Inside each thread-block (of size ), the kernel:
- Loads , , blocks into registers or shared memory
- Accumulates and computes softmax statistics with maximal data reuse
- Fuses the elementwise multiplications and the final value matrix multiplication (via tensor cores; e.g., Triton’s
tl.dotor CUDA’s WMMA) - Synchronizes only within tiles using
__syncthreads(), avoiding the need for global barrier synchronization
On-chip SRAM buffers ( floats per tile) are used for blockwise accumulators, reducing global memory traffic and maximizing arithmetic intensity by (a) reusing , reads, (b) fusing the safe-softmax with block accumulation, and (c) streaming efficiently.
4. Complexity and Memory Analysis
For a chunk of length and tile length :
- FLOPs/chunk (mLSTMsig):
- Recurrent pass:
- Intra-chunk:
- Inter-chunk matmul:
- Total:
- Memory I/O/chunk:
- Recurrent: reads , writes
- Parallel: reads , writes
- Total:
Comparison with other kernels:
- Flash Attention requires FLOPs and I/O.
- FLA requires FLOPs but suffers from extra I/O due to intermediate state storage for each chunk.
- TFLA eliminates this memory bottleneck via intra-chunk tiling and maximally fused compute.
5. Performance Benchmarks
Empirical evaluations were conducted with TFLA-mLSTMexp and TFLA-mLSTMsig on NVIDIA H100 GPUs for long-context benchmarks (, , , ):
- Inference (forward only):
- TFLA-mLSTMsig is approximately 30% faster than TFLA-mLSTMexp.
- TFLA-mLSTMsig is 20–40% faster than FlashAttention 3 and over 3× faster than Mamba 2.
- Training (forward+backward):
- TFLA-mLSTMsig achieves a 2× speedup compared to Mamba 2.
- TFLA-mLSTMsig matches or surpasses FlashAttention 3 performance for sequence lengths above 4k tokens.
Varying chunk size controls the trade-off between memory usage and runtime: smaller yields more stored states (higher memory, lower compute), while larger gives fewer stored states (lower memory, higher compute). On H100, optimal performance occurs for and (Beck et al., 18 Mar 2025).
6. Comparative Evaluation and Trade-offs
TFLA achieves several critical advancements:
- Support for arbitrarily large chunks without proportional increase in GPU state storage
- High arithmetic intensity by block-fusing softmax–matmul operations for peak tensor-core utilization
- A tunable chunk size parameter for direct control over DRAM-I/O versus compute-bound performance and peak memory
- Broad applicability as an efficient drop-in building block for long-context RNN architectures
A direct comparison is summarized below:
| Kernel | Arithmetic Complexity | Memory I/O | Bottleneck |
|---|---|---|---|
| Flash Attention | Quadratic time/memory | ||
| Linear FLA | Intermediate state materialization | ||
| TFLA (Blockwise) | None (maximal tiling) |
The table is strictly constructed from information present in (Beck et al., 18 Mar 2025).
7. Context and Applications
TFLA advances the practical deployment of linear RNNs for long-context sequence modeling. By enabling high-levels of parallelism and efficient hardware utilization, it supports accelerated training and inference for language modeling tasks at scales previously prohibitive due to memory and throughput constraints. The kernel’s two-level tiling design and flexible configuration make it a key enabling technology for integrating mLSTM or xLSTM units into pipelines intended for tasks requiring efficient long-sequence processing, particularly where state-of-the-art GPU acceleration is available (Beck et al., 18 Mar 2025).