Tiled Flash Linear Attention (TFLA)
- Tiled Flash Linear Attention (TFLA) is a hardware-optimized algorithm for linear attention and RNNs that uses a two-level tiling strategy to maximize compute and memory efficiency.
- It generalizes tiling principles from FlashAttention and FlashLinearAttention by introducing intra-chunk tiling, thus removing chunk size limits and enhancing arithmetic intensity.
- TFLA achieves state-of-the-art performance on modern GPUs, reducing kernel runtimes by over 25% and improving throughput for long-context sequence modeling.
Tiled Flash Linear Attention (TFLA) is a class of hardware-optimized kernel algorithms for linear attention and linear recurrent neural networks (RNNs) that achieves both high arithmetic intensity and memory efficiency for long-context sequence modeling. TFLA generalizes the tiling and chunking principles of FlashAttention and FlashLinearAttention, extending them with a second level of intra-chunk tiling to remove chunk size limits and further reduce memory input/output (I/O). TFLA has enabled state-of-the-art kernel runtimes for large-memory RNNs such as mLSTM and xLSTM, and allows linear attention models to realize their theoretical compute scaling on modern GPU and accelerator hardware (Beck et al., 18 Mar 2025, Kao et al., 2021, Hua et al., 2022).
1. Precedent Architectures: From Quadratic Attention to Flash Linear Attention
Classic attention mechanisms are characterized by quadratic complexity in sequence length due to the computation and storage of the full attention matrix. FlashAttention introduced a tiling-based kernel fused across queries and keys, never materializing the full attention matrix in high-bandwidth memory (HBM), thus minimizing I/O but retaining compute (Kao et al., 2021).
Linear attention methods—including kernelized attention and linear RNN formalisms—derive compute by exploiting associativity to reformulate attention as a streaming or chunkwise linear update, as in (Yang et al., 2023). FlashLinearAttention applies chunkwise parallelism: it splits the full sequence into chunks, materializes per-chunk RNN states, and parallelizes intra-chunk computation, but is practically limited by shared SRAM and memory bandwidth. When chunk size is small, arithmetic intensity (FLOPs/byte transferred) remains low and DRAM traffic becomes the primary bottleneck (Beck et al., 18 Mar 2025, Yang et al., 2023).
2. Core Algorithmic Principles of TFLA
TFLA removes the hardware-imposed ceiling on chunk size by introducing a second level of sequence parallelism—tile-parallelism—inside each chunk. This strategy transforms TFLA into a two-level parallel architecture:
- Level 1 (Chunk-level parallelization): The sequence is partitioned into chunks of size . Each chunk's starting RNN state is persisted in DRAM.
- Level 2 (Tile-level parallelization within chunk): Each chunk's intra-chunk operations (self-attention, gating, matmul) are parallelized across a 2D grid of thread blocks, processing the local attention or recurrence matrix via tiled GEMMs. This enables to be set arbitrarily large, maximizing data reuse, tensor-core occupancy, and arithmetic intensity (Beck et al., 18 Mar 2025).
This two-level structure applies to both standard linear attention and gated/forgetful RNNs. All per-chunk states (e.g., for mLSTM variants) are updated via a recurrent kernel and then consumed by the intra-chunk parallel kernel. The tiling within chunk enables all critical matrix multiplications (e.g., ) to be performed on accelerator tensor cores while reading each data block precisely once from HBM into shared SRAM.
3. Mathematical Formulation and Implementation
For a generic linear RNN (e.g., mLSTM) under TFLA, the algorithm proceeds as follows:
- Recurrent update (per chunk ):
with , determined by gate preactivations and normalization.
- Tile-parallel intra-chunk computation:
Partition into tiles of size suited to SRAM budget (e.g., ) and process all output rows via batched matmuls:
TFLA kernels are implemented using either Triton or custom CUDA, with all intensive loops scheduled in tile-parallel over the dimension (Beck et al., 18 Mar 2025, Yang et al., 2023).
4. Hardware Efficiency, Complexity, and Practical Performance
A distinguishing feature of TFLA is the raised arithmetic intensity (), which grows with chunk size as . As a result, for sufficiently large , TFLA transitions from being memory-bound to compute-bound on modern accelerators, matching or exceeding the performance roofline determined by peak FLOPs and effective memory bandwidth (Beck et al., 18 Mar 2025).
- Memory footprint: Only state per chunk is materialized; intra-chunk operations require working SRAM. DRAM traffic is reduced by a factor compared to chunkwise parallel schemes without tiling.
- Optimal selection: is dictated by the ratio of device bandwidth to peak compute, model dimensions, and the tile sizes fitting on SRAM:
For NVIDIA H100 GPUs, is typical for .
| Kernel | Scaling | I/O Limit | Peak Speedup (8K→65K) |
|---|---|---|---|
| FlashAttention | moderate | ||
| FlashLinearAttention | I/O bound | ||
| Tiled FLA (TFLA) | compute |
On NVIDIA H100, TFLA kernels for mLSTM/xLSTM achieve state-of-the-art speed, outperforming both FlashAttention and chunkwise FlashLinearAttention at long context lengths (e.g., 25–30% kernel runtime reduction from to ; over faster than Mamba-2 at token contexts) (Beck et al., 18 Mar 2025).
5. Applications to RNNs and Attention Variants
TFLA is applicable to a range of linear-time sequence models:
- mLSTM and xLSTM: Matrix-memory LSTMs with both exponential and sigmoid input gates (the latter, mLSTM, omits max logic and normalization, further simplifying TFLA kernels). All memory and arithmetic savings of TFLA are realized without numerical instability or loss of performance in language modeling (e.g., mLSTM matches or slightly exceeds Llama2 PPL at fixed parameter budget) (Beck et al., 18 Mar 2025).
- Linear Attention Transformers: TFLA subsumes the two-level tiling of Gated Linear Attention (GLA) Transformers, providing an efficient implementation for hardware- and memory-bound settings (Yang et al., 2023).
- FLASH/FLAT architectures: TFLA encapsulates and extends the tiling and gating architecture used in FLASH for single-head “weak” attention, and the fusion of streaming, reduction, and tiling operations in FLAT (Hua et al., 2022, Kao et al., 2021).
6. Implementation and Practical Considerations
Implementing TFLA requires careful tile/block size selection, management of on-chip SRAM constraints, and parallel loop scheduling:
- SRAM fitting: Tiles (blocks) 4-8 KiB should fit in shared memory with all Q/K/V and accumulators loaded once per block (Yang et al., 2023, Beck et al., 18 Mar 2025).
- Compiler and fusion constraints: Current Triton thread block models may make deep loop fusion and asynchronous prefetching challenging; custom CUDA versions can further improve throughput.
- Gradient computation: Backward pass under TFLA requires four separate tiled matmuls, with parallelization axes reversed, entailing additional kernel engineering.
- Integration: TFLA is not yet packaged as a generic library for arbitrary RNN cells, but foundational support exists for mLSTM, xLSTM, and linear attention layers.
7. Performance Benchmarks and Empirical Results
Empirical measurements confirm TFLA’s efficiency and scaling properties:
- Edge TPU/FPGA/A100 GPU: End-to-end inference achieves – latency reduction and 20–60% energy savings per token over quadratic attention, with working memory and bandwidth scaling as (Kao et al., 2021).
- Language Modeling: mLSTM TFLA on H100 achieves perplexity (PPL) of 21.03 vs. 21.05 for Llama2 baselines at 4k context, maintaining or exceeding quality at strong throughput.
- Typical kernel speedups: TFLA mLSTMexp (=128) is faster than fixed-chunk (=64) kernel; mLSTM further improves runtime by 30% (Beck et al., 18 Mar 2025).
A plausible implication is that TFLA, by maximizing arithmetic intensity and supporting arbitrary chunk sizes, positions linear RNNs and attention models to exploit hardware scaling trends in FLOPs relative to bandwidth and to serve as efficient sequence modeling primitives in extremely long-context settings.
Key References:
- "Tiled Flash Linear Attention: More Efficient Linear RNN and xLSTM Kernels" (Beck et al., 18 Mar 2025)
- "Gated Linear Attention Transformers with Hardware-Efficient Training" (Yang et al., 2023)
- "Transformer Quality in Linear Time" (Hua et al., 2022)
- "FLAT: An Optimized Dataflow for Mitigating Attention Bottlenecks" (Kao et al., 2021)