Papers
Topics
Authors
Recent
Search
2000 character limit reached

FlashDecoding++: GPU LLM Inference Engine

Updated 1 March 2026
  • FlashDecoding++ is a GPU-centric LLM inference engine that optimizes decoding and prefill stages through three integrated kernel-level techniques.
  • It employs asynchronized softmax, flat GEMM with double buffering, and heuristic dataflow adaptation to overcome GPU memory and compute bottlenecks.
  • It achieves up to 4.86× speedup on NVIDIA GPUs without requiring LLM-specific source code changes, setting a state-of-the-art benchmark for LLM inference.

FlashDecoding++ is a GPU-centric LLM inference engine specifically designed to accelerate the decode and prefill stages of Transformer-based models through three tightly integrated kernel-level optimizations: asynchronized softmax, flat GEMM with double buffering, and dynamic, heuristic dataflow adaptation. It targets the inherent computational and memory bottlenecks in standard LLM inference, addressing under-utilization caused by flat matrix shapes and minimizing synchronization overheads typically associated with partial reductions in softmax operations. FlashDecoding++ achieves up to 4.86× speedup on NVIDIA GPUs and 2.18× on AMD GPUs compared to Hugging Face baselines without requiring LLM-specific source code changes. Its architecture and methodology are now referenced as state-of-the-art for high-throughput LLM inference, serving as the foundation for subsequent engine optimizations and integration with advanced decoding strategies (Hong et al., 2023).

1. Core Design and Rationale

FlashDecoding++ was motivated by three persistent challenges in LLM inference on GPUs:

  • Synchronized Softmax Reduction: Conventional implementations of the softmax operation over long sequences require synchronization (e.g., block-wise or grid-wise barriers) to correctly aggregate partial maxima and exponentiated sums. This synchronization consumes approximately 20% of the attention computation time.
  • Under-Utilized Flat GEMM Operations: During decoding (especially with batch size MM substantially lower than feature dimensions K,NK, N), the main matrix multiplications are extremely imbalanced (“flat” MK,NM \ll K, N). Most vendor libraries (cuBLAS, CUTLASS) require padding MM up to 64, leaving over 80% of Tensor Core compute units idle for M=1M=1 or $4$; this results in more than 50% performance loss.
  • Inefficiency from Static Dataflows: The optimal execution kernel is a function of the shape (M,K,N)(M, K, N), batch size, and GPU architecture. Relying on a single GEMM implementation for all cases incurs up to 50% performance loss for certain matrix shapes seen during LLM inference (Hong et al., 2023).

2. Asynchronized Softmax with Unified Max Value

The asynchronized softmax kernel eliminates the synchronization necessary for correct rowwise softmax normalization by exploiting the mathematical identity softmax(x)=softmax(xϕ)\mathrm{softmax}(x) = \mathrm{softmax}(x - \phi) for any scalar ϕ\phi. Instead of performing a two-phase reduction while finding the global maximum m(x)m(x), it substitutes with a fixed, empirically safe value ϕ=6.5\phi = 6.5 (which covers >99.99%>99.99\% of attention logits for Llama2-7B). Each tile computes local accumulations: αj=itile jexp(xi(j)ϕ),βj=itile jvi(j)exp(xi(j)ϕ)\alpha_j = \sum_{i \in \text{tile }j} \exp(x_i^{(j)} - \phi), \quad \beta_j = \sum_{i \in \text{tile }j} v_i^{(j)} \exp(x_i^{(j)} - \phi) Accumulators are finalized by a single atomic or warp-level reduction. If any logit in a tile exceeds a preset range, the kernel reverts to standard synchronized softmax. This approach increases prefill speed by 1.18× and decode speed by 1.14× over the best two-stage partial softmax previous implementations (Hong et al., 2023).

3. Flat GEMM Optimization with Double Buffering

In the decode phase, FlashDecoding++ utilizes the native 8×8×48 \times 8 \times 4 multiply-accumulate units of NVIDIA Ampere/Hopper Tensor Cores by padding the batch dimension only up to M=8M=8, unlike the default M=64M=64. For M8M \leq 8, it applies double buffering at the shared-memory level: one buffer loads the next tile of AA and BB while the other is consumed for computation, thus completely overlapping compute with memory load. Empirically, this achieves over 90% of the theoretical Tensor-Core throughput in such regime (Hong et al., 2023).

