Papers
Topics
Authors
Recent
Search
2000 character limit reached

Tiled Flash Linear Attention (TFLA)

Updated 3 March 2026
  • Tiled Flash Linear Attention (TFLA) is a hardware-optimized algorithm for linear attention and RNNs that uses a two-level tiling strategy to maximize compute and memory efficiency.
  • It generalizes tiling principles from FlashAttention and FlashLinearAttention by introducing intra-chunk tiling, thus removing chunk size limits and enhancing arithmetic intensity.
  • TFLA achieves state-of-the-art performance on modern GPUs, reducing kernel runtimes by over 25% and improving throughput for long-context sequence modeling.

Tiled Flash Linear Attention (TFLA) is a class of hardware-optimized kernel algorithms for linear attention and linear recurrent neural networks (RNNs) that achieves both high arithmetic intensity and memory efficiency for long-context sequence modeling. TFLA generalizes the tiling and chunking principles of FlashAttention and FlashLinearAttention, extending them with a second level of intra-chunk tiling to remove chunk size limits and further reduce memory input/output (I/O). TFLA has enabled state-of-the-art kernel runtimes for large-memory RNNs such as mLSTM and xLSTM, and allows linear attention models to realize their theoretical O(T)O(T) compute scaling on modern GPU and accelerator hardware (Beck et al., 18 Mar 2025, Kao et al., 2021, Hua et al., 2022).

1. Precedent Architectures: From Quadratic Attention to Flash Linear Attention

Classic attention mechanisms are characterized by quadratic complexity in sequence length due to the computation and storage of the full T×TT \times T attention matrix. FlashAttention introduced a tiling-based kernel fused across queries and keys, never materializing the full attention matrix in high-bandwidth memory (HBM), thus minimizing I/O but retaining O(T2)O(T^2) compute (Kao et al., 2021).

Linear attention methods—including kernelized attention and linear RNN formalisms—derive O(T)O(T) compute by exploiting associativity to reformulate attention as a streaming or chunkwise linear update, as in (Yang et al., 2023). FlashLinearAttention applies chunkwise parallelism: it splits the full sequence into N=T/LN = \lceil T/L \rceil chunks, materializes per-chunk RNN states, and parallelizes intra-chunk computation, but is practically limited by shared SRAM and memory bandwidth. When chunk size LL is small, arithmetic intensity (FLOPs/byte transferred) remains low and DRAM traffic becomes the primary bottleneck (Beck et al., 18 Mar 2025, Yang et al., 2023).

2. Core Algorithmic Principles of TFLA

TFLA removes the hardware-imposed ceiling on chunk size LL by introducing a second level of sequence parallelism—tile-parallelism—inside each chunk. This strategy transforms TFLA into a two-level parallel architecture:

  • Level 1 (Chunk-level parallelization): The sequence is partitioned into chunks of size LL. Each chunk's starting RNN state is persisted in DRAM.
  • Level 2 (Tile-level parallelization within chunk): Each chunk's intra-chunk operations (self-attention, gating, matmul) are parallelized across a 2D grid of thread blocks, processing the L×LL \times L local attention or recurrence matrix via tiled GEMMs. This enables LL to be set arbitrarily large, maximizing data reuse, tensor-core occupancy, and arithmetic intensity (Beck et al., 18 Mar 2025).

This two-level structure applies to both standard linear attention and gated/forgetful RNNs. All per-chunk states (e.g., Ck,nk,mkC_k, n_k, m_k for mLSTM variants) are updated via a recurrent kernel and then consumed by the intra-chunk parallel kernel. The tiling within chunk enables all critical matrix multiplications (e.g., QK,QV,KVQK^\top, QV, KV) to be performed on accelerator tensor cores while reading each data block precisely once from HBM into shared SRAM.

3. Mathematical Formulation and Implementation

For a generic linear RNN (e.g., mLSTM) under TFLA, the algorithm proceeds as follows:

  • Recurrent update (per chunk kk):

Ck=r~kCk1+(AkK(k))V(k),nk=r~knk1+(AkK(k))1C_k = \tilde{r}_k C_{k-1} + (A_k \odot K^{(k)})^\top V^{(k)},\quad n_k = \tilde{r}_k n_{k-1} + (A_k \odot K^{(k)})^\top \mathbf{1}

with r~k\tilde{r}_k, AkA_k determined by gate preactivations and normalization.

  • Tile-parallel intra-chunk computation:

Partition Q(k),K(k),V(k)Q^{(k)}, K^{(k)}, V^{(k)} into tiles of size suited to SRAM budget (e.g., BLhq×BdqkB_{Lhq}\times B_{dqk}) and process all output rows via batched matmuls:

S^(k)=(QK/dqk)D(k)\hat{S}^{(k)} = (Q K^\top / \sqrt{d_{qk}}) \odot D^{(k)}

H(k)=Inter-chunk:  (Q(k)/dqkr~k)Ck1    +    Intra-chunk:  S^(k)V(k)H^{(k)} = \text{Inter-chunk:}\; (Q^{(k)}/\sqrt{d_{qk}} \odot \tilde{r}_k) C_{k-1} \;\; + \;\; \text{Intra-chunk:}\; \hat{S}^{(k)} V^{(k)}

Final output:  Hout(k)=Hinter+exp(moldmk)Hintra\text{Final output:}\; H^{(k)}_{\text{out}} = H_{\text{inter}} + \exp(m_{\text{old}} - m_k) H_{\text{intra}}

H(k)=Hout(k)/NormalizerH^{(k)} = H^{(k)}_{\text{out}} / \text{Normalizer}

