Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions lit_tests/kernel/wave/repro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from sympy import ceiling

import wave_lang.kernel.lang as tkl
import wave_lang.kernel.wave as tkw
from wave_lang.kernel.lang.global_symbols import *
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
from wave_lang.kernel.wave.utils.general_utils import (
run_test,
)

M = tkl.sym.M
N = tkl.sym.N
K = tkl.sym.K
B = tkl.sym.B
B_KV = tkl.sym.B_KV
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K = tkl.sym.BLOCK_K
BLOCK_B = tkl.sym.BLOCK_B
GROUP_SIZE_N = tkl.sym.GROUP_SIZE_N
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0


@run_test
def test_gemm():
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.TilingConstraint(K, BLOCK_K)]
constraints += [tkw.WaveConstraint(M, ceiling(BLOCK_M / 2))]
constraints += [tkw.WaveConstraint(N, ceiling(BLOCK_N / 2))]

constraints += [
tkw.HardwareConstraint(
threads_per_wave=64,
mma_type=tkw.MMAType.F32_16x16x16_F16,
)
]

@tkw.wave(constraints)
def gemm(
a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32],
):
c_reg = tkl.Register[M, N, tkl.f32](0.0)

@tkw.iterate(K, init_args=[c_reg])
def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
a_reg = tkw.read(a)
b_reg = tkw.read(b)
acc = tkw.mma(a_reg, b_reg, acc)
return acc

tkw.write(repeat, c)

options = WaveCompileOptions(
subs={
M: 64,
N: 128,
K: 64,
BLOCK_M: 32,
BLOCK_N: 32,
BLOCK_K: 16,
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE,
},
canonicalize=True,
# Option 1: regular MLIR with IREE-isms
# compile_to_mlir=True,
# Option 2: regular MLIR without IREE-isms
compile_to_mlir=True,
use_water_backend=True,
# Option 3: ASM backend
# wave_runtime=True,
# backend="asm",
# compile_to_asm=True,
)
gemm = wave_compile(options, gemm)
print(gemm.asm)
Loading