Papers
Topics
Authors
Recent
2000 character limit reached

tritonBLAS: Analytical GEMM Kernel Selector

Updated 8 December 2025
  • tritonBLAS is an analytical framework that deterministically selects GEMM kernel parameters by modeling tile configurations with hardware-specific details like cache hierarchy and memory bandwidth.
  • It replaces costly runtime autotuning with a closed-form, roofline-based performance model that rapidly predicts optimal tiling parameters, achieving near-optimal throughput.
  • Implemented in the Triton programming environment, tritonBLAS enables reproducible and low-overhead kernel configuration for high-performance computing and machine learning workloads.

tritonBLAS is a deterministic, analytical framework for General Matrix Multiplication (GEMM) kernel parameter selection, implemented entirely within the Triton programming environment. It replaces runtime empirical autotuning with a closed-form model that incorporates architectural parameters—including cache hierarchy, memory bandwidths, register and shared memory capacities, and the topology of compute units—to rapidly select performant GPU kernel configurations. tritonBLAS achieves near-optimal throughput, with selection overhead several orders of magnitude lower than autotuned solutions, making it suitable for both high-performance computing (HPC) and ML production workloads (Swann et al., 3 Dec 2025).

1. Analytical Model for GEMM Kernel Selection

tritonBLAS models the GEMM operation CαAB+βCC \leftarrow \alpha \cdot A \cdot B + \beta \cdot C, where ARM×KA \in \mathbb{R}^{M \times K}, BRK×NB \in \mathbb{R}^{K \times N}, and CRM×NC \in \mathbb{R}^{M \times N}, as a tiling problem mapped to the GPU hardware. The framework partitions the computation across compute units (CUs) placing spatial “output” tiles with dimensions Mb×NbM_b \times N_b and reduction tiles along KK of size KbK_b.

The performance model builds upon the roofline approach, using the arithmetic intensity II (FLOP/byte):

I=2MNK4[MK+NK+MN]I = \frac{2 M N K}{4[M K + N K + M N]}

where 2 FLOPs per MAC and 2 bytes per FP16 element are assumed. The maximal achievable performance is:

Pmin(Fpeak,BmemI)P \leq \min (F_{peak}, B_{mem} \cdot I)

Per-tile latency comprises:

  • Compute latency LcompL_{comp}, derived from matrix instruction shape (mi,ni,ki)(m_i, n_i, k_i) and instruction latency LmiL_{mi}; for a tile, NMI=Mb/miNb/niKb/kiN_{MI} = \lceil M_b / m_i \rceil \lceil N_b / n_i \rceil \lceil K_b / k_i \rceil, with Lcomp=NMILmiL_{comp} = N_{MI} \cdot L_{mi}.
  • Memory latency LmemL_{mem}, modeled across L1, L2 caches, and DRAM with bandwidths RL1R_{L1}, RL2R_{L2}, RmemR_{mem}, and hit rates H1,H2H1, H2 estimated from tile footprints and reuse. Per-CU load latency is assigned by the bottleneck among memory hierarchy levels, accounting for uncached and cached loads.

Pipeline overhead (prologue/epilogue) and occupancy details, such as waves (ω=Tout/NCU\omega = \lceil T_{out} / N_{CU} \rceil) and active compute units, are incorporated. The total latency formula is:

Ltotal(Mb,Nb,Kb)=ωLtileL_{total}(M_b, N_b, K_b) = \omega \cdot L_{tile}

where LtileL_{tile} includes all compute, memory, pipeline, and store latencies, as explicated in the model.

Optimization seeks to select (Mb,Nb,Kb)(M_b, N_b, K_b) minimizing LtotalL_{total}, equivalently maximizing 2MNK/(Ltotalclock)2 M N K / (L_{total} \cdot clock).

2. Architecture-Driven Parameterization

All model inputs are obtained from microbenchmarks executed on the target GPU. Key parameters include:

  • NCUN_{CU}: Number of compute units (e.g., 80 on MI300X)
  • Matrix instruction shape (mi,ni,ki)(m_i, n_i, k_i) and latency LmiL_{mi}
  • Registers per SIMD, registers per thread (impose loop unroll limits)
  • Shared memory per CU, L1/L2 cache size and bandwidth, DRAM bandwidth and latency
  • Architectural constraints: e.g., MbKbM_b K_b \cdotelem_bytes +NbKb+ N_b K_b \cdotelem_bytes \leq smem_size

These values determine feasible tile factors and inform the performance cost function.

3. Blocking Notation and Optimization Constraints

Within this framework, tiling factors are notated as:

  • MbM_b, NbN_b: spatial output tile sizes
  • KbK_b: reduction axis tile size