TFLA kernels are implemented using either Triton or custom CUDA, with all intensive loops scheduled in tile-parallel over the LL dimension (Beck et al., 18 Mar 2025, Yang et al., 2023).

4. Hardware Efficiency, Complexity, and Practical Performance

A distinguishing feature of TFLA is the raised arithmetic intensity (Ialg=FLOPs/BytesIOI_{\mathrm{alg}} = \mathrm{FLOPs} / \mathrm{Bytes}_{\mathrm{IO}}), which grows with chunk size LL as IalgO(L)I_{\mathrm{alg}} \sim O(L). As a result, for sufficiently large LL, TFLA transitions from being memory-bound to compute-bound on modern accelerators, matching or exceeding the performance roofline determined by peak FLOPs and effective memory bandwidth (Beck et al., 18 Mar 2025).

  • Memory footprint: Only O(d2)O(d^2) state per chunk is materialized; intra-chunk operations require O(Ld2)O(L d^2) working SRAM. DRAM traffic is reduced by a factor LL compared to chunkwise parallel schemes without tiling.
  • Optimal LL selection: LoptL_\text{opt} is dictated by the ratio of device bandwidth to peak compute, model dimensions, and the tile sizes fitting on SRAM:

Lopt2d2pqk+2Fcausal(d(1+pqk)+3)+1L_\text{opt} \approx \sqrt{\frac{2 d^2 p_{qk} + \ldots}{2 F_{\text{causal}} (d(1+p_{qk})+3) + 1}}

For NVIDIA H100 GPUs, L256L \approx 256 is typical for dhv=512,pqk=0.5d_{hv} = 512, p_{qk} = 0.5.

Kernel Scaling I/O Limit Peak Speedup (8K→65K)
FlashAttention O(T2)O(T^2) moderate <1×< 1 \times
FlashLinearAttention O(Td2)O(T d^2) I/O bound 12×1-2 \times
Tiled FLA (TFLA) O(Td2)O(T d^2) compute >26×>2-6 \times

On NVIDIA H100, TFLA kernels for mLSTM/xLSTM achieve state-of-the-art speed, outperforming both FlashAttention and chunkwise FlashLinearAttention at long context lengths (e.g., 25–30% kernel runtime reduction from L=64L=64 to L=128L=128; over 2×2\times faster than Mamba-2 at 8k65k8\text{k}\to65\text{k} token contexts) (Beck et al., 18 Mar 2025).

5. Applications to RNNs and Attention Variants

TFLA is applicable to a range of linear-time sequence models:

  • mLSTM and xLSTM: Matrix-memory LSTMs with both exponential and sigmoid input gates (the latter, mLSTMsig_\text{sig}, omits max logic and normalization, further simplifying TFLA kernels). All memory and arithmetic savings of TFLA are realized without numerical instability or loss of performance in language modeling (e.g., mLSTMsig_\text{sig} matches or slightly exceeds Llama2 PPL at fixed parameter budget) (Beck et al., 18 Mar 2025).
  • Linear Attention Transformers: TFLA subsumes the two-level tiling of Gated Linear Attention (GLA) Transformers, providing an efficient implementation for hardware- and memory-bound settings (Yang et al., 2023).
  • FLASH/FLAT architectures: TFLA encapsulates and extends the tiling and gating architecture used in FLASH for single-head “weak” attention, and the fusion of streaming, reduction, and tiling operations in FLAT (Hua et al., 2022, Kao et al., 2021).

6. Implementation and Practical Considerations

Implementing TFLA requires careful tile/block size selection, management of on-chip SRAM constraints, and parallel loop scheduling:

  • SRAM fitting: Tiles (blocks) \sim 4-8 KiB should fit in shared memory with all Q/K/V and accumulators loaded once per block (Yang et al., 2023, Beck et al., 18 Mar 2025).
  • Compiler and fusion constraints: Current Triton thread block models may make deep loop fusion and asynchronous prefetching challenging; custom CUDA versions can further improve throughput.
  • Gradient computation: Backward pass under TFLA requires four separate tiled matmuls, with parallelization axes reversed, entailing additional kernel engineering.
  • Integration: TFLA is not yet packaged as a generic library for arbitrary RNN cells, but foundational support exists for mLSTM, xLSTM, and linear attention layers.

7. Performance Benchmarks and Empirical Results

Empirical measurements confirm TFLA’s efficiency and scaling properties:

  • Edge TPU/FPGA/A100 GPU: End-to-end inference achieves 1.5×1.5\times6.8×6.8\times latency reduction and 20–60% energy savings per token over quadratic attention, with working memory and bandwidth scaling as O(Nd)O(Nd) (Kao et al., 2021).
  • Language Modeling: mLSTMsig_\text{sig} TFLA on H100 achieves perplexity (PPL) of 21.03 vs. 21.05 for Llama2 baselines at 4k context, maintaining or exceeding quality at strong throughput.
  • Typical kernel speedups: TFLA mLSTMexp (LL=128) is 25%25\% faster than fixed-chunk (LL=64) kernel; mLSTMsig_\text{sig} further improves runtime by \sim30% (Beck et al., 18 Mar 2025).

A plausible implication is that TFLA, by maximizing arithmetic intensity and supporting arbitrary chunk sizes, positions linear RNNs and attention models to exploit hardware scaling trends in FLOPs relative to bandwidth and to serve as efficient sequence modeling primitives in extremely long-context settings.


Key References:

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to Tiled Flash Linear Attention (TFLA).