Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def get_mxfp4_dbuf_pingpong_schedule(use_stagger: bool = True, shape: tuple = No
use_stagger: Enable wave staggering + WorkgroupBarrier in cluster 0.
Recommended for 8-wave configs; disable for 4-wave.
shape: Tuple of (M, N, K) dimensions. If provided and bigger than
(1024, 1024, 16384), an extra WorkgroupBarrier will be added
(1024, 1024, 1024), an extra WorkgroupBarrier will be added
after the first SchedulingBarrier in cluster 0.
"""
K = tkl.sym.K
Expand Down Expand Up @@ -406,10 +406,10 @@ def mxfp4_dbuf_schedule():
print("loop_scaled_mma_0")
print(loop_scaled_mma_0)

# Check if shape is bigger than threshold (1024, 1024, 16384)
# Check if shape is bigger than threshold (1024, 1024, 1024)
use_extra_barrier = False
if shape is not None and len(shape) >= 3:
threshold = (1024, 1024, 16384)
threshold = (1024, 1024, 1024)
use_extra_barrier = (
shape[0] > threshold[0]
and shape[1] > threshold[1]
Expand Down
74 changes: 74 additions & 0 deletions wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
bitcast_a, bitcast_a_scale, bitcast_b, bitcast_b_scale, scaled_mma.
"""

from sympy import Piecewise, ceiling, floor

import wave_lang.kernel.lang as tkl
import wave_lang.kernel.wave as tkw
from wave_lang.kernel.lang.global_symbols import *
Expand All @@ -33,6 +35,8 @@ def get_tagged_mxfp4_gemm(
mfma_variant: ScaledMMAType = ScaledMMAType.F32_16x16x128_F8F6F4,
a_address_space: tkl.AddressSpace = SHARED_ADDRESS_SPACE,
b_address_space: tkl.AddressSpace = SHARED_ADDRESS_SPACE,
reorder_workgroups=True,
group_size_n=32,
):
"""Return a tagged MXFP4 scaled GEMM kernel + compile options for CDNA4.

Expand All @@ -53,6 +57,7 @@ def get_tagged_mxfp4_gemm(
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K = tkl.sym.BLOCK_K
GROUP_SIZE_N = tkl.sym.GROUP_SIZE_N
A_ADDRESS_SPACE = tkl.sym.A_ADDRESS_SPACE
B_ADDRESS_SPACE = tkl.sym.B_ADDRESS_SPACE
C_ADDRESS_SPACE = tkl.sym.C_ADDRESS_SPACE
Expand All @@ -66,6 +71,13 @@ def get_tagged_mxfp4_gemm(

constraints += [tkw.HardwareConstraint(threads_per_wave=64, mma_type=mfma_variant)]

if reorder_workgroups:
new_wg0, new_wg1 = _reorder_mxfp4_workgroups(
WORKGROUP_0, WORKGROUP_1, M, N, BLOCK_M, BLOCK_N, GROUP_SIZE_N
)
constraints += [tkw.ReorderingConstraint(new_wg0, 0)]
constraints += [tkw.ReorderingConstraint(new_wg1, 1)]

@tkw.wave(constraints)
def gemm(
a: tkl.Memory[M, K / 2, A_ADDRESS_SPACE, tkl.i8],
Expand Down Expand Up @@ -102,6 +114,7 @@ def repeat(
BLOCK_M: block_shape[0],
BLOCK_N: block_shape[1],
BLOCK_K: block_shape[2],
GROUP_SIZE_N: group_size_n,
M: shape[0],
N: shape[1],
K: shape[2],
Expand All @@ -125,6 +138,8 @@ def get_tagged_mxfp4_gemm_preshuffle_b(
wave_shape: tuple[int, int] = (2, 2),
mfma_variant: ScaledMMAType = ScaledMMAType.F32_16x16x128_F8F6F4,
a_address_space: tkl.AddressSpace = SHARED_ADDRESS_SPACE,
reorder_workgroups=True,
group_size_n=32,
):
"""Return a tagged MXFP4 scaled GEMM kernel with preshuffled B and B_scale.

Expand All @@ -151,6 +166,7 @@ def get_tagged_mxfp4_gemm_preshuffle_b(
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K = tkl.sym.BLOCK_K
GROUP_SIZE_N = tkl.sym.GROUP_SIZE_N
A_ADDRESS_SPACE = tkl.sym.A_ADDRESS_SPACE
C_ADDRESS_SPACE = tkl.sym.C_ADDRESS_SPACE
K_PACKED = tkl.sym.K_PACKED
Expand All @@ -165,6 +181,13 @@ def get_tagged_mxfp4_gemm_preshuffle_b(

constraints += [tkw.HardwareConstraint(threads_per_wave=64, mma_type=mfma_variant)]

if reorder_workgroups:
new_wg0, new_wg1 = _reorder_mxfp4_workgroups(
WORKGROUP_0, WORKGROUP_1, M, N, BLOCK_M, BLOCK_N, GROUP_SIZE_N
)
constraints += [tkw.ReorderingConstraint(new_wg0, 0)]
constraints += [tkw.ReorderingConstraint(new_wg1, 1)]

# --- B data preshuffle mapping (aiter shuffle_weight) ---
# Each 16-row x 32-byte tile is reordered from [n, k_sub, k_elem] to
# [k_sub, n, k_elem] so a contiguous 256-byte read fetches one K-chunk
Expand Down Expand Up @@ -270,6 +293,7 @@ def repeat(
BLOCK_M: block_shape[0],
BLOCK_N: block_shape[1],
BLOCK_K: block_shape[2],
GROUP_SIZE_N: group_size_n,
M: shape[0],
N: shape[1],
K: shape[2],
Expand All @@ -287,3 +311,53 @@ def repeat(
)

return gemm, options


def _reorder_mxfp4_workgroups(wg0, wg1, m, n, block_m, block_n, group_size_n):
"""Remap workgroup indices to a new order based on group_size_n along N dimension.

Args:
wg0: Initial workgroup index along M dimension.
wg1: Initial workgroup index along N dimension.
m: Number of workgroups along M dimension.
n: Number of workgroups along N dimension.
block_m: Tile size along M dimension.
block_n: Tile size along N dimension.
group_size_n: Number of N-tiles per group.

Returns:
(new_wg0, new_wg1): New workgroup indices along M and N dimensions.
"""
wg0, wg1 = WORKGROUP_0, WORKGROUP_1
num_wg_0 = ceiling(m / block_m)
num_wg_1 = ceiling(n / block_n)

# Flatten in column-major order
flat_wg_index = wg0 + wg1 * num_wg_0
group_index = flat_wg_index // group_size_n

# Main case, forming full groups of GROUP_SIZE_N tiles along N
main_new_wg0 = group_index % num_wg_0
main_new_wg1 = (
group_index // num_wg_0
) * group_size_n + flat_wg_index % group_size_n

# Tailing case, when N tiles is not a multiple of GROUP_SIZE_N
full_tiles_n = floor(num_wg_1 / group_size_n) * group_size_n
tail_tiles_n = num_wg_1 - full_tiles_n
total_full = full_tiles_n * num_wg_0
tail_linear = flat_wg_index - total_full
tail_new_wg0 = tail_linear // tail_tiles_n
tail_new_wg1 = full_tiles_n + tail_linear % tail_tiles_n

# Select tail path if we can no longer form full groups
new_wg0 = Piecewise(
(tail_new_wg0, (flat_wg_index >= total_full) & (tail_tiles_n > 0)),
(main_new_wg0, True),
)
new_wg1 = Piecewise(
(tail_new_wg1, (flat_wg_index >= total_full) & (tail_tiles_n > 0)),
(main_new_wg1, True),
)

return new_wg0, new_wg1