From 7870819040f95fcc247cc45803ecb56b8207ebf0 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 29 Jan 2026 11:41:07 +0100 Subject: [PATCH] simple repro for asm generation for gemm Signed-off-by: Alex Zinenko --- lit_tests/kernel/wave/repro.py | 80 ++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 lit_tests/kernel/wave/repro.py diff --git a/lit_tests/kernel/wave/repro.py b/lit_tests/kernel/wave/repro.py new file mode 100644 index 0000000000..dc56b08ca8 --- /dev/null +++ b/lit_tests/kernel/wave/repro.py @@ -0,0 +1,80 @@ +from sympy import ceiling + +import wave_lang.kernel.lang as tkl +import wave_lang.kernel.wave as tkw +from wave_lang.kernel.lang.global_symbols import * +from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile +from wave_lang.kernel.wave.utils.general_utils import ( + run_test, +) + +M = tkl.sym.M +N = tkl.sym.N +K = tkl.sym.K +B = tkl.sym.B +B_KV = tkl.sym.B_KV +BLOCK_M = tkl.sym.BLOCK_M +BLOCK_N = tkl.sym.BLOCK_N +BLOCK_K = tkl.sym.BLOCK_K +BLOCK_B = tkl.sym.BLOCK_B +GROUP_SIZE_N = tkl.sym.GROUP_SIZE_N +ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE +ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0 + + +@run_test +def test_gemm(): + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K)] + constraints += [tkw.WaveConstraint(M, ceiling(BLOCK_M / 2))] + constraints += [tkw.WaveConstraint(N, ceiling(BLOCK_N / 2))] + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + mma_type=tkw.MMAType.F32_16x16x16_F16, + ) + ] + + @tkw.wave(constraints) + def gemm( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32], + ): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + + @tkw.iterate(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a) + b_reg = tkw.read(b) + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + tkw.write(repeat, c) + + options = WaveCompileOptions( + subs={ + M: 64, + N: 128, + K: 64, + BLOCK_M: 32, + BLOCK_N: 32, + BLOCK_K: 16, + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, + }, + canonicalize=True, + # Option 1: regular MLIR with IREE-isms + # compile_to_mlir=True, + # Option 2: regular MLIR without IREE-isms + compile_to_mlir=True, + use_water_backend=True, + # Option 3: ASM backend + # wave_runtime=True, + # backend="asm", + # compile_to_asm=True, + ) + gemm = wave_compile(options, gemm) + print(gemm.asm)