Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions tests/kernel/wave_gemm_mxfp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,35 +337,55 @@ def testScaledBatchedGemmMXFP4Codegen(use_water_backend: bool, tmp_path: Path):
# We encode the exact registers and wait counts as we want to know if
# they suddenly change due to backend or upstream MLIR changes.
if use_water_backend:
vgpr_count = 164
vgpr_count = 146
vgpr_spill_count = 0
sgpr_count = 57
sgpr_count = 61
sgpr_spill_count = 0
waitcounts = [
"s_waitcnt lgkmcnt(0)",
"s_waitcnt vmcnt(0)",
"s_waitcnt lgkmcnt(6)",
"s_waitcnt lgkmcnt(4)",
"s_waitcnt lgkmcnt(2)",
"s_waitcnt lgkmcnt(0)",
"s_waitcnt vmcnt(0) lgkmcnt(0)",
"s_waitcnt vmcnt(0)",
"s_waitcnt lgkmcnt(7)",
"s_waitcnt lgkmcnt(6)",
"s_waitcnt lgkmcnt(6)",
"s_waitcnt lgkmcnt(5)",
"s_waitcnt lgkmcnt(4)",
"s_waitcnt lgkmcnt(3)",
"s_waitcnt lgkmcnt(1)",
"s_waitcnt lgkmcnt(0)",
"s_waitcnt lgkmcnt(3)",
"s_waitcnt lgkmcnt(1)",
"s_waitcnt lgkmcnt(0)",
]
else:
vgpr_count = 164
vgpr_count = 140
vgpr_spill_count = 0
sgpr_count = 59
sgpr_spill_count = 0
waitcounts = [
"s_waitcnt lgkmcnt(0)",
"s_waitcnt vmcnt(0)",
"s_waitcnt lgkmcnt(6)",
"s_waitcnt lgkmcnt(4)",
"s_waitcnt lgkmcnt(2)",
"s_waitcnt lgkmcnt(0)",
"s_waitcnt lgkmcnt(4)",
"s_waitcnt lgkmcnt(3)",
"s_waitcnt lgkmcnt(1)",
"s_waitcnt lgkmcnt(0)",
"s_waitcnt vmcnt(0) lgkmcnt(0)",
"s_waitcnt vmcnt(0)",
"s_waitcnt lgkmcnt(1)",
"s_waitcnt lgkmcnt(7)",
"s_waitcnt lgkmcnt(6)",
"s_waitcnt lgkmcnt(6)",
"s_waitcnt lgkmcnt(5)",
"s_waitcnt lgkmcnt(4)",
"s_waitcnt lgkmcnt(3)",
"s_waitcnt lgkmcnt(1)",
"s_waitcnt lgkmcnt(0)",
"s_waitcnt lgkmcnt(3)",
"s_waitcnt lgkmcnt(1)",
"s_waitcnt lgkmcnt(0)",
Expand All @@ -387,6 +407,11 @@ def testScaledBatchedGemmMXFP4Codegen(use_water_backend: bool, tmp_path: Path):
metadata.waitcnt_ops == waitcounts
), f"Expected {waitcounts} waitcnt operations, got {metadata.waitcnt_ops}"

# Verify interleaved scale instructions are generated.
# op_sel_hi:[0,0,0] selects lower bytes (0,1), op_sel_hi:[1,1,0] selects upper bytes (2,3).
assert "op_sel_hi:[0,0,0]" in text, "Expected lower interleaved scale instructions"
assert "op_sel_hi:[1,1,0]" in text, "Expected upper interleaved scale instructions"


