TMA-Adaptive FP8 Grouped GEMM
- The paper introduces a dynamic TMA descriptor pool that eliminates padding overhead in FP8 grouped GEMM on NVIDIA Hopper GPUs.
- It employs a dual-phase TMA load/store mechanism to handle variable matrix sizes while maintaining strict memory alignment without runtime allocation.
- Experimental results demonstrate up to 20.4% speedup and 23.8% memory savings, crucial for efficient low-precision MoE model training and inference.
TMA-Adaptive FP8 Grouped GEMM is a kernel-level optimization for low-precision matrix multiplication on NVIDIA Hopper GPUs, which eliminates the padding overhead typically required for grouped general matrix-matrix multiplication (GEMM) using FP8 precision. By introducing a logarithmic-sized pool of preconfigured Tensor Memory Accelerator (TMA) descriptors and dual-phase memory load/store operations, this method dynamically adapts to variable group matrix dimensions without incurring extra memory or compute associated with conventional padding, while preserving strict alignment constraints imposed by Hopper’s hardware (Su et al., 7 Aug 2025).
1. Limitations of Conventional FP8 Grouped GEMM with Padding
Traditional FP8 grouped GEMM implementations, such as DeepGEMM, mandate that each expert group’s input and output matrices be padded to align their row count to a fixed multiple (commonly 128). This is necessitated by two hardware constraints: (1) TMA descriptors are static and not designed for group-wise variation in row counts, and (2) Hopper’s TMA enforces alignment requirements of 16 bytes for global memory addresses and 128 bytes for shared memory addresses during multidimensional transfers. As a consequence, groups often experience up to 127 additional rows per matrix, resulting in wasteful memory consumption (up to 23.8%) and computational slow-downs reaching 20% in extreme cases. The computational and bandwidth inefficiency becomes particularly acute as the number of groups increases and for smaller group sizes, due to the compounded effect of redundant reads, writes, and zero-element computation [(Su et al., 7 Aug 2025), Section 1, Figure 1].
2. Structure and Function of the TMA Descriptor Pool
TMA-Adaptive FP8 Grouped GEMM introduces an efficient method for handling arbitrary per-group matrix sizes through a logarithmic pool of TMA descriptors. For a tile shape defined by , the descriptor pool is initialized at kernel launch:
(Section 2.2, Equation (1))
For each group during execution, the size of the residual rows is computed as:
(Section 2.2, Equation (2))
The optimal descriptor for the residual is then dynamically selected via a single lookup into , ensuring minimal overhead and memory traffic independent of group size. This construction guarantees comprehensive coverage for all possible residual sizes with merely descriptors [(Su et al., 7 Aug 2025), Section 2.2].
3. Dual-Phase Load–Store Mechanism and Dynamic Descriptor Selection
Each group’s residual matrix portion is handled in exactly two TMA operations, regardless of :
- The first phase (Phase A) copies the largest possible power-of-two block of rows from shared to global memory, starting at the group’s residual offset.
- The second phase (Phase B) covers the remaining overlap, mapping the last shared rows to the end of the global matrix portion.
This ensures full coverage without holes or out-of-bounds accesses, as the overlapping region of rows resolves any boundary cases. The method relies on prebuilt descriptors, avoiding any runtime allocation. The compute phase itself (FP8 TensorCore GEMM with tiling across , , and ) proceeds as standard, interleaved with the adapted TMA stateless data transfers [(Su et al., 7 Aug 2025), Section 2.2, Appendix B, Figure 2].
4. Alignment-Compliant Memory Management
Strict adherence to Hopper TMA alignment restrictions is achieved through two key mechanisms (Section 2.3, Appendix A):
- Global Memory: Every matrix row stride is enforced to be a multiple of bytes. If the starting address misaligns (not divisible by 16), extra upstream rows are fetched until alignment:
(Equation (3))
- Shared Memory: Each block ensures , so is always 128-byte aligned. Both phases of TMA therefore always land on legal 128-byte boundaries, irrespective of the value of .
These mechanisms obviate the need for zero-padding on both the global and shared memory levels [(Su et al., 7 Aug 2025), Section 2.3, Appendix A].
5. Implementation Details on NVIDIA Hopper Architecture
TMA-Adaptive FP8 Grouped GEMM is deployed as a CUDA kernel tailored for NVIDIA H800 GPUs, using CUDA 12.6 and PyTorch 2.6.0. Each group maps to a single threadblock, often organized as 4 × 2 warps (8 warps) to utilize warp-group TMA. The descriptor pool is persistently stored as a constant-memory array. Runtime logarithm (for descriptor selection) is computed efficiently with the 31 - __clz(res) intrinsic. TMA API invocations include:
tmaDescCreate(&desc,...)for descriptor setuptmaMemcpyAsync(...)for memory transfers__syncwarp()andtmaWaitAll()for synchronization and ordering
FP8 TensorCore GEMM uses 1×128 and 128×128 scaling for compute [(Su et al., 7 Aug 2025), Section 3.1, Table 1].
6. Experimental Evaluation and Performance Outcomes
Evaluation against the DeepGEMM baseline (with explicit per-group padding to the 128 boundary) covered a comprehensive workload sweep: matrices with , group counts , and total up to 65k, with random per-group sizes (Appendix C). Key empirical findings include:
- Acceleration: Measured speedup over baseline ranged from 1.7% to 20.4%, with stronger effects for smaller and larger group counts. Acceleration correlated weakly positively with (), and strongly negatively with ().
- Memory Reduction: DRAM usage for , , and reduced by up to 23.8% (notably at k, 32 groups). Memory savings inversely correlated with (), and positively with group count ().
- Numerical Equivalence: The result matrices exactly matched baseline results after removal of padded rows, confirming strict numerical fidelity of the dual-phase TMA approach (bitwise identical results for valid entries).
Summary table of key performance results:
| Metric | Observed Value | Correlation |
|---|---|---|
| Speedup | 1.7% – 20.4% | with , with |
| Memory Saving | up to 23.8% | with , with group count |
| Numerical Error | Bitwise identical | – |
[(Su et al., 7 Aug 2025), Section 3, Figure 1]
7. Applications in Low-Precision MoE Training and Inference
Grouped GEMM with variable per expert is fundamental for Mixture-of-Experts (MoE) architectures, particularly in recent LLMs where sequences are dynamically routed into specialist subnets. Key deployment scenarios include:
- Inference: Dynamic batching produces widely differing across requests.
- Training: Pipeline and tensor parallel strategies generate residual-sized matrices per group.
By removing the necessity for host- or kernel-side padding (previously consuming 2000 GB/s DRAM bandwidth in worst cases), TMA-Adaptive FP8 Grouped GEMM provides the following operational advantages:
- Reduction in end-to-end inference latency due to fewer extra matrix rows.
- Lower DRAM pressure, enabling increased maximum batch sizes or larger expert counts per GPU.
- Drop-in compatibility with extant FP8 GEMM libraries, requiring only replacement of the grouped GEMM kernel with no changes to upstream host routing logic.
This architecture-compliant, zero-padding solution advances the state-of-the-art in low-precision MoE model training and inference for NVIDIA Hopper platforms, offering consistent improvements in both throughput (1.7%–20.4%) and memory utilization (up to 23.8%) without compromising accuracy [(Su et al., 7 Aug 2025), Section 4].