Papers
Topics
Authors
Recent
2000 character limit reached

SplitK Partial Fusion in Quantized GEMM

Updated 4 January 2026
  • The technique fuses on-the-fly dequantization with GEMM into a single kernel, significantly improving compute resource utilization and memory bandwidth for quantized inference.
  • SplitK partial fusion decomposes the K-dimension into multiple segments, enabling efficient processing of 'skinny' matrices typical in large foundation models like LLaMA.
  • Benchmark results on NVIDIA A100 and H100 GPUs demonstrate speedups up to 295% compared to conventional tiling, though challenges include atomic contention and sensitive parameter tuning.

SplitK partial fusion is a technique for accelerating matrix multiplication involving quantized weights, specifically targeting W4A16 quantized inference workloads. It fuses on-the-fly dequantization and general matrix multiplication (GEMM) into a single kernel using a SplitK work decomposition within the Triton programming model. This approach is particularly effective for "skinny" matrix multiplications, such as those found in large foundation models (e.g., LLaMA), where the activation matrix is thin (mn=km \ll n = k). SplitK partial fusion improves compute resource utilization and memory bandwidth, delivering speedups of up to 295% compared to conventional data-parallel GEMM tiling, with average boosts of 65% on NVIDIA A100 GPUs and 124% on H100 GPUs (Hoque et al., 2024).

1. Core Algorithm: SplitK Partial Fusion

SplitK partial fusion modifies the standard matrix multiplication workflow for CABC \leftarrow AB with:

  • AA as an FP16 matrix of shape (M×K)(M \times K),
  • BB as a W4A16-quantized matrix (eight 4-bit weights packed into each int32), of shape (K×N)(K \times N),
  • Per-column scale (sns_n) and zero-point (znz_n) parameters for BB.

Instead of conventional 2D tiling over (M,N)(M, N), SplitK launches a 3D grid spanning (M,N,S)(M, N, S), where SS is the number of splits along the KK-axis (the "split_k" parameter). Each thread block (pid,pidk)(\mathit{pid}, \mathit{pid}_k) computes a partial sum over a disjoint segment of KK, accumulating into a tile (blockm×blockn)(\text{block}_m \times \text{block}_n) and performing an atomic add to the output matrix CC.

The fused kernel executes the following for each tile:

  1. Parameters: block sizes (blockm,blockn,blockk)(\text{block}_m, \text{block}_n, \text{block}_k) and S=splitkS = \text{split}_k.
  2. Launches a grid: grid_dim(0)=M/blockmN/blockn\text{grid\_dim}(0) = \lceil M / \text{block}_m \rceil \cdot \lceil N / \text{block}_n \rceil, grid_dim(1)=S\text{grid\_dim}(1) = S.
  3. Each kernel instance determines its tile position (moff,noff)(m_{\mathrm{off}}, n_{\mathrm{off}}) and range in KK using

koff(p,pidk)=(pS+pidk)blockkk_{\mathrm{off}}(p,\,\mathit{pid}_k) = (p \cdot S + \mathit{pid}_k) \cdot \text{block}_k

where P=K/(blockkS)P = \lceil K/(\text{block}_k \cdot S) \rceil is the number of iterations.

  1. For each koffk_{\mathrm{off}}:
    • Loads an AA tile (FP16) and a packed, quantized BB tile (int32).
    • Dequantizes BB in-register using the provided scales and zero-points.
    • Performs a fused matrix multiply-accumulate.
  2. After summing over its KK-segment, atomically adds the local accumulator into CC.

Sample pseudocode:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
@triton.jit
def splitk_fused_matmul(A_ptr, B_ptr, C_ptr, scales_ptr, zeros_ptr,
                        M, N, K, block_m, block_n, block_k, split_k):
    pid    = tl.program_id(0)
    pid_k  = tl.program_id(1)
    n_tiles = tl.cdiv(N, block_n)
    m_tile = pid // n_tiles
    n_tile = pid % n_tiles
    m_off  = m_tile * block_m
    n_off  = n_tile * block_n
    P      = tl.cdiv(K, block_k * split_k)
    acc    = tl.zeros((block_m, block_n), tl.float32)
    for p in range(P):
        k_off = (p * split_k + pid_k) * block_k
        if k_off >= K: break
        a = tl.load(A_ptr + (m_off + tl.arange(0, block_m))[:,None] * K
                              + (k_off + tl.arange(0, block_k))[None,:],
                    mask=, other=0.0, dtype=tl.float16)
        packed = tl.load(B_ptr + (k_off + tl.arange(0, block_k))[:,None] * N
                                 + (n_off + tl.arange(0, block_n))[None,:],
                         mask=, other=0, dtype=tl.int32)
        b = dequantize_4bit(packed, scales_ptr + n_off, zeros_ptr + n_off)
        acc += tl.dot(a, b)
    tl.atomic_add(C_ptr + m_off * N + n_off, acc)

2. Mathematical Formulation: Dequantization and SplitK Slicing

Dequantization from 4-bit weights to FP16 is performed per element via:

wfp=sn(wqzn)w_{\mathrm{fp}} = s_n \cdot (w_q - z_n)

where:

  • wq{0,,15}w_q \in \{0,\dots,15\} is the quantized weight,
  • sns_n is the scale for output column nn,
  • znz_n is the zero-point for column nn.

For each 32-bit packed word PiP_i holding 8 lanes: \begin{align} w{(j)}_q &= (P_i \gg 4j) \land 0xF \ w{(j)}_{\mathrm{fp}} &= s_{n_{\mathrm{start}} + 8i + j} \cdot (w{(j)}_q - z_{n_{\mathrm{start}} + 8i + j}) \end{align}

