Papers
Topics
Authors
Recent
2000 character limit reached

Triton Kernels with Partial Fusion

Updated 31 December 2025
  • The paper introduces partial fusion, a GPU kernel strategy that fuses selected adjacent tensor operations to lower kernel launch overhead and reduce memory usage in LLM workflows.
  • It details methodologies like chunkwise processing and SplitK fusion, achieving significant speedups and mitigating issues like register spilling and memory bottlenecks.
  • Empirical benchmarks on GPUs such as A100 and H100 show multi× performance gains and notable memory savings compared to conventional unfused implementations.

Triton kernels with partial fusion describe a set of GPU kernel engineering strategies for optimizing training and inference of LLMs by grouping only subsets of adjacent tensor operations—often performed over structured tiles or chunks of data—into a single Triton kernel launch. This approach addresses kernel-launch overhead, memory bandwidth limitations, and register/shared-memory constraints. By restricting fusion scope to operations sharing contiguous data and memory footprint, these kernels deliver significant throughput and memory efficiency improvements, as documented in Liger-Kernel for LLM training (Hsu et al., 14 Oct 2024) and SplitK-decomposed fused kernels for inference with int4 quantization (Hoque et al., 5 Jan 2024).

1. Motivation and Design Principles

Partial fusion arose from the need to mediate between two extremes in GPU kernel engineering: unfused implementations (one kernel per tensor op, incurring repeated host-GPU transitions and intermediate DRAM allocations) and fully fused kernels (merging large stretches of operator graphs, often at the expense of excessive register usage and poor resource tiling). In Triton, every kernel launch incurs latency and dispatch overhead; repeated intermediate tensor read/write to high-bandwidth memory (HBM) constrains performance and power. Full fusion, while minimizing memory traffic, frequently faces hardware occupation bottlenecks—e.g., register spilling, shared-memory exhaustion, and suboptimal automatic tiler behavior—particularly for large LLM workloads.

Partial fusion in Liger-Kernel targets fusing those operations which are:

  • Elementwise or small reduction ops composable on the same buffered tile.
  • Share identical memory stride and block layout. For instance, in RMSNorm, mean-square computation, reciprocal sqrt normalization, affine scaling, and RMS caching are fused into a single pass. Similarly, in SwiGLU/GeGLU layers, two linear pre-activations, nonlinearities, and final elementwise multiplications are fused.

2. Algorithmic Overview and Triton Kernel Construction

Partial fusion in practice is exemplified by the FusedLinearCrossEntropy (FLCE) operation in Liger-Kernel. In this case, the objective is to avoid materializing the entire (BT×V)(BT \times V) logit matrix in DRAM when performing linear projection followed by softmax and cross-entropy loss with gradient backpropagation. Instead, the data is processed chunkwise over CHUNKCHUNK rows:

  • Load chunk of hidden state H(c)RCHUNK×HH^{(c)} \in \mathbb{R}^{CHUNK\times H}.
  • Multiply by projection weights WRH×VW \in \mathbb{R}^{H\times V} to produce chunk logits X(c)X^{(c)}.
  • Apply numerically stable softmax, then compute cross-entropy gradients.
  • Immediately backpropagate: compute H(c)\nabla_H^{(c)} and accumulate gradients for WW via atomic adds.

Triton-style pseudocode demonstrates interleaving forward and backward passes without saving intermediate full-matrix outputs:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import triton
import triton.language as tl