The operational intensity (OI) is optimized via tile sizes in KK and BNB_N. The double-buffered approach yields up to 52% higher GEMV/GEMM throughput relative to the zero-padded (M=64M=64) baseline.

4. Heuristic Dataflow and Runtime Adaptation

FlashDecoding++ introduces a dynamic LUT-indexed scheduling to select among three GEMM kernels at runtime, based on observed batch size MM and distinctive projection shapes in Transformer layers. Each kernel is best for a regime:

  • FastGEMV (CUDA Core): optimal for M=1M=1
  • Flat-GEMM Double Buffered (Tensor Core): optimal for 2M<M2(K,N)2 \leq M < M_2(K,N)
  • CUTLASS GEMM (Tensor Core): optimal for MM2(K,N)M \geq M_2(K,N)

Offline profiling establishes the M1(K,N)M_1(K,N) and M2(K,N)M_2(K,N) inflection points for switch-over. One LUT lookup at inference enables up to 29% additional speedup compared to using a single static kernel (Hong et al., 2023).

5. Quantitative Performance and Integration

Performance Summary Table

Metric NVIDIA A100 RTX 3090 AMD RX7900XTX AMD MI210 SOTA Overhead Reduction
Decode phase speedup (vs HF) Up to 4.86× 4.2–4.5× Up to 2.27× Up to 3.93× 1.37× vs SOTA engines
Prefill phase speedup (vs HF) Up to 1.40× 1.08–1.10×
Asynch softmax kernel speedup 1.18× (prefill), 1.14× (decode)
Flat-GEMM+double buffering speedup Up to 52%
Heuristic dataflow adaptation Up to 29%

Available as a PyTorch extension with C++/CUDA (NVIDIA) and ROCm (AMD) backends, FlashDecoding++ is a drop-in replacement for TransformerDecoderLayer and is compatible with Hugging Face APIs. It manages its own KV cache and requires no changes to upstream model source code, facilitating easy adoption across mainstream models such as Llama2-7B/13B, OPT-6.7B, and ChatGLM2-6B (Hong et al., 2023).

6. Limitations and Future Directions

FlashDecoding++ is currently single-GPU and does not support multi-node or pipeline-parallel deployment. The unified-ϕ\phi softmax depends on the logit distribution remaining in empirically established ranges, which may require fallbacks for extreme input cases. At present, there is no special handling for sparse-weight or structured-pruned models.

Planned extensions include:

  • Multi-GPU support with allreduce or NCCL pipelining of attention operations
  • Hardware-accelerated sparse/dense kernel mixing for pruned transformer variants
  • Adaptive precision inside the asynch softmax kernel (FP16, BF16, INT8)
  • Compiler-driven kernel fusion via Triton or MLIR-based autotuning for softmax/GEMM

A plausible implication is that further gains may arise from tighter operator fusion and distributed attention kernel integration.

7. Relationship to Successor and Parallel Techniques

FlashDecoding++ serves as a core decode-stage kernel for modern LLM inference. However, complementary approaches such as FlashForge (Wang et al., 23 May 2025) extend efficiency to scenarios involving prefix sharing, by merging key-value loads across requests in shared-prefix trees, reaching up to 120.9× global memory access reduction and 3.8× end-to-end throughput improvement compared to vLLM. In the multimodal and speculative decoding domain, frameworks like FLASH (Wang et al., 19 May 2025) and its projected “++” variant leverage semi-autoregressive block speculation, adaptive draft composition, and dynamic acceptance thresholds to achieve up to 2.68× speedup in video captioning, indicating that the future of high-performance LLM and LMM inference will combine low-level kernel engineering with higher-level dynamic and speculative execution methodologies.

References

Topic to Video (Beta)

No one has generated a video about this topic yet.

Whiteboard

No one has generated a whiteboard explanation for this topic yet.

Follow Topic

Get notified by email when new papers are published related to FlashDecoding++.