@require_e2e
@require_cdna4
Expand Down
4 changes: 2 additions & 2 deletions wave_lang/kernel/compiler/wave_codegen/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,14 +1200,14 @@ def cast_vector(emitter: WaveEmitter, value, *, element_type: Optional[IrType] =
return vector_d.broadcast(vector_type, value)


def cast_scalar(emitter: WaveEmitter, value):
def cast_scalar(emitter: WaveEmitter, value: Value, position: int = 0) -> Value:
proxy_value = cast_py_value(emitter, value)
value = proxy_value.ir_value

# After scalar promotion, promote to vector.
if isinstance(value.type, VectorType):
# Vector -> scalar.
return vector_d.extract(value, static_position=[0], dynamic_position=[])
return vector_d.extract(value, static_position=[position], dynamic_position=[])
else:
# Already a scalar. Coerce or return.
# No target element_type.
Expand Down
55 changes: 41 additions & 14 deletions wave_lang/kernel/compiler/wave_codegen/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,13 +437,21 @@ def handle_mma(emitter: WaveEmitter, node: fx.Node):


def emit_mfma_scaled(
m: int, n: int, k: int, acc: Value, values: list[Value], scales: list[Value]
m: int,
n: int,
k: int,
acc: Value,
values: list[Value],
scales: list[Value],
idx_a: int,
idx_b: int,
) -> Value:
m = get_constant_attr(m, IntegerType.get_signless(32))
n = get_constant_attr(n, IntegerType.get_signless(32))
k = get_constant_attr(k, IntegerType.get_signless(32))
idx_a = get_constant_attr(0, IntegerType.get_signless(32))
idx_b = get_constant_attr(0, IntegerType.get_signless(32))
i32 = IntegerType.get_signless(32)
m = get_constant_attr(m, i32)
n = get_constant_attr(n, i32)
k = get_constant_attr(k, i32)
idx_a = get_constant_attr(idx_a, i32)
idx_b = get_constant_attr(idx_b, i32)

result = amdgpu_d.scaled_mfma(
m=m,
Expand Down Expand Up @@ -488,7 +496,9 @@ def emit_wmma_scaled(
@handle_op(scaled_mma)
def handle_scaled_mma(emitter: WaveEmitter, node: fx.Node):
try:
lhs, lhs_scale, rhs, rhs_scale, acc, mma_type = node.args
lhs, lhs_scale, rhs, rhs_scale, acc, mma_type, scale_idx_a, scale_idx_b = (
node.args
)
acc = cast_vector(emitter, acc)
values = [cast_vector(emitter, val) for val in [lhs, rhs]]
except ValueError as e:
Expand All @@ -515,8 +525,21 @@ def handle_scaled_mma(emitter: WaveEmitter, node: fx.Node):
scales = [cast_vector(emitter, val) for val in [lhs_scale, rhs_scale]]
result = emit_wmma_scaled(m, n, k, acc, values, scales)
else:
scales = [cast_scalar(emitter, val) for val in [lhs_scale, rhs_scale]]
result = emit_mfma_scaled(m, n, k, acc, values, scales)
pos_a = scale_idx_a if scale_idx_a is not None else 0
pos_b = scale_idx_b if scale_idx_b is not None else 0
scale_a = (
cast_vector(emitter, lhs_scale)
if scale_idx_a is not None
else cast_scalar(emitter, lhs_scale)
)
scale_b = (
cast_vector(emitter, rhs_scale)
if scale_idx_b is not None
else cast_scalar(emitter, rhs_scale)
)
result = emit_mfma_scaled(
m, n, k, acc, values, [scale_a, scale_b], pos_a, pos_b
)

emitter.bind_node_proxy(node, IRProxyValue(result))

Expand Down Expand Up @@ -2166,18 +2189,22 @@ def handle_reshape(emitter: WaveEmitter, node: fx.Node):
# Determine whether to extract or combine.
if len(args) > 1:
vectors = [cast_vector(emitter, arg) for arg in args]
shape = vectors[0].type.shape[0]
if shape == 1:
shape = vectors[0].type.shape[0] if vectors[0].type.rank > 0 else 0
if shape <= 1:
# If source is 1-element vector or scalar (which will be casted to
# 1-element vector by `cast_vector`), we can construct the result
# 0-d vector by `cast_vector`), we can construct the result
# vector using `extract` and a single `from_elements` op instead of
# series of `insert_strided_slice` ops.
values = [
vector_d.extract(vector, static_position=[0], dynamic_position=[])
vector_d.extract(
vector,
static_position=[] if vector.type.rank == 0 else [0],
dynamic_position=[],
)
for vector in vectors
]
element_type = vectors[0].type.element_type
vector_type = VectorType.get([shape * len(args)], element_type)
vector_type = VectorType.get([len(args)], element_type)
result = vector_d.from_elements(vector_type, values)
emitter.bind_node_proxy(node, IRProxyValue(result))
return
Expand Down
2 changes: 2 additions & 0 deletions wave_lang/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1922,6 +1922,8 @@ class ScaledMMA(MMABase):
rhs_scale: fx.Node
acc: fx.Node
mma_type: Optional["ScaledMMAType"] = None
scale_idx_a: Optional[int] = None
scale_idx_b: Optional[int] = None

@property
def indexing_dims(self) -> list[IndexSymbol]:
Expand Down
2 changes: 2 additions & 0 deletions wave_lang/kernel/wave/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from .hardware_transpose import mark_hardware_transpose_candidates
from .hoisting import hoist_loop_invariant_ops
from .in_thread_transpose import in_thread_transpose
from .interleave_scaled_mma import interleave_scaled_mma
from .location_check_pass import location_check_pass
from .memory_analysis.minimize_shared_allocs import minimize_shared_allocs
from .minimize_global_loads import minimize_global_loads
Expand Down Expand Up @@ -801,6 +802,7 @@ def _trace_launchable_and_get_kernel_signature(
trace,
options.minimize_shared_allocs,
),
partial(interleave_scaled_mma, trace, launchable.constraints),
]
graph_passes += [
partial(
Expand Down
102 changes: 102 additions & 0 deletions wave_lang/kernel/wave/interleave_scaled_mma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2025 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import torch.fx as fx

from .._support.tracing import CapturedTrace
from ..ops.wave_ops import (
NewScalar,
Reshape,
ScaledMMA,
get_custom,
)
from .constraints import (
Constraint,
ScaledMMAType,
)
from .utils.general_utils import get_hardware_constraint


def interleave_scaled_mma(trace: CapturedTrace, constraints: list[Constraint]):
"""
Packs scale values of ScaledMMA operations into shared VGPRs
using byte indexing to reduce register pressure.

When two ScaledMMA ops exist in the same subgraph, their scales are
packed into one register [a0, b0, a1, b1] and the first op uses
scale_idx (0,1) while the second uses (2,3).

Unpaired ops fall back to [a, b, 0, 0] with scale_idx (0,1).
"""
hardware_constraint = get_hardware_constraint(constraints)

def is_target_scaled_mma(node: fx.Node) -> bool:
custom = get_custom(node)
if not isinstance(custom, ScaledMMA):
return False
if custom.scale_idx_a is not None or custom.scale_idx_b is not None:
return False
mma_type = custom.mma_type or hardware_constraint.mma_type
return mma_type == ScaledMMAType.F32_16x16x128_F8F6F4

nodes = trace.walk(is_target_scaled_mma)
if not nodes:
return

# Group nodes by subgraph so we can pair within each one.
graph_groups: dict[fx.Graph, list[fx.Node]] = {}
for node in nodes:
graph_groups.setdefault(node.graph, []).append(node)

for group in graph_groups.values():
i = 0
while i < len(group):
mma_a = get_custom(group[i])
scale_dtype = get_custom(mma_a.lhs_scale).type.dtype

if i + 1 < len(group):
# Pair: pack all 4 scales into one register.
mma_b = get_custom(group[i + 1])

with mma_a.graph.inserting_before(mma_a.fx_node):
combined = Reshape(
[
mma_a.lhs_scale,
mma_a.rhs_scale,
mma_b.lhs_scale,
mma_b.rhs_scale,
],
{},
).add_to_graph(mma_a.graph, loc=mma_a.location)

mma_a.update_arg("lhs_scale", combined)
mma_a.update_arg("rhs_scale", combined)
mma_a.update_arg("scale_idx_a", 0)
mma_a.update_arg("scale_idx_b", 1)

mma_b.update_arg("lhs_scale", combined)
mma_b.update_arg("rhs_scale", combined)
mma_b.update_arg("scale_idx_a", 2)
mma_b.update_arg("scale_idx_b", 3)

i += 2
else:
# Unpaired: pad upper half with zeros.
with mma_a.graph.inserting_before(mma_a.fx_node):
zero = NewScalar(0.0, scale_dtype).add_to_graph(
mma_a.graph, loc=mma_a.location
)
combined = Reshape(
[mma_a.lhs_scale, mma_a.rhs_scale, zero, zero],
{},
).add_to_graph(mma_a.graph, loc=mma_a.location)

mma_a.update_arg("lhs_scale", combined)
mma_a.update_arg("rhs_scale", combined)
mma_a.update_arg("scale_idx_a", 0)
mma_a.update_arg("scale_idx_b", 1)

i += 1
10 changes: 9 additions & 1 deletion wave_lang/kernel/wave/water.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,15 @@ def add_opt(pipeline):

def add_transform(transform: str, entry_point: str) -> tuple[str, dict[str, Any]]:
nonlocal mlir_asm
# Erase the last occurrence of '}' from mlir_asm which closes the module operation
# Add transform.with_named_sequence attribute to the module if missing.
attr_name = "transform.with_named_sequence"
if attr_name not in mlir_asm:
mlir_asm = mlir_asm.replace(
"gpu.container_module",
"gpu.container_module, " + attr_name,
1,
)
# Erase the last occurrence of '}' from mlir_asm which closes the module operation.
last_close = mlir_asm.rfind("}")
if last_close != -1:
mlir_asm = mlir_asm[:last_close]
Expand Down
Loading