Constraints restrict candidate tiles:

  • (Mb,Nb,Kb)(M_b, N_b, K_b) must respect shared memory and register file capacities
  • Shared memory per block: (Mb+Nb)Kb(M_b + N_b) K_b \cdotelem_bytes \leq SMEM_per_CU
  • Approximate register demand: MbNb2M_b \cdot N_b \cdot 2

The optimization solves:

minMb,Nb,KbLtotal(Mb,Nb,Kb)\min_{M_b, N_b, K_b} L_{total}(M_b, N_b, K_b)

subject to these constraints, over the valid factor set.

4. Static Enumeration and Selection Algorithm

tritonBLAS proceeds by statically enumerating potential tile sizes, invoking the analytical model for each candidate, and selecting the configuration with minimal predicted latency. No empirical testing or JIT compilation during selection is required.

Pseudocode encapsulates this process:

1
2
3
4
5
6
7
8
9
10
function select_tile(M, N, K, H):
    best = ∞
    for M_b in tile_factors(M):
        for N_b in tile_factors(N):
            for K_b in tile_divisors(K):
                if violates_register_or_smem(M_b, N_b, K_b, H): continue
                lat = compute_latency(M_b, N_b, K_b, H)
                if lat < best:
                    (best, best_cfg) = (lat, (M_b, N_b, K_b))
    return best_cfg

The candidate set PP is typically $50$–$150$ for FP16 kernels on contemporary GPUs, enabling sub-millisecond selection latency.

5. Triton-Based Implementation

The tritonBLAS workflow is embedded in Triton as a pure-Python module. The analytical selector and GEMM kernel are coupled as follows:

  • Selector determines optimal (Mb,Nb,KbM_b, N_b, K_b)
  • Kernel launch uses Triton’s grid mapping: (M/Mb,N/Nb)(\lceil M / M_b \rceil, \lceil N / N_b \rceil)
  • All tile sizes are resolved pre-launch; no calls to @triton.autotune are issued

Example function signature:

1
2
3
@triton.jit
def _triton_gemm(..., M_b: tl.constexpr, N_b: tl.constexpr, K_b: tl.constexpr):
    ...

and the user-facing routine:

1
2
3
4
def tritonBLAS_gemm(A, B, C, α, β):
    M_b, N_b, K_b = select_tile(M, N, K, hardware_params)
    grid = (ceil_div(M, M_b), ceil_div(N, N_b))
    _triton_gemm[grid](...)

Selection and launch overhead is $50$–80 μ80\ \mus, independent of matrix shape.

6. Performance Characterization

Empirical evaluation demonstrates:

GEMM Shape Tile Sizes Predicted Latency (μs) Measured Latency (μs) Error (%)
512×512×512 32×32×8 120 125 +4.2
1024×512×256 64×32×8 80 83 +3.8
2048×1024×512 128×64×16 240 253 +5.4

Over 150K random shapes up to $8$K dimension, tritonBLAS attains 94.7% of the peak performance found by exhaustive autotuning, with a median performance near 97%. Selection time is $5$–$6$ orders faster than runtime autotuning (e.g., \sim80 μs vs. 12–50 s). GEMM throughput on MI300X (FP16) is within ++3% of vendor-optimized torch.matmul(), and achieves \sim95% of the performance of autotuned cuBLAS/CUTLASS kernels across a wide arithmetic intensity range (0.5–50 FLOP/byte).

The timing model’s prediction error remains within ±6%\pm6\% across representative shapes and correctly ranks candidate tile configurations.

7. Strengths and Limitations

Strengths:

  • Eliminates runtime autotuning overhead, enabling rapid deployment and dynamic batching
  • Tile selection is deterministic and reproducible; identical shapes yield identical parameter choices
  • Portable across GPU architectures; calibration of model inputs suffices for adaptation
  • Delivers near-optimal performance (≥95% of exhaustive search), even for memory-bound workloads

Limitations:

  • Model abstracts away cache associativity and replacement details; not critical for GEMM’s regular access pattern
  • Framework operates under a single-GPU model; multi-GPU scenarios demand separate interconnect modeling
  • For extremely small GEMM problems (M,N,K64M, N, K \lesssim 64), kernel launch overhead dominates achievable performance; tiling decisions have marginal effect
  • Model accuracy depends on precision of measured architectural parameters; recalibration is required for new hardware, but can be achieved rapidly (few μs)

A plausible implication is that tritonBLAS enables analytical optimization in dynamic and rapidly changing workload scenarios with negligible selection cost, while maintaining throughput competitive with state-of-the-art autotuned frameworks (Swann et al., 3 Dec 2025).

Definition Search Book Streamline Icon: https://streamlinehq.com
References (1)

Whiteboard

Follow Topic

Get notified by email when new papers are published related to tritonBLAS.