Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions lit_tests/kernel/wave/mlir_roundtrip_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,6 @@ def gemm_progressive_roundtrip():
"remove_chained_getresult",
"decompose_vmma_ops",
"decompose_dot_mma",
"generate_bounds_exprs",
"location_check_pass",
"merge_contiguous_reads",
}
)

Expand Down
12 changes: 4 additions & 8 deletions water/include/water/Dialect/Wave/IR/WaveAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -599,14 +599,10 @@ class WaveExprListAttrOf<list<AttrDef> inputTypes> : Attr<
def WaveReadWriteBoundsAttr : AttrDef<WaveDialect, "WaveReadWriteBounds"> {
let mnemonic = "read_write_bounds";
let description = [{
This attribute contains a dictionary mapping from symbolic dimension (strings)
to a WaveExprListAttr specifying the bounds of the read/write operations
this is attached to.

Example:
```
#wave.read_write_bounds<{M = #wave.expr_list<[BLOCK_M] -> BLOCK_M * 2>}>
```
This attribute contains a dictionary mapping from symbolic dimensions
(strings) to a WaveExprListAttr specifying the bounds of the read/write
operations this is attached to. The dictionary may be sparse: only
dimensions that require masking need an entry.
}];

let parameters = (ins "::mlir::DictionaryAttr":$mapping);
Expand Down
24 changes: 9 additions & 15 deletions water/lib/Dialect/Wave/IR/WaveOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1531,9 +1531,12 @@ verifyIndexElementsPerThread(Operation *op, ArrayAttr indexAttr,
return success();
}

// Check that if the given read/write operation has bound expressions specified,
// each symbolic dimension of the WaveTensorType has exactly one bound
// expression.
// Verify that every key in the bounds dictionary names a symbolic dimension of
// the WaveTensorType and that each value is a single-result WaveExprListAttr.
// The dictionary may be sparse: only dimensions that actually require masking
// (e.g. because the tile size does not evenly divide the dimension) need an
// entry. Dimensions without an entry are assumed to be fully in-bounds and
// will not generate mask operations during lowering.
static LogicalResult verifyReadWriteBounds(Location loc,
wave::WaveTensorType boundedType,
DictionaryAttr bounds) {
Expand All @@ -1545,12 +1548,12 @@ static LogicalResult verifyReadWriteBounds(Location loc,
// TODO: consider refactoring bounds and other dictionary-like attributes to
// be indexed by symbol expressions rather than string attributes to avoid
// string comparisons everywhere.
SmallVector<StringRef> requiredSymbolNames = llvm::map_to_vector(
SmallVector<StringRef> validSymbolNames = llvm::map_to_vector(
boundedType.getShape(),
[](wave::WaveSymbolAttr symbol) { return symbol.getName(); });
llvm::StringSet<> knownSymbolNames;

for (NamedAttribute value : bounds) {
if (!llvm::is_contained(requiredSymbolNames, value.getName().strref())) {
if (!llvm::is_contained(validSymbolNames, value.getName().strref())) {
return emitError(loc)
<< "'bounds' specified for a symbol " << value.getName()
<< " not used in the "
Expand All @@ -1566,15 +1569,6 @@ static LogicalResult verifyReadWriteBounds(Location loc,
return emitError(loc)
<< "'bounds' must only contain single-result expressions";
}

knownSymbolNames.insert(value.getName().strref());
}
for (StringRef requiredName : requiredSymbolNames) {
if (knownSymbolNames.contains(requiredName))
continue;

return emitError(loc) << "bounds not provided for memory tensor symbol '"
<< requiredName << "'";
}

return success();
Expand Down
7 changes: 5 additions & 2 deletions water/lib/Dialect/Wave/Transforms/LowerReadWriteOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,15 @@ buildMask(Location loc, wave::WaveReadWriteBoundsAttr boundsDict,
IntegerType i1Type = IntegerType::get(rewriter.getContext(), 1);
VectorType maskType = VectorType::get({elementsPerThread}, i1Type);

// finalMask is the AND of per-dimension bound checks.
// finalMask is the AND of per-dimension bound checks. The bounds dict may
// be sparse: only dimensions that require masking have an entry. Dimensions
// without an entry are fully in-bounds and are skipped.
Value finalMask;
for (uint64_t d = 0; d < rank; ++d) {
StringRef name = orderedSyms[d].getName();
Attribute a = boundsDict.getMapping().get(name);
assert(a && "bounds dict missing entry for dimension symbol");
if (!a)
continue;
auto boundAttr = cast<wave::WaveExprListAttr>(a);
// Materialize bounds.
FailureOr<SmallVector<Value>> boundValsFo = wave::materializeAffine(
Expand Down
28 changes: 28 additions & 0 deletions water/test/Dialect/Wave/lower-wave-to-mlir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,34 @@ normalform.module [#wave.normal_form<full_types,index_exprs,memory_only_types,re

// -----

// Sparse bounds: only M needs masking, N is fully tiled (no entry).
// The mask should only check the M dimension — no arith.andi.
normalform.module [#wave.normal_form<full_types,index_exprs,memory_only_types,resolved_allocations,ordered_syms>] {
// CHECK-LABEL: @lower_read_sparse_bounds
func.func @lower_read_sparse_bounds(%mem: !wave.tensor<[@M, @N] of f16, <global>>)
attributes {wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 64, BLOCK_N = 64, M = 100, N = 64}>} {
%v = wave.read %mem index [{
// CHECK: %[[BIDX_X:.*]] = gpu.block_id x
// CHECK: %[[TIDX_X:.*]] = gpu.thread_id x
// CHECK: %[[ROW:.*]] = affine.apply affine_map<()[s0, s1] -> (s0 * 64 + s1)>()[%[[BIDX_X]], %[[TIDX_X]]]
M : <[#wave.index_symbol<WG0>, #wave.index_symbol<T0>, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + T0, 1, 64)>,
N : <[#wave.index_symbol<WG1>, #wave.index_symbol<T1>, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N + T1 * 32, 4, 1)>
}] { bounds = #wave.read_write_bounds<{
M = #wave.expr_list<[#wave.symbol<"M">] -> (M)>}>}
: (!wave.tensor<[@M, @N] of f16, <global>>) -> vector<4xf16>
// Only M dimension produces a mask — no andi needed.
// CHECK: %[[DIM0_SIZE:.+]] = affine.apply affine_map<() -> (100)>()
// CHECK: %[[DIM0_CMP:.+]] = arith.cmpi slt, %[[ROW]], %[[DIM0_SIZE]]
// CHECK: %[[MASK:.+]] = vector.broadcast %[[DIM0_CMP]] : i1 to vector<4xi1>
// CHECK-NOT: arith.andi
// CHECK: %[[CST0:.+]] = arith.constant dense<0.000000e+00> : vector<4xf16>
// CHECK: vector.maskedload %{{.*}}[%[[ROW]], %{{.*}}], %[[MASK]], %[[CST0]]
return
}
}

// -----

normalform.module [#wave.normal_form<full_types,index_exprs,memory_only_types,resolved_allocations,ordered_syms>] {
// CHECK-LABEL: @read_with_vector_result
func.func @read_with_vector_result(%mem: !wave.tensor<[@M, @N] of f16, <global>>)
Expand Down
8 changes: 0 additions & 8 deletions water/test/Dialect/Wave/ops-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -466,14 +466,6 @@ normalform.module [#wave.normal_form<full_types>] {

// -----

func.func @bounds_missing_dim(%mem: !wave.tensor<[@M, @N] of f32>, %val: !wave.tensor<[@M, @N] of f32, <register>>) {
// expected-error @below {{bounds not provided for memory tensor symbol 'N'}}
wave.write %val, %mem { bounds = #wave.read_write_bounds<{ M = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M * 64)>}> } : !wave.tensor<[@M, @N] of f32, <register>>, !wave.tensor<[@M, @N] of f32>
return
}

// -----

func.func @bounds_extraneous_dim(%mem: !wave.tensor<[@N] of f32>, %val: !wave.tensor<[@N] of f32, <register>>) {
// expected-error @below {{'bounds' specified for a symbol "M" not used in the indexed memory tensor}}
wave.write %val, %mem { bounds = #wave.read_write_bounds<{ M = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M * 64)>}> } : !wave.tensor<[@N] of f32, <register>>, !wave.tensor<[@N] of f32>
Expand Down
8 changes: 8 additions & 0 deletions water/test/Dialect/Wave/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,14 @@ func.func @write_with_bounds(%memo: !wave.tensor<[@M] of f32>, %val: !wave.tenso
return
}

// Sparse bounds: only M needs masking.
// CHECK-LABEL: @write_with_sparse_bounds
func.func @write_with_sparse_bounds(%mem: !wave.tensor<[@M, @N] of f32>, %val: !wave.tensor<[@M, @N] of f32, <register>>) {
// CHECK: wave.read_write_bounds
wave.write %val, %mem { bounds = #wave.read_write_bounds<{ M = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M * 64)>}> } : !wave.tensor<[@M, @N] of f32, <register>>, !wave.tensor<[@M, @N] of f32>
return
}

// CHECK-LABEL: @cast_wave_tensor
func.func @cast_wave_tensor(%arg0: !wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M, @N] of bf16> {
// CHECK: wave.cast
Expand Down