Papers
Topics
Authors
Recent
2000 character limit reached

Blockwise Flash Kernel for TFLA

Updated 10 December 2025
  • Blockwise Flash Kernel is a GPU algorithm for efficient computation in linear RNNs using two-level tiling and parallel processing.
  • It mitigates arithmetic intensity and memory I/O bottlenecks through fused softmax-matmul operations and tunable chunk sizes.
  • Performance benchmarks show significant speedups over state-of-the-art methods, enhancing both inference and training on modern accelerators.

Tiled Flash Linear Attention (TFLA), commonly referred to as the Blockwise Flash Kernel, is a two-level, sequence-parallel GPU kernel algorithm engineered for efficient computation in linear recurrent neural networks (RNNs) such as matrix-memory LSTMs (mLSTM/xLSTM). TFLA extends the chunkwise-parallel principles of Flash Linear Attention (FLA) by introducing an additional level of fine-grained tiling within each chunk, addressing bottlenecks in arithmetic intensity and memory I/O that hamper the scaling of long-context sequence models. This kernel enables both high arithmetic intensity and arbitrary large chunk sizes, resulting in significant performance improvements over prior state-of-the-art kernels, including Flash Attention, Linear Attention, and Mamba, on modern accelerators (Beck et al., 18 Mar 2025).

1. Mathematical Formulation

TFLA generalizes to any linear RNN with gating, but is illustrated concretely for the mLSTM/xLSTM cell. At each timestep tt, the cell maintains:

  • Hidden state htRdh_t \in \mathbb{R}^d
  • Matrix memory MtRdq×dM_t \in \mathbb{R}^{d_q \times d}
  • Per-timestep scalar gates: iti_t (input), ftf_t (forget), with optional output gate oto_t

For the standard "mLSTMexp" recurrence with exponential input gate, the update equations are: mt=max(logσ(f~t)+mt1,  i~t)m_t = \max\left(\log \sigma(\tilde f_t) + m_{t-1},\; \tilde i_t\right)

ft=exp(logσ(f~t)+mt1mt),it=exp(i~tmt)f_t = \exp\left(\log \sigma(\tilde f_t) + m_{t-1} - m_t\right), \quad i_t = \exp\left(\tilde i_t - m_t\right)

Mt=ftMt1+it(ktvt)M_t = f_t \cdot M_{t-1} + i_t \cdot (k_t v_t^\top)

h~t=Mt(qt/dq),ht=σ(o~t)NORM(h~t)\tilde h_t = M_t^\top (q_t/\sqrt{d_q}), \quad h_t = \sigma(\tilde o_t) \odot \mathrm{NORM}(\tilde h_t)

where σ\sigma denotes the sigmoid, and NORM()\mathrm{NORM}(\cdot) is either RMS- or Layer-Norm. The running max-state mtm_t stabilizes the exp-input gate, ensuring numerical safety analogous to the softmax.

A simplified variant, "mLSTMsig," replaces the unbounded exponential gates with bounded sigmoidal gates,

ft=σ(f~t),it=σ(i~t)f_t = \sigma(\tilde f_t), \qquad i_t = \sigma(\tilde i_t)

Mt=ftMt1+it(ktvt)M_t = f_t \cdot M_{t-1} + i_t \cdot (k_t v_t^\top)

ht=σ(o~t)NORM(Mt(qt/dq))h_t = \sigma(\tilde o_t) \odot \mathrm{NORM}(M_t^\top (q_t/\sqrt{d_q}))

where, by construction, no extra max-state or normalizer is required to prevent overflow.

2. Blockwise Tiling and Kernel Workflow

The TFLA kernel partitions the input sequence of length TT into Nc=T/CN_c = \lceil T/C \rceil chunks, each of length CC. Within each chunk, the C×CC \times C attention-style matrix is further divided into (C/T)2(C/T)^2 tiles of size T×TT \times T. The procedure in each chunk comprises:

  1. Recurrent Pass: Materializing the recurrent state Mk1M_{k-1} from the previous chunk.
  2. Parallel Tiled Computation: For every T×TT \times T tile within the chunk, the kernel accumulates QKQK^{\top}, constructs cumulative-forget and input exponents, applies a numerically-safe softmax in blockwise fashion, fuses the result, and performs the final matmul with the value matrix block VV.
  3. Output Rescaling: Intra-chunk outputs are rescaled to align with the scale of the inter-chunk contribution.
  4. Inter-chunk Contribution: Outer-chunk state contributes via Hinter=QMk1H_\mathrm{inter} = \overline{Q} \cdot M_{k-1}.
  5. Output Combination: The final chunk output is a weighted sum of intra- and inter-chunk outputs, normalized for stability.

Pseudocode implementing this tiling and fusion appears explicitly in the kernel’s specification (Beck et al., 18 Mar 2025).

3. GPU Implementation Details

TFLA exploits the modern GPU architecture using a three-dimensional thread-block grid:

  • Nbatch×Nhead×NchunkN_\mathrm{batch} \times N_\mathrm{head} \times N_\mathrm{chunk} for coarse parallelism
  • C/TC/T tiles for the “rows” of the intra-chunk computation
  • dh/Td_h/T tiles along the head/value dimension