@triton.jit
def fused_linear_cross_entropy_kernel(
    H_ptr, W_ptr, T_ptr,
    W_grad_ptr, H_grad_ptr,
    BT, H, V, CHUNK,
    eps: float = 1e-6,
    BLOCK_H: tl.constexpr,
    BLOCK_V: tl.constexpr
):
    off = tl.program_id(0) * CHUNK
    h = tl.load(H_ptr + (off + tl.arange(0, CHUNK))[:, None] * H
                        + tl.arange(0, BLOCK_H)[None, :])
    w = tl.load(W_ptr + tl.arange(0, BLOCK_H)[:, None] * V
                        + tl.arange(0, BLOCK_V)[None, :])
    logits = tl.dot(h, w)
    lo_max = tl.max(logits, axis=1, keepdims=True)
    exp_logits = tl.exp(logits - lo_max)
    sum_exp = tl.sum(exp_logits, axis=1, keepdims=True)
    probs = exp_logits / (sum_exp + eps)
    t = tl.load(T_ptr + off + tl.arange(0, CHUNK))
    one_hot = (tl.arange(0, BLOCK_V)[None, :] == t[:, None]).to(tl.float32)
    grad_logits = probs - one_hot
    h_grad = tl.dot(grad_logits, w, trans_a=False, trans_b=True)
    tl.store(H_grad_ptr + (off + tl.arange(0, CHUNK))[:, None]*H
                        + tl.arange(0, BLOCK_H)[None, :], h_grad)
    grad_w = tl.dot(h, grad_logits, trans_a=True, trans_b=False)
    tl.atomic_add(W_grad_ptr + , grad_w)

The workflow avoids materializing full (BT×V)(BT \times V) intermediates and launches one kernel per chunk, balancing peak memory footprint and arithmetic intensity.

3. Comparisons: Unfused vs. Partial Fusion

Full unfused sequences, typical of "PyTorch style" computation, materialize large intermediate tensors after each operator:

X=HW(linear projection),Y=softmax(X),L=itilogYiX = H\,W \quad\text{(linear projection)}, \quad Y = \mathrm{softmax}(X), \quad \mathcal{L} = -\,\sum_i t_i\log Y_i

XL=Yt,HL=XL  W,WL=HXL\nabla_X \mathcal{L} = Y - t,\quad \nabla_H \mathcal{L} = \nabla_X\mathcal{L}\;W^\top,\quad \nabla_W \mathcal{L} = H^\top\,\nabla_X\mathcal{L}

This implementation requires up to three separate kernel launches per operator and maximal memory allocation.

The fused and chunked partial fusion sequence is:

for chunks c=1BTCHUNK:X(c)=H(c)W,H(c)L=W(softmax(X(c))T(c)),WL+=(H(c))(softmax(X(c))T(c))\text{for chunks }c=1\ldots \left\lceil{\frac{BT}{\mathrm{CHUNK}}}\right\rceil:\quad X^{(c)} = H^{(c)}\,W,\quad \nabla_{H^{(c)}}\mathcal{L} = W^\top(\mathrm{softmax}(X^{(c)}) - T^{(c)}),\quad \nabla_W \mathcal{L} {+}= (H^{(c)})^\top(\mathrm{softmax}(X^{(c)}) - T^{(c)})

This approach eliminates the need to materialize large (BT×V)(BT \times V) matrices, with all forward and backward steps completed in-place for each tile.

4. Performance Analysis and Empirical Results

Liger-Kernel benchmarks on A100 (80 GB) document substantial improvements:

Kernel Speedup vs. PyTorch/TorchScript Peak Memory Reduction
CrossEntropy ≈ 3× faster ≈ 5× less
GeGLU/SwiGLU ≈ same speed ≈ 1.6× less at T=16384
RMSNorm ≈ 7× faster ≈ 3× less
LayerNorm ≈ 1.3× faster ≤ 1.1× (negligible)
RoPE ≈ 8× faster ≈ 3× less

End-to-end LLM fine-tuning on 4×A100 (bfloat16) yields throughput increases from +11.9% to +42.8% and GPU memory savings up to −56.8% compared to HuggingFace implementations, depending on model and batch size (Hsu et al., 14 Oct 2024).

Fused kernels enable these gains chiefly by minimizing temporary tensor allocation and reducing overall kernel launch count, especially for large vocabulary sizes where BT×VBT \times V is otherwise prohibitive.

5. SplitK Partial Fusion for Quantized Inference

In inference scenarios for LLMs with quantized weights, partial fusion is combined with SplitK work decomposition. Here, GEMM between skinny activation matrices and large square weight matrices is required (mn=km \ll n = k). Fusing int4 dequantization with GEMM eliminates the extra write of dequantized weights. SplitK partitions the KK dimension into SS blocks, raising occupancy and facilitating parallel summation via atomic adds (Hoque et al., 5 Jan 2024).

