Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7,814 changes: 7,808 additions & 6 deletions examples/python/7.1_schedule.py

Large diffs are not rendered by default.

742 changes: 742 additions & 0 deletions examples/python/7.2_mxfp4_gemm_preshuffle_scale.py

Large diffs are not rendered by default.

21 changes: 9 additions & 12 deletions lit_tests/kernel/wave/attention/pipelined_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,36 +372,33 @@ def test_bshd_attention_pipelined_prefetch():
base_attention = wave_compile(options, base_attention)
print(base_attention.asm)

# CHECK: func.func @base_attention
# CHECK-LABEL: func.func @base_attention
# CHECK: {{.*}} = scf.for
# CHECK-COUNT-16: vector.load
# CHECK: arith.subf
# CHECK: math.exp2
# CHECK: math.exp2
# CHECK: arith.mulf
# CHECK: arith.addf
# CHECK-COUNT-16: vector.extract
# CHECK-COUNT-16: arith.addf
# CHECK: vector.extract
# CHECK: arith.addf
# CHECK: vector.broadcast
# CHECK: gpu.shuffle
# CHECK: arith.addf
# CHECK: arith.addf
# CHECK: arith.truncf
# CHECK: arith.truncf
# CHECK: vector.extract
# CHECK: vector.broadcast
# CHECK: arith.mulf
# CHECK: arith.mulf
# CHECK-COUNT-8: vector.extract_strided_slice
# CHECK-COUNT-32: amdgpu.mfma
# CHECK-COUNT-8: vector.load
# CHECK-COUNT-8: vector.extract
# CHECK: amdgpu.mfma
# CHECK: vector.load
# CHECK: vector.extract
# CHECK: vector.from_elements
# CHECK: vector.from_elements
# CHECK: amdgpu.lds_barrier
# CHECK-COUNT-32: vector.load
# CHECK-COUNT-4: vector.load
# CHECK-COUNT-8: amdgpu.mfma
# CHECK: vector.load
# CHECK: amdgpu.mfma


@run_test
Expand Down Expand Up @@ -434,7 +431,7 @@ def test_bshd_attention_pipelined_prefetch_pingpong():
base_attention = wave_compile(options, base_attention)
print(base_attention.asm)

# CHECK: func.func @base_attention
# CHECK-LABEL: func.func @base_attention

# CHECK: scf.if
# CHECK-NEXT: rocdl.s.barrier
Expand Down
195 changes: 195 additions & 0 deletions lit_tests/kernel/wave/merge_scale_reads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# RUN: python %s | FileCheck %s

"""
Test merge_contiguous_reads pass on pre-shuffled (e8m0_shuffle) scale
reads for MXFP4 GEMM.

The e8m0_shuffle index mapping rearranges scale data so that each thread's
scale elements land in contiguous groups in physical memory. The merge pass
should combine the expanded scalar reads into wider vector loads:

BLOCK_K=128 -> 4 scale elements -> 2 groups of 2 -> vector<2xi8>
BLOCK_K=256 -> 8 scale elements -> 2 groups of 4 -> vector<4xi8>

The shuffle layout requires K/32 >= 64 (i.e. K >= 2048) for the groups to
land contiguously in the row-major [M, K/32] scale tensor.

Also verifies that the opsel_scaled_mfma pass enables byte selection in
amdgpu.scaled_mfma, replacing scalar scale operands with vector operands
and scalesIdxA/scalesIdxB attributes for efficient hardware extraction.
"""

import wave_lang.kernel.lang as tkl
import wave_lang.kernel.wave as tkw
from wave_lang.kernel.lang.global_symbols import *
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
from wave_lang.kernel.wave.constraints import ScaledMMAType
from wave_lang.kernel.wave.utils.general_utils import (
get_default_scheduling_params,
run_test,
)

# Symbols shared by all tests.
M = tkl.sym.M
N = tkl.sym.N
K = tkl.sym.K
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K = tkl.sym.BLOCK_K
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
K_SCALE_SHUFFLED = tkl.sym.K_SCALE_SHUFFLED


def get_preshuffle_kernel():
"""Return the pre-shuffled MXFP4 GEMM kernel with e8m0_shuffle mappings."""
constraints: list[tkw.Constraint] = [
tkw.WorkgroupConstraint(M, BLOCK_M, 0),
tkw.WorkgroupConstraint(N, BLOCK_N, 1),
tkw.TilingConstraint(K, BLOCK_K),
tkw.WaveConstraint(M, BLOCK_M / 2),
tkw.WaveConstraint(N, BLOCK_N / 2),
tkw.HardwareConstraint(
threads_per_wave=64,
mma_type=ScaledMMAType.F32_16x16x128_F8F6F4,
),
]