Inside each thread-block (of size T×TT\times T), the kernel:

  • Loads QQ, KK, VV blocks into registers or shared memory
  • Accumulates QKQK^\top and computes softmax statistics with maximal data reuse
  • Fuses the elementwise multiplications and the final value matrix multiplication (via tensor cores; e.g., Triton’s tl.dot or CUDA’s WMMA)
  • Synchronizes only within tiles using __syncthreads(), avoiding the need for global barrier synchronization

On-chip SRAM buffers (T2\sim T^2 floats per tile) are used for blockwise accumulators, reducing global memory traffic and maximizing arithmetic intensity by (a) reusing QQ, KK reads, (b) fusing the safe-softmax with block accumulation, and (c) streaming VV efficiently.

4. Complexity and Memory Analysis

For a chunk of length CC and tile length TT:

  • FLOPs/chunk (mLSTMsig):
    • Recurrent pass: O(dqkdh+Cdqkdh)O(d_{qk} \cdot d_h + C \cdot d_{qk} \cdot d_h)
    • Intra-chunk: tiles[2Tdqkdh+3T2+O(T2)]=O(Cdqkdh+C2)\sum_{\mathrm{tiles}} [2T d_{qk} d_h + 3T^2 + O(T^2)] = O(C d_{qk} d_h + C^2)
    • Inter-chunk matmul: O(Cdqkdh)O(C d_{qk} d_h)
    • Total: O(Cdqkdh+C2)O(C d_{qk} d_h + C^2)
  • Memory I/O/chunk:
    • Recurrent: reads C(dqk+dh)C(d_{qk}+d_h), writes dqkdhd_{qk} d_h
    • Parallel: reads C(2dqk+dh)C(2d_{qk}+d_h), writes CdhC d_h
    • Total: O(Cd+d2)O(C d + d^2)

Comparison with other kernels:

  • Flash Attention requires O(T2d)O(T^2d) FLOPs and O(T2)O(T^2) I/O.
  • FLA requires O(Cd2+C2)O(C d^2 + C^2) FLOPs but suffers from O(T/Cd2)O(T/C d^2) extra I/O due to intermediate state storage for each chunk.
  • TFLA eliminates this memory bottleneck via intra-chunk tiling and maximally fused compute.

5. Performance Benchmarks

Empirical evaluations were conducted with TFLA-mLSTMexp and TFLA-mLSTMsig on NVIDIA H100 GPUs for long-context benchmarks (embedding-dim=4096\text{embedding-dim}=4096, seq-len=65536\text{seq-len}=65536, head-dim=256\text{head-dim}=256, batch×T=65536\text{batch}\times T=65536):

  • Inference (forward only):
    • TFLA-mLSTMsig is approximately 30% faster than TFLA-mLSTMexp.
    • TFLA-mLSTMsig is 20–40% faster than FlashAttention 3 and over 3× faster than Mamba 2.
  • Training (forward+backward):
    • TFLA-mLSTMsig achieves a 2× speedup compared to Mamba 2.
    • TFLA-mLSTMsig matches or surpasses FlashAttention 3 performance for sequence lengths above 4k tokens.

Varying chunk size CC controls the trade-off between memory usage and runtime: smaller CC yields more stored states (higher memory, lower compute), while larger CC gives fewer stored states (lower memory, higher compute). On H100, optimal performance occurs for C=128256C=128\ldots256 and T=3264T=32\ldots64 (Beck et al., 18 Mar 2025).

6. Comparative Evaluation and Trade-offs

TFLA achieves several critical advancements:

  • Support for arbitrarily large chunks without proportional increase in GPU state storage
  • High arithmetic intensity by block-fusing softmax–matmul operations for peak tensor-core utilization
  • A tunable chunk size parameter CC for direct control over DRAM-I/O versus compute-bound performance and peak memory
  • Broad applicability as an efficient drop-in building block for long-context RNN architectures

A direct comparison is summarized below:

Kernel Arithmetic Complexity Memory I/O Bottleneck
Flash Attention O(T2d)O(T^2 d) O(T2)O(T^2) Quadratic time/memory
Linear FLA O(Cd2+C2)O(C d^2 + C^2) O(T/Cd2)O(T/C d^2) Intermediate state materialization
TFLA (Blockwise) O(Cdqkdh+C2)O(C d_{qk} d_h + C^2) O(Cd+d2)O(C d + d^2) None (maximal tiling)

The table is strictly constructed from information present in (Beck et al., 18 Mar 2025).

7. Context and Applications

TFLA advances the practical deployment of linear RNNs for long-context sequence modeling. By enabling high-levels of parallelism and efficient hardware utilization, it supports accelerated training and inference for language modeling tasks at scales previously prohibitive due to memory and throughput constraints. The kernel’s two-level tiling design and flexible configuration make it a key enabling technology for integrating mLSTM or xLSTM units into pipelines intended for tasks requiring efficient long-sequence processing, particularly where state-of-the-art GPU acceleration is available (Beck et al., 18 Mar 2025).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Whiteboard

Follow Topic

Get notified by email when new papers are published related to Blockwise Flash Kernel.