Mathematically: Cp,q=r=1kAp,rs(QW[r,q]z)C_{p,q} = \sum_{r=1}^k A_{p,r}\,s(Q_W[r,q]-z) SplitK decomposes this sum into slices: Cp,q=s=0S1(r=sK/S(s+1)K/S1Ap,rW^r,q)=s=0S1Cp,q(s)C_{p,q} = \sum_{s=0}^{S-1} \left(\sum_{r = sK/S}^{(s+1)K/S-1} A_{p,r}\,\widehat W_{r,q}\right) = \sum_{s=0}^{S-1} C^{(s)}_{p,q}

Benchmark results indicate speedups up to ≈65% on A100 and ≈124% on H100, with peaks approaching 295% for small-batch, large-weight workloads.

Key metrics:

Metric SplitK Data Parallel
Global Mem Throughput 313 GB/s 161 GB/s
Achieved Occupancy 27.8% 7.6%
SM Utilization 43.1% 20.8%
Latency 27.90 µs 52.93 µs

SplitK’s per-block register (92 vs. 150) and shared-memory usage (102 KB vs. 168 KB) are reduced, further boosting SM utilization.

6. Architectural Considerations and Limitations

The efficacy of partial fusion is bounded by hardware constraints:

  • Register and Shared-Memory Pressure: Determining optimal block sizes for Triton thread-blocks is essential. Excessively large fusion scopes reduce occupancy and can trigger register spilling.
  • Atomic Add Bottleneck: Especially evident in gradient accumulation for large weight matrices or SplitK reductions. Alternatives such as per-warp partials with intra-block reduction may be required.
  • Memory Contiguity: Triton kernels assume row-major contiguous layouts; tensors must be .contiguous() before kernel launch.
  • Chunk Size Heuristics: In Liger, chunk size is chosen as

CHUNK=2log2BTV/H\mathrm{CHUNK} = 2^{\lceil \log_2\left\lceil \frac{BT}{\lceil V/H\rceil} \right\rceil \rceil}

balancing maximum memory savings against GPU utilization.

  • Hardware Dependencies: Fast 32-bit atomics and high DRAM bandwidth are prerequisites. On architectures with limited atomic throughput, tuning split-k may be required.

A plausible implication is that as model and batch sizes scale, the design flexibility afforded by partial fusion and chunking remains essential for maintaining performance and resource efficiency.

Partial fusion implemented in Triton, as analyzed in Liger-Kernel (Hsu et al., 14 Oct 2024) and SplitK inference kernels (Hoque et al., 5 Jan 2024), represents the current state-of-the-art for tailored GPU utilization in both training and inference of large-scale models. These kernels have demonstrated multi-× speedups and substantial memory reductions for popular LLMs, surpassing conventional PyTorch/HuggingFace approaches.

Grouping tensor operations over shared data tiles allows practitioners to balance memory-bound and compute-bound problem instances, adapting fusion levels as necessary for both forward and backward passes. The deployment flexibility—integrating into standard PyTorch training loops with minimal code changes—further highlights their practical impact.

Current limitations center on hardware constraints, atomic add performance, and fusion scope tuning. Ongoing work investigates extending partial fusion strategies to a broader class of quantized formats and optimizing kernel resource usage for emerging GPU architectures.

Summary Table: Partial Fusion Kernels in Liger-Kernel

Kernel Fusion Strategy Throughput Gain Memory Reduction
RMSNorm Mean-square + scale + RMS ≈7× faster 3× less
GeGLU/SwiGLU Linear+activation+product ≈same speed 1.6× less
CrossEntropy Linear+softmax+CE+grad ≈3× faster 5× less
RoPE All rotation ops fused ≈8× faster 3× less

In summary, Triton kernels with partial fusion leverage the granularity of operator grouping over contiguous tiles/chunks to optimize computational throughput and memory usage in LLM training and inference. This approach, validated by Liger-Kernel and SplitK implementations, increasingly defines best practice in high-performance deep learning.

Definition Search Book Streamline Icon: https://streamlinehq.com
References (2)

Whiteboard

Topic to Video (Beta)

Follow Topic

Get notified by email when new papers are published related to Triton Kernels with Partial Fusion.