diff --git a/tests/kernel/wave_gemm_mxfp_test.py b/tests/kernel/wave_gemm_mxfp_test.py index 15925f0de..28cad65c4 100644 --- a/tests/kernel/wave_gemm_mxfp_test.py +++ b/tests/kernel/wave_gemm_mxfp_test.py @@ -337,35 +337,55 @@ def testScaledBatchedGemmMXFP4Codegen(use_water_backend: bool, tmp_path: Path): # We encode the exact registers and wait counts as we want to know if # they suddenly change due to backend or upstream MLIR changes. if use_water_backend: - vgpr_count = 164 + vgpr_count = 146 vgpr_spill_count = 0 - sgpr_count = 57 + sgpr_count = 61 sgpr_spill_count = 0 waitcounts = [ "s_waitcnt lgkmcnt(0)", "s_waitcnt vmcnt(0)", + "s_waitcnt lgkmcnt(6)", + "s_waitcnt lgkmcnt(4)", + "s_waitcnt lgkmcnt(2)", + "s_waitcnt lgkmcnt(0)", "s_waitcnt vmcnt(0) lgkmcnt(0)", "s_waitcnt vmcnt(0)", "s_waitcnt lgkmcnt(7)", + "s_waitcnt lgkmcnt(6)", + "s_waitcnt lgkmcnt(6)", "s_waitcnt lgkmcnt(5)", - "s_waitcnt lgkmcnt(4)", + "s_waitcnt lgkmcnt(3)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(0)", "s_waitcnt lgkmcnt(3)", "s_waitcnt lgkmcnt(1)", "s_waitcnt lgkmcnt(0)", ] else: - vgpr_count = 164 + vgpr_count = 140 vgpr_spill_count = 0 sgpr_count = 59 sgpr_spill_count = 0 waitcounts = [ "s_waitcnt lgkmcnt(0)", "s_waitcnt vmcnt(0)", + "s_waitcnt lgkmcnt(6)", + "s_waitcnt lgkmcnt(4)", + "s_waitcnt lgkmcnt(2)", + "s_waitcnt lgkmcnt(0)", + "s_waitcnt lgkmcnt(4)", + "s_waitcnt lgkmcnt(3)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(0)", "s_waitcnt vmcnt(0) lgkmcnt(0)", "s_waitcnt vmcnt(0)", - "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(7)", + "s_waitcnt lgkmcnt(6)", + "s_waitcnt lgkmcnt(6)", "s_waitcnt lgkmcnt(5)", - "s_waitcnt lgkmcnt(4)", + "s_waitcnt lgkmcnt(3)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(0)", "s_waitcnt lgkmcnt(3)", "s_waitcnt lgkmcnt(1)", "s_waitcnt lgkmcnt(0)", @@ -387,6 +407,11 @@ def testScaledBatchedGemmMXFP4Codegen(use_water_backend: bool, tmp_path: Path): metadata.waitcnt_ops == waitcounts ), f"Expected {waitcounts} waitcnt operations, got {metadata.waitcnt_ops}" + # Verify interleaved scale instructions are generated. + # op_sel_hi:[0,0,0] selects lower bytes (0,1), op_sel_hi:[1,1,0] selects upper bytes (2,3). + assert "op_sel_hi:[0,0,0]" in text, "Expected lower interleaved scale instructions" + assert "op_sel_hi:[1,1,0]" in text, "Expected upper interleaved scale instructions" + @require_e2e @require_cdna4 diff --git a/wave_lang/kernel/compiler/wave_codegen/emitter.py b/wave_lang/kernel/compiler/wave_codegen/emitter.py index 970d6b52b..d1ed6b169 100644 --- a/wave_lang/kernel/compiler/wave_codegen/emitter.py +++ b/wave_lang/kernel/compiler/wave_codegen/emitter.py @@ -1200,14 +1200,14 @@ def cast_vector(emitter: WaveEmitter, value, *, element_type: Optional[IrType] = return vector_d.broadcast(vector_type, value) -def cast_scalar(emitter: WaveEmitter, value): +def cast_scalar(emitter: WaveEmitter, value: Value, position: int = 0) -> Value: proxy_value = cast_py_value(emitter, value) value = proxy_value.ir_value # After scalar promotion, promote to vector. if isinstance(value.type, VectorType): # Vector -> scalar. - return vector_d.extract(value, static_position=[0], dynamic_position=[]) + return vector_d.extract(value, static_position=[position], dynamic_position=[]) else: # Already a scalar. Coerce or return. # No target element_type. diff --git a/wave_lang/kernel/compiler/wave_codegen/handlers.py b/wave_lang/kernel/compiler/wave_codegen/handlers.py index 226a8f9fa..4d4ec6d5f 100644 --- a/wave_lang/kernel/compiler/wave_codegen/handlers.py +++ b/wave_lang/kernel/compiler/wave_codegen/handlers.py @@ -437,13 +437,21 @@ def handle_mma(emitter: WaveEmitter, node: fx.Node): def emit_mfma_scaled( - m: int, n: int, k: int, acc: Value, values: list[Value], scales: list[Value] + m: int, + n: int, + k: int, + acc: Value, + values: list[Value], + scales: list[Value], + idx_a: int, + idx_b: int, ) -> Value: - m = get_constant_attr(m, IntegerType.get_signless(32)) - n = get_constant_attr(n, IntegerType.get_signless(32)) - k = get_constant_attr(k, IntegerType.get_signless(32)) - idx_a = get_constant_attr(0, IntegerType.get_signless(32)) - idx_b = get_constant_attr(0, IntegerType.get_signless(32)) + i32 = IntegerType.get_signless(32) + m = get_constant_attr(m, i32) + n = get_constant_attr(n, i32) + k = get_constant_attr(k, i32) + idx_a = get_constant_attr(idx_a, i32) + idx_b = get_constant_attr(idx_b, i32) result = amdgpu_d.scaled_mfma( m=m, @@ -488,7 +496,9 @@ def emit_wmma_scaled( @handle_op(scaled_mma) def handle_scaled_mma(emitter: WaveEmitter, node: fx.Node): try: - lhs, lhs_scale, rhs, rhs_scale, acc, mma_type = node.args + lhs, lhs_scale, rhs, rhs_scale, acc, mma_type, scale_idx_a, scale_idx_b = ( + node.args + ) acc = cast_vector(emitter, acc) values = [cast_vector(emitter, val) for val in [lhs, rhs]] except ValueError as e: @@ -515,8 +525,21 @@ def handle_scaled_mma(emitter: WaveEmitter, node: fx.Node): scales = [cast_vector(emitter, val) for val in [lhs_scale, rhs_scale]] result = emit_wmma_scaled(m, n, k, acc, values, scales) else: - scales = [cast_scalar(emitter, val) for val in [lhs_scale, rhs_scale]] - result = emit_mfma_scaled(m, n, k, acc, values, scales) + pos_a = scale_idx_a if scale_idx_a is not None else 0 + pos_b = scale_idx_b if scale_idx_b is not None else 0 + scale_a = ( + cast_vector(emitter, lhs_scale) + if scale_idx_a is not None + else cast_scalar(emitter, lhs_scale) + ) + scale_b = ( + cast_vector(emitter, rhs_scale) + if scale_idx_b is not None + else cast_scalar(emitter, rhs_scale) + ) + result = emit_mfma_scaled( + m, n, k, acc, values, [scale_a, scale_b], pos_a, pos_b + ) emitter.bind_node_proxy(node, IRProxyValue(result)) @@ -2166,18 +2189,22 @@ def handle_reshape(emitter: WaveEmitter, node: fx.Node): # Determine whether to extract or combine. if len(args) > 1: vectors = [cast_vector(emitter, arg) for arg in args] - shape = vectors[0].type.shape[0] - if shape == 1: + shape = vectors[0].type.shape[0] if vectors[0].type.rank > 0 else 0 + if shape <= 1: # If source is 1-element vector or scalar (which will be casted to - # 1-element vector by `cast_vector`), we can construct the result + # 0-d vector by `cast_vector`), we can construct the result # vector using `extract` and a single `from_elements` op instead of # series of `insert_strided_slice` ops. values = [ - vector_d.extract(vector, static_position=[0], dynamic_position=[]) + vector_d.extract( + vector, + static_position=[] if vector.type.rank == 0 else [0], + dynamic_position=[], + ) for vector in vectors ] element_type = vectors[0].type.element_type - vector_type = VectorType.get([shape * len(args)], element_type) + vector_type = VectorType.get([len(args)], element_type) result = vector_d.from_elements(vector_type, values) emitter.bind_node_proxy(node, IRProxyValue(result)) return diff --git a/wave_lang/kernel/ops/wave_ops.py b/wave_lang/kernel/ops/wave_ops.py index 2aef02ed4..617ec500c 100644 --- a/wave_lang/kernel/ops/wave_ops.py +++ b/wave_lang/kernel/ops/wave_ops.py @@ -1922,6 +1922,8 @@ class ScaledMMA(MMABase): rhs_scale: fx.Node acc: fx.Node mma_type: Optional["ScaledMMAType"] = None + scale_idx_a: Optional[int] = None + scale_idx_b: Optional[int] = None @property def indexing_dims(self) -> list[IndexSymbol]: diff --git a/wave_lang/kernel/wave/compile.py b/wave_lang/kernel/wave/compile.py index 7d97bc272..6c5ba059c 100644 --- a/wave_lang/kernel/wave/compile.py +++ b/wave_lang/kernel/wave/compile.py @@ -64,6 +64,7 @@ from .hardware_transpose import mark_hardware_transpose_candidates from .hoisting import hoist_loop_invariant_ops from .in_thread_transpose import in_thread_transpose +from .interleave_scaled_mma import interleave_scaled_mma from .location_check_pass import location_check_pass from .memory_analysis.minimize_shared_allocs import minimize_shared_allocs from .minimize_global_loads import minimize_global_loads @@ -801,6 +802,7 @@ def _trace_launchable_and_get_kernel_signature( trace, options.minimize_shared_allocs, ), + partial(interleave_scaled_mma, trace, launchable.constraints), ] graph_passes += [ partial( diff --git a/wave_lang/kernel/wave/interleave_scaled_mma.py b/wave_lang/kernel/wave/interleave_scaled_mma.py new file mode 100644 index 000000000..22de3066c --- /dev/null +++ b/wave_lang/kernel/wave/interleave_scaled_mma.py @@ -0,0 +1,102 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import torch.fx as fx + +from .._support.tracing import CapturedTrace +from ..ops.wave_ops import ( + NewScalar, + Reshape, + ScaledMMA, + get_custom, +) +from .constraints import ( + Constraint, + ScaledMMAType, +) +from .utils.general_utils import get_hardware_constraint + + +def interleave_scaled_mma(trace: CapturedTrace, constraints: list[Constraint]): + """ + Packs scale values of ScaledMMA operations into shared VGPRs + using byte indexing to reduce register pressure. + + When two ScaledMMA ops exist in the same subgraph, their scales are + packed into one register [a0, b0, a1, b1] and the first op uses + scale_idx (0,1) while the second uses (2,3). + + Unpaired ops fall back to [a, b, 0, 0] with scale_idx (0,1). + """ + hardware_constraint = get_hardware_constraint(constraints) + + def is_target_scaled_mma(node: fx.Node) -> bool: + custom = get_custom(node) + if not isinstance(custom, ScaledMMA): + return False + if custom.scale_idx_a is not None or custom.scale_idx_b is not None: + return False + mma_type = custom.mma_type or hardware_constraint.mma_type + return mma_type == ScaledMMAType.F32_16x16x128_F8F6F4 + + nodes = trace.walk(is_target_scaled_mma) + if not nodes: + return + + # Group nodes by subgraph so we can pair within each one. + graph_groups: dict[fx.Graph, list[fx.Node]] = {} + for node in nodes: + graph_groups.setdefault(node.graph, []).append(node) + + for group in graph_groups.values(): + i = 0 + while i < len(group): + mma_a = get_custom(group[i]) + scale_dtype = get_custom(mma_a.lhs_scale).type.dtype + + if i + 1 < len(group): + # Pair: pack all 4 scales into one register. + mma_b = get_custom(group[i + 1]) + + with mma_a.graph.inserting_before(mma_a.fx_node): + combined = Reshape( + [ + mma_a.lhs_scale, + mma_a.rhs_scale, + mma_b.lhs_scale, + mma_b.rhs_scale, + ], + {}, + ).add_to_graph(mma_a.graph, loc=mma_a.location) + + mma_a.update_arg("lhs_scale", combined) + mma_a.update_arg("rhs_scale", combined) + mma_a.update_arg("scale_idx_a", 0) + mma_a.update_arg("scale_idx_b", 1) + + mma_b.update_arg("lhs_scale", combined) + mma_b.update_arg("rhs_scale", combined) + mma_b.update_arg("scale_idx_a", 2) + mma_b.update_arg("scale_idx_b", 3) + + i += 2 + else: + # Unpaired: pad upper half with zeros. + with mma_a.graph.inserting_before(mma_a.fx_node): + zero = NewScalar(0.0, scale_dtype).add_to_graph( + mma_a.graph, loc=mma_a.location + ) + combined = Reshape( + [mma_a.lhs_scale, mma_a.rhs_scale, zero, zero], + {}, + ).add_to_graph(mma_a.graph, loc=mma_a.location) + + mma_a.update_arg("lhs_scale", combined) + mma_a.update_arg("rhs_scale", combined) + mma_a.update_arg("scale_idx_a", 0) + mma_a.update_arg("scale_idx_b", 1) + + i += 1 diff --git a/wave_lang/kernel/wave/water.py b/wave_lang/kernel/wave/water.py index 62eed0425..048bfa8ec 100644 --- a/wave_lang/kernel/wave/water.py +++ b/wave_lang/kernel/wave/water.py @@ -383,7 +383,15 @@ def add_opt(pipeline): def add_transform(transform: str, entry_point: str) -> tuple[str, dict[str, Any]]: nonlocal mlir_asm - # Erase the last occurrence of '}' from mlir_asm which closes the module operation + # Add transform.with_named_sequence attribute to the module if missing. + attr_name = "transform.with_named_sequence" + if attr_name not in mlir_asm: + mlir_asm = mlir_asm.replace( + "gpu.container_module", + "gpu.container_module, " + attr_name, + 1, + ) + # Erase the last occurrence of '}' from mlir_asm which closes the module operation. last_close = mlir_asm.rfind("}") if last_close != -1: mlir_asm = mlir_asm[:last_close]