SplitK Partial Fusion in Quantized GEMM
- The technique fuses on-the-fly dequantization with GEMM into a single kernel, significantly improving compute resource utilization and memory bandwidth for quantized inference.
- SplitK partial fusion decomposes the K-dimension into multiple segments, enabling efficient processing of 'skinny' matrices typical in large foundation models like LLaMA.
- Benchmark results on NVIDIA A100 and H100 GPUs demonstrate speedups up to 295% compared to conventional tiling, though challenges include atomic contention and sensitive parameter tuning.
SplitK partial fusion is a technique for accelerating matrix multiplication involving quantized weights, specifically targeting W4A16 quantized inference workloads. It fuses on-the-fly dequantization and general matrix multiplication (GEMM) into a single kernel using a SplitK work decomposition within the Triton programming model. This approach is particularly effective for "skinny" matrix multiplications, such as those found in large foundation models (e.g., LLaMA), where the activation matrix is thin (). SplitK partial fusion improves compute resource utilization and memory bandwidth, delivering speedups of up to 295% compared to conventional data-parallel GEMM tiling, with average boosts of 65% on NVIDIA A100 GPUs and 124% on H100 GPUs (Hoque et al., 2024).
1. Core Algorithm: SplitK Partial Fusion
SplitK partial fusion modifies the standard matrix multiplication workflow for with:
- as an FP16 matrix of shape ,
- as a W4A16-quantized matrix (eight 4-bit weights packed into each int32), of shape ,
- Per-column scale () and zero-point () parameters for .
Instead of conventional 2D tiling over , SplitK launches a 3D grid spanning , where is the number of splits along the -axis (the "split_k" parameter). Each thread block computes a partial sum over a disjoint segment of , accumulating into a tile and performing an atomic add to the output matrix .
The fused kernel executes the following for each tile:
- Parameters: block sizes and .
- Launches a grid: , .
- Each kernel instance determines its tile position and range in using
where is the number of iterations.
- For each :
- Loads an tile (FP16) and a packed, quantized tile (int32).
- Dequantizes in-register using the provided scales and zero-points.
- Performs a fused matrix multiply-accumulate.
- After summing over its -segment, atomically adds the local accumulator into .
Sample pseudocode:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
@triton.jit def splitk_fused_matmul(A_ptr, B_ptr, C_ptr, scales_ptr, zeros_ptr, M, N, K, block_m, block_n, block_k, split_k): pid = tl.program_id(0) pid_k = tl.program_id(1) n_tiles = tl.cdiv(N, block_n) m_tile = pid // n_tiles n_tile = pid % n_tiles m_off = m_tile * block_m n_off = n_tile * block_n P = tl.cdiv(K, block_k * split_k) acc = tl.zeros((block_m, block_n), tl.float32) for p in range(P): k_off = (p * split_k + pid_k) * block_k if k_off >= K: break a = tl.load(A_ptr + (m_off + tl.arange(0, block_m))[:,None] * K + (k_off + tl.arange(0, block_k))[None,:], mask=…, other=0.0, dtype=tl.float16) packed = tl.load(B_ptr + (k_off + tl.arange(0, block_k))[:,None] * N + (n_off + tl.arange(0, block_n))[None,:], mask=…, other=0, dtype=tl.int32) b = dequantize_4bit(packed, scales_ptr + n_off, zeros_ptr + n_off) acc += tl.dot(a, b) tl.atomic_add(C_ptr + m_off * N + n_off, acc) |
2. Mathematical Formulation: Dequantization and SplitK Slicing
Dequantization from 4-bit weights to FP16 is performed per element via:
where:
- is the quantized weight,
- is the scale for output column ,
- is the zero-point for column .
For each 32-bit packed word holding 8 lanes: \begin{align} w{(j)}_q &= (P_i \gg 4j) \land 0xF \ w{(j)}_{\mathrm{fp}} &= s_{n_{\mathrm{start}} + 8i + j} \cdot (w{(j)}_q - z_{n_{\mathrm{start}} + 8i + j}) \end{align}
SplitK slicing along -dimension: \begin{align} S &= \text{split}k \ P &= \lceil K / (B_k \cdot S) \rceil \ k{\mathrm{off}}(p,\mathit{pid}_k) &= (p \cdot S + \mathit{pid}_k) \cdot B_k \end{align} Each block computes
Prior to the final atomic reduction into the output matrix .
3. Triton Kernel Micro-Architecture
The kernel micro-architecture employs:
- Grid dimensions:
- (over tiles)
- (over -axis splits)
- Thread blocks:
Each block typically maps to one CUDA thread array (CTA), commonly using 4 warps (128 threads).
- Thread layout:
Threads each handle one or more elements of the accumulator. Tiling values such as , are typical; warps are subdivided to load 16×16 subtiles from or .
- Data movement:
- tiles loaded from FP16 global into registers.
- tiles loaded as int32, dequantized in registers using per-column scales/zero-points.
- Matrix multiplications and accumulations are performed in registers (using Tensor Core-compatible instructions).
- No shared memory buffering is used; tile sizes are tuned for register-level reuse.
- The accumulator remains in FP32 registers until atomic addition into .
- Control flow:
The main tile/segment loop iterates over -slices, loading tiles, dequantizing , computing dot-products, and upon completion, atomically merging the block result into the output.
4. Benchmark Results and Performance Profile
Performance has been systematically benchmarked for , on NVIDIA A100 (PCIe/SXM) and Hopper H100 GPUs (PCIe).
Key results, comparing SplitK to classic data-parallel (DP) tiling:
| Platform (split_k) | M=1: SplitK vs DP TFLOPS (Δ) | M=16: SplitK vs DP TFLOPS (Δ) | Peak speedup |
|---|---|---|---|
| A100 80GB (split_k=4) | 0.15 vs 0.09 (65%) | 4.5 vs 3.5 (28%) | Up to 295% (small dims) |
| H100 (split_k=8) | 2.46 vs 1.10 (124%) | 4.1 vs 1.8 (128%) | Up to 295% (N=1024) |
Further breakdown:
- For M=1, N=2048 on H100: SplitK delivers 1.85 TFLOPS vs DP's 0.62 TFLOPS (195% increase).
- Nsight Compute (A100, M=16, N=K=4096):
Superior gains are observed for the "llama-style" regime () due to:
- Low causing memory-bound conditions for DP.
- SplitK boosting kernel occupancy by multiplying the number of CTAs in flight, hiding memory latency.
- On large-SM H100 architectures, raising split_k to 8 rebalances the load distribution and reduces the "wave" quantization losses of data-parallel tiling.
5. Limitations and Trade-Offs
SplitK partial fusion presents several operational limitations:
- Atomic addition contention: Increasing split_k raises the number of CTAs writing to the same -tile, potentially incurring significant atomic update queuing and serialization.
- Parameter tuning sensitivity: Optimal split_k is architecture-specific (e.g., A100: 4; H100: 8). Larger values can degrade performance due to contention.
- Resource constraints: Increasing , , or to exploit more parallelism also drives up register pressure and restricts hardware occupancy.
- Supported quantization: The implementation targets W4A16 (4-bit weights to FP16) only. Adapting to other quantization levels (W2, W8) requires a reimplementation of the unpack and dequantization logic.
6. Future Extensions
Potential extensions for SplitK partial fusion include:
- StreamK decomposition: Incorporating techniques such as K-streaming and double-buffered pipelines to further improve -axis parallelism and hide memory latencies.
- Generalized dequantization: Supporting per-group or per-channel quantization, requiring multiple scale/zero arrays and adaptable unpacking logic.
- Autotuning: Employing automated search over configuration space using Triton, to optimize for specific hardware and workload profiles.
SplitK partial fusion constitutes a hardware-efficient, highly parallel kernel design for W4A16 inference, demonstrating substantial performance gains for memory-bound, skinny GEMM layers found in contemporary foundation model deployments (Hoque et al., 2024).