SplitK slicing along KK-dimension: \begin{align} S &= \text{split}k \ P &= \lceil K / (B_k \cdot S) \rceil \ k{\mathrm{off}}(p,\mathit{pid}_k) &= (p \cdot S + \mathit{pid}_k) \cdot B_k \end{align} Each block computes

Cmtile,ntile(pidk)=p=0P1A[moff:moff+Bm,koff:koff+Bk]×B[koff:koff+Bk,noff:noff+Bn]C_{m_{\mathrm{tile}},n_{\mathrm{tile}}}^{(\mathit{pid}_k)} = \sum_{p=0}^{P-1} A[m_{\mathrm{off}}:m_{\mathrm{off}}+B_m,\,k_{\mathrm{off}}:k_{\mathrm{off}}+B_k] \times B[k_{\mathrm{off}}:k_{\mathrm{off}}+B_k,\,n_{\mathrm{off}}:n_{\mathrm{off}}+B_n]

Prior to the final atomic reduction into the output matrix CC.

3. Triton Kernel Micro-Architecture

The kernel micro-architecture employs:

  • Grid dimensions:
    • dim0=MtilesNtiles\text{dim}0 = M_{\mathrm{tiles}} \cdot N_{\mathrm{tiles}} (over (m,n)(m, n) tiles)
    • dim1=splitk\text{dim}1 = \text{split}_k (over KK-axis splits)
  • Thread blocks:

Each block typically maps to one CUDA thread array (CTA), commonly using 4 warps (128 threads).

  • Thread layout:

Threads each handle one or more elements of the (blockm×blockn)(\text{block}_m \times \text{block}_n) accumulator. Tiling values such as blockm=16\text{block}_m=16, blockn=64\text{block}_n=64 are typical; warps are subdivided to load 16×16 subtiles from AA or BB.

  • Data movement:
    • AA tiles loaded from FP16 global into registers.
    • BB tiles loaded as int32, dequantized in registers using per-column scales/zero-points.
    • Matrix multiplications and accumulations are performed in registers (using Tensor Core-compatible instructions).
    • No shared memory buffering is used; tile sizes are tuned for register-level reuse.
    • The accumulator remains in FP32 registers until atomic addition into CC.
  • Control flow:

The main tile/segment loop iterates over PP KK-slices, loading tiles, dequantizing BB, computing dot-products, and upon completion, atomically merging the block result into the output.

4. Benchmark Results and Performance Profile

Performance has been systematically benchmarked for M{1,16}M\in\{1,16\}, N=K{512,1024,2048,4096,8192,16384}N=K\in\{512,1024,2048,4096,8192,16384\} on NVIDIA A100 (PCIe/SXM) and Hopper H100 GPUs (PCIe).

Key results, comparing SplitK to classic data-parallel (DP) tiling:

Platform (split_k) M=1: SplitK vs DP TFLOPS (Δ) M=16: SplitK vs DP TFLOPS (Δ) Peak speedup
A100 80GB (split_k=4) 0.15 vs 0.09 (65%) 4.5 vs 3.5 (28%) Up to 295% (small dims)
H100 (split_k=8) 2.46 vs 1.10 (124%) 4.1 vs 1.8 (128%) Up to 295% (N=1024)

Further breakdown:

  • For M=1, N=2048 on H100: SplitK delivers 1.85 TFLOPS vs DP's 0.62 TFLOPS (195% increase).
  • Nsight Compute (A100, M=16, N=K=4096):
    • Kernel latency: 27.9 μs vs 52.9 μs (48% lower).
    • Achieved DRAM BW: 313 GB/s vs 161 GB/s.
    • Occupancy: 27.8 vs 7.6 warps per SM.
    • SM utilization: 43% vs 21%.

Superior gains are observed for the "llama-style" regime (mn=km \ll n = k) due to:

  • Low MM causing memory-bound conditions for DP.
  • SplitK boosting kernel occupancy by multiplying the number of CTAs in flight, hiding memory latency.
  • On large-SM H100 architectures, raising split_k to 8 rebalances the load distribution and reduces the "wave" quantization losses of data-parallel tiling.

5. Limitations and Trade-Offs

SplitK partial fusion presents several operational limitations:

  • Atomic addition contention: Increasing split_k raises the number of CTAs writing to the same CC-tile, potentially incurring significant atomic update queuing and serialization.
  • Parameter tuning sensitivity: Optimal split_k is architecture-specific (e.g., A100: 4; H100: 8). Larger values can degrade performance due to contention.
  • Resource constraints: Increasing blockm\text{block}_m, blockn\text{block}_n, or blockk\text{block}_k to exploit more parallelism also drives up register pressure and restricts hardware occupancy.
  • Supported quantization: The implementation targets W4A16 (4-bit weights to FP16) only. Adapting to other quantization levels (W2, W8) requires a reimplementation of the unpack and dequantization logic.

6. Future Extensions

Potential extensions for SplitK partial fusion include:

  • StreamK decomposition: Incorporating techniques such as K-streaming and double-buffered pipelines to further improve KK-axis parallelism and hide memory latencies.
  • Generalized dequantization: Supporting per-group or per-channel quantization, requiring multiple scale/zero arrays and adaptable unpacking logic.
  • Autotuning: Employing automated search over (blockm,blockn,blockk,splitk)(\text{block}_m, \text{block}_n, \text{block}_k, \text{split}_k) configuration space using Triton, to optimize for specific hardware and workload profiles.

SplitK partial fusion constitutes a hardware-efficient, highly parallel kernel design for W4A16 inference, demonstrating substantial performance gains for memory-bound, skinny GEMM layers found in contemporary foundation model deployments (Hoque et al., 2024).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Whiteboard

Topic to Video (Beta)

Follow Topic

Get notified by email when new papers are published related to SplitK Partial Fusion.