# e8m0_shuffle index mapping: logical (iter0, iter1) -> physical (row, col).
i = tkw.IndexMapping.iterator(0)
j = tkw.IndexMapping.iterator(1)

shuffle_expr = (
(j // 32) * ((K_SCALE_SHUFFLED // 8) * 256)
+ (i // 8) * 256
+ ((i % 8) % 4) * 64
+ ((j % 32) % 16) * 4
+ (((i % 8) // 4) * 2)
+ ((j % 32) // 16)
)

a_scale_mapping = tkw.IndexMapping(
num_iterators=2,
inputs={
M: shuffle_expr // K_SCALE_SHUFFLED,
K: shuffle_expr % K_SCALE_SHUFFLED,
},
outputs={K: i, M: j},
)

k = tkw.IndexMapping.iterator(0)
n = tkw.IndexMapping.iterator(1)

shuffle_expr_b = (
(n // 32) * ((K_SCALE_SHUFFLED // 8) * 256)
+ (k // 8) * 256
+ ((k % 8) % 4) * 64
+ ((n % 32) % 16) * 4
+ (((k % 8) // 4) * 2)
+ ((n % 32) // 16)
)

b_scale_mapping = tkw.IndexMapping(
num_iterators=2,
inputs={
N: shuffle_expr_b // K_SCALE_SHUFFLED,
K: shuffle_expr_b % K_SCALE_SHUFFLED,
},
outputs={K: k, N: n},
)

@tkw.wave(constraints)
def preshuffle_scaled_mma(
a: tkl.Memory[M, K / 2, ADDRESS_SPACE, tkl.i8],
a_scale: tkl.Memory[M, K / 32, GLOBAL_ADDRESS_SPACE, tkl.i8],
b: tkl.Memory[N, K / 2, ADDRESS_SPACE, tkl.i8],
b_scale: tkl.Memory[N, K / 32, GLOBAL_ADDRESS_SPACE, tkl.i8],
c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f32],
):
c_reg = tkl.Register[M, N, tkl.f32](0.0)

@tkw.iterate(K, init_args=[c_reg])
def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
a_reg = tkw.read(a)
a_reg = tkw.bitcast(a_reg, tkl.f4e2m1fn)
a_scale_reg = tkw.read(a_scale, mapping=a_scale_mapping)
a_scale_reg = tkw.bitcast(a_scale_reg, tkl.f8e8m0fnu)
b_reg = tkw.read(b)
b_reg = tkw.bitcast(b_reg, tkl.f4e2m1fn)
b_scale_reg = tkw.read(b_scale, mapping=b_scale_mapping)
b_scale_reg = tkw.bitcast(b_scale_reg, tkl.f8e8m0fnu)
acc = tkw.scaled_mma(a_reg, a_scale_reg, b_reg, b_scale_reg, acc)
return acc

tkw.write(repeat, c)

return preshuffle_scaled_mma


def compile_and_print(m, n, k, block_k):
"""Compile the preshuffle kernel with given dimensions and print MLIR."""
k_scale_shuffled = (((k // 32) + 7) // 8) * 8
hyperparams = {
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
BLOCK_M: 128,
BLOCK_N: 128,
BLOCK_K: block_k,
M: m,
N: n,
K: k,
K_SCALE_SHUFFLED: k_scale_shuffled,
}
hyperparams.update(get_default_scheduling_params())

options = WaveCompileOptions(
subs=hyperparams,
canonicalize=True,
device="hip",
target="gfx950",
compile_to_mlir=True,
use_global_to_shared=True,
)
kernel = get_preshuffle_kernel()
result = wave_compile(options, kernel)
print(result.asm)


@run_test
def test_preshuffle_scale_merge_block_k_128():
# BLOCK_K=128: 4 scale elements per thread -> 2 groups of 2 -> vector<2xi8>.
compile_and_print(m=512, n=512, k=2048, block_k=128)

# CHECK-LABEL: test_preshuffle_scale_merge_block_k_128

# Each scale tensor produces 2 merged vector<2xi8> loads from global.
# CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<2xi8>
# CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<2xi8>
# CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<2xi8>
# CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<2xi8>

# No unmerged scalar scale loads from global should remain.
# CHECK-NOT: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<1xi8>


@run_test
def test_preshuffle_scale_merge_block_k_256():
# BLOCK_K=256: 8 scale elements per thread -> 2 groups of 4 -> vector<4xi8>.
compile_and_print(m=512, n=512, k=2048, block_k=256)

# CHECK-LABEL: test_preshuffle_scale_merge_block_k_256

# Each scale tensor produces 2 merged vector<4xi8> loads from global.
# CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<4xi8>
# CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<4xi8>
# CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<4xi8>
# CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<4xi8>

# No unmerged scalar scale loads from global should remain.
# CHECK-NOT: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<1xi8>

# Check that amdgpu.scaled_mfma uses opsel (indexed access into scale values)
# The key indicator is the [N] indexing syntax on f8E8M0FNU scale operands
# CHECK: amdgpu.scaled_mfma {{.*}} (%{{.*}}[{{[0-9]+}}] * %{{.*}}) * (%{{.*}}[{{[0-9]+}}] * %{{.*}}) + %{{.*}} : vector<4xf8E8M0FNU>, vector<{{.*}}xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<{{.*}}xf4E2M1FN>, vector<4xf32>

# Verify that we're not using scalar scale extracts (the old pattern)
# If opsel is working, we should NOT see vector.extract before scaled_mfma
# CHECK-NOT: vector.extract %{{.*}}[0] : f8E8M0FNU
5 changes: 2 additions & 3 deletions lit_tests/kernel/wave/mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,10 +771,9 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:

# CHECK: %[[ITER_COND:.*]] = affine.apply #[[MAP1]]()[%[[ITER_ARG]]]
# CHECK: %[[COND1:.*]] = arith.cmpi eq, %[[ITER_COND]], %[[C0]] : index
# CHECK: scf.if %[[COND1]] {
# CHECK-NEXT: scf.if %[[COND0]] {
# CHECK: %[[COND_ANDED:.*]] = arith.andi %[[COND1]], %[[COND0]] : i1
# CHECK: scf.if %[[COND_ANDED]] {
# CHECK-NEXT: rocdl.s.barrier.signal id = -3
# CHECK-NEXT: }
# CHECK-NEXT: }

# CHECK-COUNT-4: rocdl.wmma
Expand Down
4 changes: 0 additions & 4 deletions lit_tests/kernel/wave/speculative_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,10 @@ def test_speculative_decoding():
# CHECK: arith.divf
# CHECK: arith.cmpf
# CHECK: arith.ori
# CHECK: arith.xori
# CHECK: vector.extract
# CHECK: scf.if
# CHECK: scf.if
# CHECK: vector.load
# CHECK: vector.store
# CHECK: scf.yield
# CHECK: scf.yield

# --- Reduction and arithmetic patterns:
# CHECK: gpu.shuffle up
Expand Down
30 changes: 23 additions & 7 deletions wave_lang/kernel/wave/asm/handlers_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,15 @@ def handle_mfma_op(self, operation: amdgpu_d.MFMAOp, kernel_info: KernelInfo):
def handle_scaled_mfma_op(
self, operation: amdgpu_d.ScaledMFMAOp, kernel_info: KernelInfo
):
"""Handle amdgpu.scaled_mfma operations - emit scaled MFMA instruction for MXFP4/FP6/FP8."""
"""Handle amdgpu.scaled_mfma operations - emit scaled MFMA instruction for MXFP4/FP6/FP8.

# Scaled MFMA format: %result = amdgpu.scaled_mfma M x N x K
# (%scaleA * %dataA) * (%scaleB * %dataB) + %acc
# where scaleA and scaleB are scalar f8E8M0FNU values (often from vector.extract)
Scaled MFMA format: %result = amdgpu.scaled_mfma M x N x K
(%scaleA * %dataA) * (%scaleB * %dataB) + %acc

Supports both scalar f8E8M0FNU and vector<4xf8E8M0FNU> scale types.
When the scale is a vector<4xf8E8M0FNU>, the scalesIdx attribute
selects which byte within the 32-bit VGPR to use (opsel).
"""
from .kernel_mfma import _MFMASupport

ctx = self.walker.kernel_ctx
Expand All @@ -342,13 +346,21 @@ def handle_scaled_mfma_op(
cbsz = _MFMASupport._get_scaled_mfma_format_code(a_type_str)
blgp = _MFMASupport._get_scaled_mfma_format_code(b_type_str)

# Extract opsel (byte index within scale VGPR) from attributes
scales_idx_a = int(operation.attributes["scalesIdxA"])
scales_idx_b = int(operation.attributes["scalesIdxB"])

# Get operands based on actual MLIR structure
# Operand order: sourceA, sourceB, destC, scaleA, scaleB
data_a_ssa = str(operation.operands[0]) # sourceA: vector<32xf4E2M1FN>
data_b_ssa = str(operation.operands[1]) # sourceB: vector<32xf4E2M1FN>
acc_ssa = str(operation.operands[2]) # destC: vector<4xf32>
scale_a_ssa = str(operation.operands[3]) # scaleA: f8E8M0FNU (scalar)
scale_b_ssa = str(operation.operands[4]) # scaleB: f8E8M0FNU (scalar)
scale_a_ssa = str(
operation.operands[3]
) # scaleA: f8E8M0FNU or vector<4xf8E8M0FNU>
scale_b_ssa = str(
operation.operands[4]
) # scaleB: f8E8M0FNU or vector<4xf8E8M0FNU>

# Get registers from kernel context
scale_a_reg = ctx.ssa_to_reg.get(scale_a_ssa)
Expand All @@ -363,7 +375,9 @@ def handle_scaled_mfma_op(
# For MXFP4: 32 elements of FP4 = 16 bytes = 4 VGPRs (4 bytes/VGPR)
# vector<32xf4E2M1FN> bitcast from vector<16xi8> -> 4 VGPRs

# Scale registers should be single VGPRs (extracted from vector<1xf8E8M0FNU>)
# Scale register: either a single VGPR (scalar f8E8M0FNU)
# or a single VGPR containing 4 packed bytes (vector<4xf8E8M0FNU>).
# In both cases it maps to a single VGPR register.
if isinstance(scale_a_reg, (list, tuple)):
scale_a_vreg = scale_a_reg[0] if len(scale_a_reg) > 0 else None
else:
Expand All @@ -390,6 +404,8 @@ def handle_scaled_mfma_op(
acc_regs if acc_regs and len(acc_regs) == 4 else None,
cbsz=cbsz,
blgp=blgp,
scales_idx_a=scales_idx_a,
scales_idx_b=scales_idx_b,
)

# Track result in SSA mapping
Expand Down
11 changes: 9 additions & 2 deletions wave_lang/kernel/wave/asm/kernel_mfma.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ def emit_mfma_f32_16x16x128_f8f6f4(
acc_regs: Optional[Tuple[KReg, ...]] = None,
cbsz: int = 4,
blgp: int = 4,
scales_idx_a: int = 0,
scales_idx_b: int = 0,
) -> Tuple[KReg, ...]:
"""
Emit scaled MFMA instruction for MXFP4 (16x16x128 F8F6F4).
Expand All @@ -194,19 +196,24 @@ def emit_mfma_f32_16x16x128_f8f6f4(
Args:
a_regs: Tuple of 4 VGPRs for A operand (32 x f4E2M1FN packed as i8)
b_regs: Tuple of 4 VGPRs for B operand (32 x f4E2M1FN packed as i8)
a_scale_reg: Single VGPR for A scale factor (f8E8M0FNU)
b_scale_reg: Single VGPR for B scale factor (f8E8M0FNU)
a_scale_reg: Single VGPR for A scale factor (f8E8M0FNU or
4 packed bytes with opsel byte selection)
b_scale_reg: Single VGPR for B scale factor (same as above)
acc_regs: Optional tuple of 4 VGPRs for accumulator (f32x4)
If None, allocates new result registers
cbsz: Format code for A source data (0=FP8, 1=BF8, 2=FP6_E2M3,
3=FP6_E3M2, 4=FP4). Default 4 (FP4).
blgp: Format code for B source data. Same encoding as cbsz.
Default 4 (FP4).
scales_idx_a: Byte index (0-3) within the A scale VGPR. Default 0.
scales_idx_b: Byte index (0-3) within the B scale VGPR. Default 0.

Returns:
Tuple of 4 VGPRs containing the result
"""
modifiers = f"cbsz:{cbsz} blgp:{blgp}"
if scales_idx_a != 0 or scales_idx_b != 0:
modifiers += f" op_sel_hi:[0,0,0,{scales_idx_a},{scales_idx_b}]"

# Build operand ranges - For FP4: 32 elements = 16 bytes = 4 VGPRs
a_range = KRegRange(a_regs[0], len(a_regs), alignment=4)
Expand Down
7 changes: 7 additions & 0 deletions wave_lang/kernel/wave/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from .type_inference import infer_types
from .wave_schedule import WaveSchedule
from .workgroup_reordering import reorder_workgroups
from .opsel_scaled_mfma import apply_opsel_scaled_mfma

# Utilities.
from .utils.compile_utils import canonicalize_module, apply_transform, compile_to_vmfb
Expand Down Expand Up @@ -652,6 +653,12 @@ def compile_launchable_to_mlir(
if options.canonicalize:
canonicalize_module(mb.module_op)

# Replace scalar extract+bitcast scale chains on scaled_mfma ops
# with vector-level bitcast and opsel byte selection.
apply_opsel_scaled_mfma(mb.module_op)
if options.canonicalize:
canonicalize_module(mb.module_op)

return mb, trace, exe, kernel_sig, entrypoint_name


Expand Down
Loading