diff --git a/lit_tests/kernel/wave/mlir_roundtrip_pipeline.py b/lit_tests/kernel/wave/mlir_roundtrip_pipeline.py index 008510a8b..41783ed96 100644 --- a/lit_tests/kernel/wave/mlir_roundtrip_pipeline.py +++ b/lit_tests/kernel/wave/mlir_roundtrip_pipeline.py @@ -206,9 +206,6 @@ def gemm_progressive_roundtrip(): "remove_chained_getresult", "decompose_vmma_ops", "decompose_dot_mma", - "generate_bounds_exprs", - "location_check_pass", - "merge_contiguous_reads", } ) diff --git a/water/include/water/Dialect/Wave/IR/WaveAttrs.td b/water/include/water/Dialect/Wave/IR/WaveAttrs.td index 5f52dbb67..764550d8b 100644 --- a/water/include/water/Dialect/Wave/IR/WaveAttrs.td +++ b/water/include/water/Dialect/Wave/IR/WaveAttrs.td @@ -599,14 +599,10 @@ class WaveExprListAttrOf inputTypes> : Attr< def WaveReadWriteBoundsAttr : AttrDef { let mnemonic = "read_write_bounds"; let description = [{ - This attribute contains a dictionary mapping from symbolic dimension (strings) - to a WaveExprListAttr specifying the bounds of the read/write operations - this is attached to. - - Example: - ``` - #wave.read_write_bounds<{M = #wave.expr_list<[BLOCK_M] -> BLOCK_M * 2>}> - ``` + This attribute contains a dictionary mapping from symbolic dimensions + (strings) to a WaveExprListAttr specifying the bounds of the read/write + operations this is attached to. The dictionary may be sparse: only + dimensions that require masking need an entry. }]; let parameters = (ins "::mlir::DictionaryAttr":$mapping); diff --git a/water/lib/Dialect/Wave/IR/WaveOps.cpp b/water/lib/Dialect/Wave/IR/WaveOps.cpp index 547c07874..cbcfa2bb0 100644 --- a/water/lib/Dialect/Wave/IR/WaveOps.cpp +++ b/water/lib/Dialect/Wave/IR/WaveOps.cpp @@ -1531,9 +1531,12 @@ verifyIndexElementsPerThread(Operation *op, ArrayAttr indexAttr, return success(); } -// Check that if the given read/write operation has bound expressions specified, -// each symbolic dimension of the WaveTensorType has exactly one bound -// expression. +// Verify that every key in the bounds dictionary names a symbolic dimension of +// the WaveTensorType and that each value is a single-result WaveExprListAttr. +// The dictionary may be sparse: only dimensions that actually require masking +// (e.g. because the tile size does not evenly divide the dimension) need an +// entry. Dimensions without an entry are assumed to be fully in-bounds and +// will not generate mask operations during lowering. static LogicalResult verifyReadWriteBounds(Location loc, wave::WaveTensorType boundedType, DictionaryAttr bounds) { @@ -1545,12 +1548,12 @@ static LogicalResult verifyReadWriteBounds(Location loc, // TODO: consider refactoring bounds and other dictionary-like attributes to // be indexed by symbol expressions rather than string attributes to avoid // string comparisons everywhere. - SmallVector requiredSymbolNames = llvm::map_to_vector( + SmallVector validSymbolNames = llvm::map_to_vector( boundedType.getShape(), [](wave::WaveSymbolAttr symbol) { return symbol.getName(); }); - llvm::StringSet<> knownSymbolNames; + for (NamedAttribute value : bounds) { - if (!llvm::is_contained(requiredSymbolNames, value.getName().strref())) { + if (!llvm::is_contained(validSymbolNames, value.getName().strref())) { return emitError(loc) << "'bounds' specified for a symbol " << value.getName() << " not used in the " @@ -1566,15 +1569,6 @@ static LogicalResult verifyReadWriteBounds(Location loc, return emitError(loc) << "'bounds' must only contain single-result expressions"; } - - knownSymbolNames.insert(value.getName().strref()); - } - for (StringRef requiredName : requiredSymbolNames) { - if (knownSymbolNames.contains(requiredName)) - continue; - - return emitError(loc) << "bounds not provided for memory tensor symbol '" - << requiredName << "'"; } return success(); diff --git a/water/lib/Dialect/Wave/Transforms/LowerReadWriteOps.cpp b/water/lib/Dialect/Wave/Transforms/LowerReadWriteOps.cpp index 494f8e80c..2a0397c33 100644 --- a/water/lib/Dialect/Wave/Transforms/LowerReadWriteOps.cpp +++ b/water/lib/Dialect/Wave/Transforms/LowerReadWriteOps.cpp @@ -107,12 +107,15 @@ buildMask(Location loc, wave::WaveReadWriteBoundsAttr boundsDict, IntegerType i1Type = IntegerType::get(rewriter.getContext(), 1); VectorType maskType = VectorType::get({elementsPerThread}, i1Type); - // finalMask is the AND of per-dimension bound checks. + // finalMask is the AND of per-dimension bound checks. The bounds dict may + // be sparse: only dimensions that require masking have an entry. Dimensions + // without an entry are fully in-bounds and are skipped. Value finalMask; for (uint64_t d = 0; d < rank; ++d) { StringRef name = orderedSyms[d].getName(); Attribute a = boundsDict.getMapping().get(name); - assert(a && "bounds dict missing entry for dimension symbol"); + if (!a) + continue; auto boundAttr = cast(a); // Materialize bounds. FailureOr> boundValsFo = wave::materializeAffine( diff --git a/water/test/Dialect/Wave/lower-wave-to-mlir.mlir b/water/test/Dialect/Wave/lower-wave-to-mlir.mlir index 0d33672a6..61308ce4c 100644 --- a/water/test/Dialect/Wave/lower-wave-to-mlir.mlir +++ b/water/test/Dialect/Wave/lower-wave-to-mlir.mlir @@ -661,6 +661,34 @@ normalform.module [#wave.normal_form] { + // CHECK-LABEL: @lower_read_sparse_bounds + func.func @lower_read_sparse_bounds(%mem: !wave.tensor<[@M, @N] of f16, >) + attributes {wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 64, BLOCK_N = 64, M = 100, N = 64}>} { + %v = wave.read %mem index [{ + // CHECK: %[[BIDX_X:.*]] = gpu.block_id x + // CHECK: %[[TIDX_X:.*]] = gpu.thread_id x + // CHECK: %[[ROW:.*]] = affine.apply affine_map<()[s0, s1] -> (s0 * 64 + s1)>()[%[[BIDX_X]], %[[TIDX_X]]] + M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + T0, 1, 64)>, + N : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N + T1 * 32, 4, 1)> + }] { bounds = #wave.read_write_bounds<{ + M = #wave.expr_list<[#wave.symbol<"M">] -> (M)>}>} + : (!wave.tensor<[@M, @N] of f16, >) -> vector<4xf16> + // Only M dimension produces a mask — no andi needed. + // CHECK: %[[DIM0_SIZE:.+]] = affine.apply affine_map<() -> (100)>() + // CHECK: %[[DIM0_CMP:.+]] = arith.cmpi slt, %[[ROW]], %[[DIM0_SIZE]] + // CHECK: %[[MASK:.+]] = vector.broadcast %[[DIM0_CMP]] : i1 to vector<4xi1> + // CHECK-NOT: arith.andi + // CHECK: %[[CST0:.+]] = arith.constant dense<0.000000e+00> : vector<4xf16> + // CHECK: vector.maskedload %{{.*}}[%[[ROW]], %{{.*}}], %[[MASK]], %[[CST0]] + return + } +} + +// ----- + normalform.module [#wave.normal_form] { // CHECK-LABEL: @read_with_vector_result func.func @read_with_vector_result(%mem: !wave.tensor<[@M, @N] of f16, >) diff --git a/water/test/Dialect/Wave/ops-invalid.mlir b/water/test/Dialect/Wave/ops-invalid.mlir index 0f2a3b662..39c5a5601 100644 --- a/water/test/Dialect/Wave/ops-invalid.mlir +++ b/water/test/Dialect/Wave/ops-invalid.mlir @@ -466,14 +466,6 @@ normalform.module [#wave.normal_form] { // ----- -func.func @bounds_missing_dim(%mem: !wave.tensor<[@M, @N] of f32>, %val: !wave.tensor<[@M, @N] of f32, >) { - // expected-error @below {{bounds not provided for memory tensor symbol 'N'}} - wave.write %val, %mem { bounds = #wave.read_write_bounds<{ M = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M * 64)>}> } : !wave.tensor<[@M, @N] of f32, >, !wave.tensor<[@M, @N] of f32> - return -} - -// ----- - func.func @bounds_extraneous_dim(%mem: !wave.tensor<[@N] of f32>, %val: !wave.tensor<[@N] of f32, >) { // expected-error @below {{'bounds' specified for a symbol "M" not used in the indexed memory tensor}} wave.write %val, %mem { bounds = #wave.read_write_bounds<{ M = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M * 64)>}> } : !wave.tensor<[@N] of f32, >, !wave.tensor<[@N] of f32> diff --git a/water/test/Dialect/Wave/ops.mlir b/water/test/Dialect/Wave/ops.mlir index 3fadc8dec..9e3b82367 100644 --- a/water/test/Dialect/Wave/ops.mlir +++ b/water/test/Dialect/Wave/ops.mlir @@ -379,6 +379,14 @@ func.func @write_with_bounds(%memo: !wave.tensor<[@M] of f32>, %val: !wave.tenso return } +// Sparse bounds: only M needs masking. +// CHECK-LABEL: @write_with_sparse_bounds +func.func @write_with_sparse_bounds(%mem: !wave.tensor<[@M, @N] of f32>, %val: !wave.tensor<[@M, @N] of f32, >) { + // CHECK: wave.read_write_bounds + wave.write %val, %mem { bounds = #wave.read_write_bounds<{ M = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M * 64)>}> } : !wave.tensor<[@M, @N] of f32, >, !wave.tensor<[@M, @N] of f32> + return +} + // CHECK-LABEL: @cast_wave_tensor func.func @cast_wave_tensor(%arg0: !wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M, @N] of bf16> { // CHECK: wave.cast