HipKittens: AMD AI Kernel Framework
- HipKittens is a C++ embedded tile-based programming framework that abstracts complex AMD GPU optimizations for high-performance AI kernels.
- It offers explicit register and shared memory tiling, asynchronous memory transfers, and chiplet-aware grid coordination to maximize throughput.
- Benchmarks reveal up to 3× speedup over traditional methods, demonstrating scalable efficiency across compute- and memory-bound kernels.
HipKittens is a C++-embedded, tile-based programming framework designed to facilitate the development of high-performance AI kernels on AMD CDNA3 and CDNA4 GPUs without requiring raw assembly programming. By encapsulating architecture-specific optimizations for memory hierarchy, wave scheduling, and register usage in a concise abstraction layer, HipKittens enables developers to efficiently target modern AMD accelerators for demanding AI workloads.
1. Motivation and Design Goals
The impetus for HipKittens stems from the challenges unique to AMD’s Instinct CDNA3 (MI325X) and CDNA4 (MI355X) GPUs, which offer in excess of 2.5 PFLOPS (BF16) and 8 TB/s HBM bandwidth. Achieving peak hardware performance on these GPUs has traditionally required hand-written assembly, as existing C++/HIP or compiler-based approaches such as Triton and HIPCC rarely approach peak throughput. Prior domain-specific languages (DSLs) like ThunderKittens, CuTe, and Gluon focus on NVIDIA PTX/CUDA, often encoding NVIDIA-centric memory and execution patterns.
HipKittens’ objective is to provide a minimal but expressive set of tile-centric primitives and reusable scheduling motifs encompassing:
- Explicit, bank-conflict-free register and shared memory tiling
- Asynchronous memory movement and fine-grain wave scheduling
- Chiplet-aware grid coordination to match AMD’s hierarchical caches
A central insight is that the tile+bulk operation (“tiles + bulk ops”) abstraction, influential in ThunderKittens, generalizes across vendors, but its instantiation—the particular layouts, swizzles, and schedules—needs substantial adaptation for AMD’s distinctive register file, shared-memory bank arrangement, and chiplet-based cache design.
2. Core Primitives and Programming Model
HipKittens exposes three levels of abstractions—register (warp/PTX), block, and grid—with explicit data structure and operator primitives:
2.1 Tile Data Structures:
- Register tiles:
1 |
rt<Dtype, M, TK, Layout, Shape> A; |
Dtype denotes element type (e.g., BF16, FP16, FP8), (M,TK) the tile dimensions, Layout the memory layout (row/column-major), and Shape the underlying MFMA tile.
- Shared-memory tiles:
1 |
st<Dtype, TileRows, TileCols, Shape> S; |
2.2 Bulk Operators:
mma(A,B,C): Fused matrix-multiply-accumulate over register tiles (e.g., MFMA 16×16×32)load/store: Data movement between registers and shared memory (or global into LDS) supporting swizzled layouts and asynchronous DMA- Tilewise elementwise operators:
exp2,add,sum, etc.
2.3 Tile-Level Tiling Parameters:
- For GEMM with size :
- Block tile sizes , , on CDNA4
- Each block computes ; 8 waves per block, each wave owns a (MFMA) register tile, composing larger subtiles as or
2.4 Overlapping Compute and Memory:
- 8-wave ping-pong: For balanced workloads, 8 waves per block split into two groups of 4. Within a group, one wave issues matrix ops while another issues memory ops, then the roles swap. Bulk loads from global to LDS to registers overlap with MFMA computation, managed by explicit software barriers.
- 4-wave interleave: For compute- or memory-heavy kernels, 1 wave per SIMD; each issues small, tightly interleaved load/compute groups, saturating pipelines.
2.5 Grid-Level Chiplet-Aware Swizzle:
- CDNA4 features 8 chiplets (XCDs), each with private L2 and shared LLC. To optimize L2/LLC affinity:
- Blocks are assignment–swizzled: blocks in “chunks” to XCDs, then tiles are traversed in vertical windows of height , cycling through columns per XCD for improved cache reuse.
- This swizzle achieves up to 19% bandwidth improvement over naive row-major grid assignment.
3. Algorithmic Adaptations to AMD Architecture
While tile-centric abstractions generalize, HipKittens concretely adapts kernel algorithms to AMD hardware:
- Register file management: Each SIMD provides 256 VGPRs and 256 AGPRs, statically partitioned. HIPCC cannot allocate AGPRs as MFMA inputs, leading to extra
v_accvgpr_readinstructions. HipKittens’ tile objects are directly mapped to AGPR/VGPR via arange<…>API, matching hand-tuned assembly performance as demonstrated in non-causal MHA backward. - Matrix-core tile layouts: In contrast to the composable 16×16 blocks on NVIDIA, AMD’s MFMA shapes for input and output tiles are unique and must be mapped precisely. HipKittens invokes dedicated constructors and swizzling mechanisms to guarantee bank-conflict-free operations.
- Cache hierarchy optimization: The chiplet-aware grid schedule maximizes both L2 and LLC reuse. Block swizzling strategies increase peak bandwidth by 19% in benchmarked scenarios.
- Compiler scheduling hints: HK emits
llvm.amdgcn.sched.group.barrier,sched.barrierintrinsics, and thes_setprioinstruction to influence mixing of VALU, MFMA, and LDS ops, as well as wave priority during ping-pong scheduling.
4. Performance Evaluation
Benchmarks conducted on AMD MI355X (CDNA4) with ROCm 7.0 involve 500 warmup and 100 measurement runs per configuration, using random inputs. Performance is reported in PFLOPS/TFLOPS, with speedup and efficiency defined as:
Selected results illustrate the competitiveness of HK versus both AMD hand-crafted assembly (AITER) and compiler-generated code:
| Kernel | Baseline (FP32/assembly/comp) | HK (TFLOPS) | Speedup |
|---|---|---|---|
| BF16 GEMM 8192³ | Triton: 0.84 PFLOPS | 2.50 PFLOPS | 2.98× |
| FP8 GEMM 8192³ | CK: 2.10 PFLOPS | 5.00 PFLOPS | 2.38× |
| GQA non-causal forward (d=64) | AITER: 410 TFLOPS | 525 TFLOPS | 1.28× |
| GQA non-causal backward (d=64) | AITER: 510 TFLOPS | 910 TFLOPS | 1.78× |
| Fused Dropout+LayerNorm (d=128) | PyTorch compiled: 400 TFLOPS | 520 TFLOPS | 1.30× |
Further breakdown includes:
- GEMM (M=N=K=8192, BF16): HK achieves 2.50 PFLOPS (99% of 2.52 PFLOPS FHLOPS peak), 3.0× Triton.
- FP8 GEMM: HK at 5.0 PFLOPS (100% peak), 2.4× CK/ROCm.
- Attention (GQA, d=64): HK non-causal forward: 525 TFLOPS (96% of 547 TFLOPS peak); backward: 910 TFLOPS vs. 510 TFLOPS (AITER).
- Memory-bound operators: Fused Dropout+Residual+LayerNorm is 1.1–1.3× better than AITER/PyTorch compiled; rotary embedding is 1.2× Triton.
A plausible implication is that HipKittens enables broad performance portability—including both compute- and memory-bound kernels—beyond what AMD hand-optimized assembly or compiler-based DSLs can systematize.
5. Programming Workflow: Developing New Operators
HipKittens offers a structured approach for authoring high-performance matrix-style or attention-like operators, removing the need for explicit assembly coding. The development sequence comprises:
Step 1: Define tile shapes and memory buffers
1 2 3 4 5 6 7 8 9 10 |
constexpr int BLOCK_M = 256, BLOCK_N = 256, BLOCK_K = 64; using Areg = rt_bf16<16, BLOCK_K, row_l, mfma_shape_16x32>; using Breg = rt_bf16<32, BLOCK_K, row_l, mfma_shape_16x32>; using Creg = rt_f32<16,32, col_l, mfma_shape_16x32>; struct KernelGlobals { bf16* __restrict__ A; bf16* __restrict__ B; float* __restrict__ C; int M,N,K, lda,ldb,ldc; }; |
Step 2: Allocate and preload shared memory
1 2 3 4 5 6 |
extern __shared__ uint8_t raw_shm[]; shared_allocator shm((int*)raw_shm); using Asmem = st_bf16<BLOCK_M, BLOCK_K, swizzle_16x32>; using Bsmem = st_bf16<BLOCK_N, BLOCK_K, swizzle_16x32>; auto (&A_sh)[2] = shm.allocate<Asmem,2>(); auto (&B_sh)[2] = shm.allocate<Bsmem,2>(); |
Step 3: Compute block indices and wave IDs
1 2 3 4 5 |
int bx = blockIdx.x, by = blockIdx.y; int bid = grid_swizzle(bx,by,…); // chiplet-aware reorder int pid_m = bid % (M/BLOCK_M), pid_n=bid/(M/BLOCK_M); int warp_id = kittens::warpid(); // 0..7 int wave_row = warp_id/4, wave_col=warp_id%4; |
Step 4: Prologue (double-buffer preload)
1 2 3 4 5 6 |
G::load(A_sh[0], g.A, {0, pid_m, 0}); G::load(B_sh[0], g.B, {0, pid_n, 0}); kitten::barrier(); G::load(A_sh[1], g.A, {1, pid_m, 0}); G::load(B_sh[1], g.B, {1, pid_n, 0}); kitten::barrier(); |
Step 5: Hot-loop (8-wave ping-pong)
1 2 3 4 5 6 7 8 9 10 |
int cur=0, nxt=1; for(int t=0; t < K/BLOCK_K; ++t) { Areg aR; Breg bR; load(aR, subtile(A_sh[cur], wave_row)); load(bR, subtile(B_sh[cur], wave_col)); G::load(A_sh[nxt], g.A, {t+2, pid_m,}); kitten::barrier(); mma(aR,bR,Creg[wave_row][wave_col]); std::swap(cur,nxt); } |
Step 6: Result Epilogue and Store
1 2 |
store(g.C, Creg[0][0], {0,pid_m,pid_n}); // … handle all sub-tiles |
Step 7: Launch and Tune
- Select block sizes to maximize LDS utilization and respect register pressure (64 KB LDS constraint)
- Sweep wave count (4 vs 8) for kernel balance
- Tweak swizzle window (e.g., ) and chunk size () to accommodate grid and chiplet divisibility
- Empirically measure GFLOPS and ensure efficiency above 90% of peak
With this approach, new kernels can be prototyped in under 200 lines of code, exploiting full hardware capabilities while abstracting the intricacies of AMD’s underlying architecture.
6. Significance for AI Kernel Portability
HipKittens establishes that tile-based program design and bulk operator scheduling—previously associated with NVIDIA-targeted DSLs—can be systematically generalized and re-instantiated for AMD architectures using appropriate hardware-aware layouts and algorithms. By providing abstractions that encapsulate low-level optimizations (e.g., memory swizzling, waveform scheduling, chiplet locality), HipKittens accelerates the development of AI kernels that are not only performant but also scalable across multiple GPU generations. This suggests a near-term pathway towards vendor-agnostic, performance-portable kernel development workflows for the rapidly evolving landscape of AI accelerators.