diff --git a/lit_tests/kernel/wave/mlir_converter.py b/lit_tests/kernel/wave/mlir_converter.py index 84dc87f9f..a278465f9 100644 --- a/lit_tests/kernel/wave/mlir_converter.py +++ b/lit_tests/kernel/wave/mlir_converter.py @@ -298,9 +298,9 @@ def mlir_converter_matrix_add(): # CHECK-SAME: M : <[{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, 1, 64)> # CHECK-SAME: N : <[{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, BLOCK_N ceildiv 2, 1)> # CHECK-SAME: bounds - # CHECK-SAME: #wave.read_write_bounds - # CHECK-SAME: M = #wave.expr_list - # CHECK-SAME: N = #wave.expr_list + # CHECK-SAME: #wave.symbol_mapping + # CHECK-SAME: @M = #wave.expr_list + # CHECK-SAME: @N = #wave.expr_list # CHECK-SAME: elements_per_thread = 32 : i64 # CHECK-SAME: (!wave.tensor<[@M, @N] of f16, >) -> !wave.tensor<[@M, @N] of f16, > @@ -309,9 +309,9 @@ def mlir_converter_matrix_add(): # CHECK-SAME: M : <[{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, 1, 64)> # CHECK-SAME: N : <[{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, BLOCK_N ceildiv 2, 1)> # CHECK-SAME: bounds - # CHECK-SAME: #wave.read_write_bounds - # CHECK-SAME: M = #wave.expr_list - # CHECK-SAME: N = #wave.expr_list + # CHECK-SAME: #wave.symbol_mapping + # CHECK-SAME: @M = #wave.expr_list + # CHECK-SAME: @N = #wave.expr_list # CHECK-SAME: elements_per_thread = 32 : i64 # CHECK-SAME: (!wave.tensor<[@M, @N] of f16, >) -> !wave.tensor<[@M, @N] of f16, > @@ -332,9 +332,9 @@ def mlir_converter_matrix_add(): # CHECK-SAME: M : <[{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, 1, 64)> # CHECK-SAME: N : <[{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, BLOCK_N ceildiv 2, 1)> # CHECK-SAME: bounds - # CHECK-SAME: #wave.read_write_bounds - # CHECK-SAME: M = #wave.expr_list - # CHECK-SAME: N = #wave.expr_list + # CHECK-SAME: #wave.symbol_mapping + # CHECK-SAME: @M = #wave.expr_list + # CHECK-SAME: @N = #wave.expr_list # CHECK-SAME: elements_per_thread = 32 : i64 # CHECK-SAME: !wave.tensor<[@M, @N] of f32, >, !wave.tensor<[@M, @N] of f32, > diff --git a/water/include/water/Dialect/Wave/IR/WaveAttrs.td b/water/include/water/Dialect/Wave/IR/WaveAttrs.td index ae50a3cc4..7c0ed08f6 100644 --- a/water/include/water/Dialect/Wave/IR/WaveAttrs.td +++ b/water/include/water/Dialect/Wave/IR/WaveAttrs.td @@ -595,30 +595,6 @@ class WaveExprListAttrOf inputTypes> : Attr< let returnType = WaveExprListAttr.returnType; } - -def WaveReadWriteBoundsAttr : AttrDef { - let mnemonic = "read_write_bounds"; - let description = [{ - This attribute contains a dictionary mapping from symbolic dimensions - (strings) to a WaveExprListAttr specifying the bounds of the read/write - operations this is attached to. The dictionary may be sparse: only - dimensions that require masking need an entry. - }]; - - let parameters = (ins "::mlir::DictionaryAttr":$mapping); - let assemblyFormat = "`<` $mapping `>`"; - - let extraClassDeclaration = [{ - /// Check if a symbolic dimension exists in the mapping - bool hasSymbol(::llvm::StringRef symbolicDim) const { - return getMapping().get(symbolicDim) != nullptr; - } - - /// Get the number of symbols in the mapping - size_t getNumSymbols() const { return getMapping().size(); } - }]; -} - //----------------------------------------------------------------------------- // Symbol mapping attribute //----------------------------------------------------------------------------- @@ -647,6 +623,12 @@ def WaveSymbolMappingAttr : AttrDef { let hasCustomAssemblyFormat = 1; let genVerifyDecl = 1; + let builders = [ + AttrBuilder<(ins + "::llvm::ArrayRef>" + :$entries)> + ]; + let extraClassDeclaration = [{ /// Verify that every value `WaveExprListAttr` in the mapping has exactly /// `numResults` results. Returns failure with a diagnostic on the first @@ -658,6 +640,14 @@ def WaveSymbolMappingAttr : AttrDef { /// Look up the `WaveExprListAttr` associated with `key`. /// Returns a default-constructed (null) attribute if the key is absent. ::wave::WaveExprListAttr lookup(::wave::WaveSymbolAttr key) const; + + /// Check if a symbolic dimension exists in the mapping. + bool hasSymbol(::wave::WaveSymbolAttr key) const { + return static_cast(lookup(key)); + } + + /// Get the number of entries in the mapping. + size_t getNumEntries() const { return getKeys().size(); } }]; } diff --git a/water/include/water/Dialect/Wave/IR/WaveOps.td b/water/include/water/Dialect/Wave/IR/WaveOps.td index 5379d27a0..05dd10d6c 100644 --- a/water/include/water/Dialect/Wave/IR/WaveOps.td +++ b/water/include/water/Dialect/Wave/IR/WaveOps.td @@ -451,7 +451,7 @@ def ReadOp : WaveOp<"read", [ Arg:$memory, Arg, "Number of elements processed by each thread">:$elements_per_thread, - Arg, + Arg, "Bound expressions for each symbolic dimension">:$bounds, Arg, "Ordered dimension symbols from memory type shape">:$ordered_syms @@ -517,7 +517,7 @@ def WriteOp : WaveOp<"write", [ Arg:$memory, Arg, "Number of elements processed by each thread">:$elements_per_thread, - Arg, + Arg, "Bound expressions for each symbolic dimension">:$bounds, Arg, "Ordered dimension symbols from memory type shape">:$ordered_syms diff --git a/water/include/water/c/Dialects.h b/water/include/water/c/Dialects.h index b41c728a8..eab66043c 100644 --- a/water/include/water/c/Dialects.h +++ b/water/include/water/c/Dialects.h @@ -432,24 +432,37 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirWaveExprListAttrGetSymbol(MlirAttribute attr, intptr_t index); //===---------------------------------------------------------------------===// -// WaveReadWriteBoundsAttr +// WaveSymbolMappingAttr //===---------------------------------------------------------------------===// -/// Checks whether the given MLIR attribute is a WaveReadWriteBoundsAttr. +/// Checks whether the given MLIR attribute is a WaveSymbolMappingAttr. MLIR_CAPI_EXPORTED bool -mlirAttributeIsAWaveReadWriteBoundsAttr(MlirAttribute attr); +mlirAttributeIsAWaveSymbolMappingAttr(MlirAttribute attr); -/// Creates a new WaveReadWriteBoundsAttr with the given mapping from symbolic -/// dimensions to their bound expressions. +/// Creates a new WaveSymbolMappingAttr from parallel arrays of keys (symbol +/// attrs) and values (expr list attrs). MLIR_CAPI_EXPORTED MlirAttribute -mlirWaveReadWriteBoundsAttrGet(MlirAttribute mapping); +mlirWaveSymbolMappingAttrGet(MlirContext ctx, intptr_t numEntries, + MlirAttribute *keys, MlirAttribute *values); -/// Gets the underlying dictionary mapping from a WaveReadWriteBoundsAttr. +/// Returns the number of entries in a WaveSymbolMappingAttr. +MLIR_CAPI_EXPORTED intptr_t +mlirWaveSymbolMappingAttrGetNumEntries(MlirAttribute attr); + +/// Returns the key at the given index. +MLIR_CAPI_EXPORTED MlirAttribute +mlirWaveSymbolMappingAttrGetKey(MlirAttribute attr, intptr_t index); + +/// Returns the value at the given index. +MLIR_CAPI_EXPORTED MlirAttribute +mlirWaveSymbolMappingAttrGetValue(MlirAttribute attr, intptr_t index); + +/// Returns the value for the given key or null if the key is not present. MLIR_CAPI_EXPORTED MlirAttribute -mlirWaveReadWriteBoundsAttrGetMapping(MlirAttribute attr); +mlirWaveSymbolMappingAttrLookup(MlirAttribute attr, MlirAttribute key); -/// Returns the typeID of a WaveReadWriteBoundsAttr. -MLIR_CAPI_EXPORTED MlirTypeID mlirWaveReadWriteBoundsAttrGetTypeID(); +/// Returns the typeID of a WaveSymbolMappingAttr. +MLIR_CAPI_EXPORTED MlirTypeID mlirWaveSymbolMappingAttrGetTypeID(); //===---------------------------------------------------------------------===// // HardwareConstraintAttr diff --git a/water/lib/CAPI/Dialects.cpp b/water/lib/CAPI/Dialects.cpp index 591f0aac1..779baf2af 100644 --- a/water/lib/CAPI/Dialects.cpp +++ b/water/lib/CAPI/Dialects.cpp @@ -455,35 +455,55 @@ MlirAttribute mlirWaveExprListAttrGetSymbol(MlirAttribute attr, llvm::cast(unwrap(attr)).getSymbols()[index]); } //===---------------------------------------------------------------------===// -// WaveReadWriteBoundsAttr +// WaveSymbolMappingAttr //===---------------------------------------------------------------------===// -bool mlirAttributeIsAWaveReadWriteBoundsAttr(MlirAttribute attr) { - return llvm::isa(unwrap(attr)); +bool mlirAttributeIsAWaveSymbolMappingAttr(MlirAttribute attr) { + return llvm::isa(unwrap(attr)); } -MlirAttribute mlirWaveReadWriteBoundsAttrGet(MlirAttribute mapping) { - auto dictAttr = llvm::cast(unwrap(mapping)); +MlirAttribute mlirWaveSymbolMappingAttrGet(MlirContext ctx, intptr_t numEntries, + MlirAttribute *keys, + MlirAttribute *values) { + SmallVector keyAttrs; + SmallVector valueAttrs; + keyAttrs.reserve(numEntries); + valueAttrs.reserve(numEntries); + for (intptr_t i = 0; i < numEntries; ++i) { + keyAttrs.push_back(llvm::cast(unwrap(keys[i]))); + valueAttrs.push_back(llvm::cast(unwrap(values[i]))); + } + return wrap( + wave::WaveSymbolMappingAttr::get(unwrap(ctx), keyAttrs, valueAttrs)); +} - MLIRContext *ctx = dictAttr.getContext(); +intptr_t mlirWaveSymbolMappingAttrGetNumEntries(MlirAttribute attr) { + return llvm::cast(unwrap(attr)).getNumEntries(); +} - assert(llvm::all_of(dictAttr, - [](const NamedAttribute &namedAttr) { - return llvm::isa( - namedAttr.getValue()); - }) && - "expected mapping to contain only WaveExprListAttr values"); +MlirAttribute mlirWaveSymbolMappingAttrGetKey(MlirAttribute attr, + intptr_t index) { + return wrap( + llvm::cast(unwrap(attr)).getKeys()[index]); +} - return wrap(wave::WaveReadWriteBoundsAttr::get(ctx, dictAttr)); +MlirAttribute mlirWaveSymbolMappingAttrGetValue(MlirAttribute attr, + intptr_t index) { + return wrap( + llvm::cast(unwrap(attr)).getValues()[index]); } -MlirAttribute mlirWaveReadWriteBoundsAttrGetMapping(MlirAttribute attr) { +MlirAttribute mlirWaveSymbolMappingAttrLookup(MlirAttribute attr, + MlirAttribute key) { + auto keyAttr = llvm::dyn_cast(unwrap(key)); + if (!keyAttr) + return MlirAttribute(); return wrap( - llvm::cast(unwrap(attr)).getMapping()); + llvm::cast(unwrap(attr)).lookup(keyAttr)); } -MlirTypeID mlirWaveReadWriteBoundsAttrGetTypeID() { - return wrap(TypeID::get()); +MlirTypeID mlirWaveSymbolMappingAttrGetTypeID() { + return wrap(TypeID::get()); } //===---------------------------------------------------------------------===// diff --git a/water/lib/Dialect/Wave/IR/WaveAttrs.cpp b/water/lib/Dialect/Wave/IR/WaveAttrs.cpp index 9675308c4..783bd2ec0 100644 --- a/water/lib/Dialect/Wave/IR/WaveAttrs.cpp +++ b/water/lib/Dialect/Wave/IR/WaveAttrs.cpp @@ -709,6 +709,20 @@ DeviceConstraintAttr::verify(function_ref emitError, // WaveSymbolMappingAttr //===----------------------------------------------------------------------===// +WaveSymbolMappingAttr WaveSymbolMappingAttr::get( + MLIRContext *context, + ArrayRef> entries) { + SmallVector keys; + SmallVector values; + keys.reserve(entries.size()); + values.reserve(entries.size()); + for (auto &[k, v] : entries) { + keys.push_back(k); + values.push_back(v); + } + return Base::get(context, keys, values); +} + Attribute WaveSymbolMappingAttr::parse(AsmParser &parser, Type) { // Capture the location before consuming any tokens so that verification // errors are reported at the opening `<`. diff --git a/water/lib/Dialect/Wave/IR/WaveOps.cpp b/water/lib/Dialect/Wave/IR/WaveOps.cpp index 5b3b08cdf..e8908b575 100644 --- a/water/lib/Dialect/Wave/IR/WaveOps.cpp +++ b/water/lib/Dialect/Wave/IR/WaveOps.cpp @@ -28,7 +28,6 @@ #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVectorExtras.h" -#include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; @@ -1595,33 +1594,20 @@ verifyIndexElementsPerThread(Operation *op, ArrayAttr indexAttr, // will not generate mask operations during lowering. static LogicalResult verifyReadWriteBounds(Location loc, wave::WaveTensorType boundedType, - DictionaryAttr bounds) { + WaveSymbolMappingAttr bounds) { assert(bounds && "expected non-null bounds"); assert(boundedType && "expected non-null type"); - // We need a fixed iteration order of names for determinism of error messages, - // so using a vector instead of a StringSet. - // TODO: consider refactoring bounds and other dictionary-like attributes to - // be indexed by symbol expressions rather than string attributes to avoid - // string comparisons everywhere. - SmallVector validSymbolNames = llvm::map_to_vector( - boundedType.getShape(), - [](wave::WaveSymbolAttr symbol) { return symbol.getName(); }); - - for (NamedAttribute value : bounds) { - if (!llvm::is_contained(validSymbolNames, value.getName().strref())) { + ArrayRef validSymbols = boundedType.getShape(); + + for (auto [key, value] : llvm::zip(bounds.getKeys(), bounds.getValues())) { + if (!llvm::is_contained(validSymbols, key)) { return emitError(loc) - << "'bounds' specified for a symbol " << value.getName() + << "'bounds' specified for a symbol " << key.getName() << " not used in the " "indexed memory tensor"; } - - // Value type must be WaveExprListAttr. - auto exprListAttr = dyn_cast(value.getValue()); - if (!exprListAttr) - return emitError(loc) << "'bounds' values must be WaveExprListAttr, got " - << value.getValue(); - if (exprListAttr.getRank() != 1) { + if (value.getRank() != 1) { return emitError(loc) << "'bounds' must only contain single-result expressions"; } @@ -1634,7 +1620,7 @@ static LogicalResult verifyReadWriteBounds(Location loc, static LogicalResult verifyReadWriteOp(Operation *op, ArrayAttr indexAttr, std::optional elementsPerThread, Type memoryType, Type valueType, - WaveReadWriteBoundsAttr bounds, + WaveSymbolMappingAttr bounds, ArrayAttr orderedSyms) { if (failed(wave::detail::verifyElementTypesMatch( @@ -1675,7 +1661,7 @@ static LogicalResult verifyReadWriteOp(Operation *op, ArrayAttr indexAttr, if (!bounds) return success(); - return verifyReadWriteBounds(op->getLoc(), tensorType, bounds.getMapping()); + return verifyReadWriteBounds(op->getLoc(), tensorType, bounds); } LogicalResult ReadOp::verify() { @@ -2422,11 +2408,12 @@ permuteIndexExprsStrides(const IndexExprsLatticeStorage &inputLattice, DictionaryAttr inputDict = inputLattice.getConcreteValue(); - llvm::StringMap symbolToMapping; + llvm::DenseMap symbolToMapping; for (NamedAttribute namedAttr : inputDict) { if (auto mapping = llvm::dyn_cast(namedAttr.getValue())) { - symbolToMapping[namedAttr.getName().getValue()] = mapping; + auto key = WaveSymbolAttr::get(ctx, namedAttr.getName()); + symbolToMapping[key] = mapping; } } @@ -2438,11 +2425,8 @@ permuteIndexExprsStrides(const IndexExprsLatticeStorage &inputLattice, permutedMappings.reserve(srcShape.size()); for (auto [srcSymbol, targetSymbol] : llvm::zip_equal(srcShape, targetShape)) { - llvm::StringRef srcName = srcSymbol.getName(); - auto srcMappingIt = symbolToMapping.find(srcName); - - llvm::StringRef targetName = targetSymbol.getName(); - auto targetMappingIt = symbolToMapping.find(targetName); + auto srcMappingIt = symbolToMapping.find(srcSymbol); + auto targetMappingIt = symbolToMapping.find(targetSymbol); assert(srcMappingIt != symbolToMapping.end() && "source mapping not found for symbol"); @@ -2469,7 +2453,7 @@ permuteIndexExprsStrides(const IndexExprsLatticeStorage &inputLattice, alignedStep, alignedStride); permutedMappings.push_back( - NamedAttribute(StringAttr::get(ctx, srcName), newMapping)); + NamedAttribute(StringAttr::get(ctx, srcSymbol.getName()), newMapping)); } return IndexExprsLatticeStorage(DictionaryAttr::get(ctx, permutedMappings)); diff --git a/water/lib/Dialect/Wave/Transforms/LowerReadWriteOps.cpp b/water/lib/Dialect/Wave/Transforms/LowerReadWriteOps.cpp index 2a0397c33..376166f97 100644 --- a/water/lib/Dialect/Wave/Transforms/LowerReadWriteOps.cpp +++ b/water/lib/Dialect/Wave/Transforms/LowerReadWriteOps.cpp @@ -91,14 +91,14 @@ buildStartIndices(Location loc, DictionaryAttr indexDict, /// bound_d(elements_per_thread)) /// foreach d in dimensions. /// -/// whenever a bounds dictionary is provided. When it is not provided, return a +/// whenever a bounds mapping is provided. When it is not provided, return a /// null mask. If the vectorized dimension cannot be identified, return failure. static FailureOr -buildMask(Location loc, wave::WaveReadWriteBoundsAttr boundsDict, +buildMask(Location loc, wave::WaveSymbolMappingAttr boundsMapping, ArrayRef orderedSyms, PatternRewriter &rewriter, DictionaryAttr indexDict, wave::WaveHyperparameterAttr hyper, ArrayRef startIdx, int64_t elementsPerThread) { - if (!boundsDict) + if (!boundsMapping) return Value(); const uint64_t rank = startIdx.size(); @@ -112,11 +112,9 @@ buildMask(Location loc, wave::WaveReadWriteBoundsAttr boundsDict, // without an entry are fully in-bounds and are skipped. Value finalMask; for (uint64_t d = 0; d < rank; ++d) { - StringRef name = orderedSyms[d].getName(); - Attribute a = boundsDict.getMapping().get(name); - if (!a) + wave::WaveExprListAttr boundAttr = boundsMapping.lookup(orderedSyms[d]); + if (!boundAttr) continue; - auto boundAttr = cast(a); // Materialize bounds. FailureOr> boundValsFo = wave::materializeAffine( loc, boundAttr.getSymbols(), boundAttr.getMap(), rewriter, hyper); @@ -312,7 +310,7 @@ createMemoryIndicesAndMask(ConversionPatternRewriter &rewriter, Type memoryTypeArg, VectorType vectorType) { int64_t elementsPerThread = vectorType.getNumElements(); - wave::WaveReadWriteBoundsAttr boundsDict = op.getBoundsAttr(); + wave::WaveSymbolMappingAttr boundsMapping = op.getBoundsAttr(); wave::WaveHyperparameterAttr hyper = static_cast(*typeConverter) .getHyperparameters(); @@ -366,7 +364,7 @@ createMemoryIndicesAndMask(ConversionPatternRewriter &rewriter, SmallVector startIndices = std::move(*maybeStartIndices); FailureOr mask = - buildMask(op->getLoc(), boundsDict, orderedSyms, rewriter, indexDict, + buildMask(op->getLoc(), boundsMapping, orderedSyms, rewriter, indexDict, hyper, startIndices, elementsPerThread); if (failed(mask)) return rewriter.notifyMatchFailure(op, "couldn't build the required mask"); diff --git a/water/python/WaterExtensionNanobind.cpp b/water/python/WaterExtensionNanobind.cpp index acffa2b5a..7055e18ad 100644 --- a/water/python/WaterExtensionNanobind.cpp +++ b/water/python/WaterExtensionNanobind.cpp @@ -8,6 +8,7 @@ #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/IR.h" +#include "mlir-c/Support.h" #include "mlir/Bindings/Python/IRCore.h" #include "mlir/Bindings/Python/Nanobind.h" #include "water/c/Dialects.h" @@ -664,17 +665,17 @@ struct PyWaveExprListAttr }; //===---------------------------------------------------------------------===// -// WaveReadWriteBoundsAttr +// WaveSymbolMappingAttr //===---------------------------------------------------------------------===// -struct PyWaveReadWriteBoundsAttr +struct PyWaveSymbolMappingAttr : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteAttribute< - PyWaveReadWriteBoundsAttr> { + PyWaveSymbolMappingAttr> { static constexpr IsAFunctionTy isaFunction = - mlirAttributeIsAWaveReadWriteBoundsAttr; + mlirAttributeIsAWaveSymbolMappingAttr; static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirWaveReadWriteBoundsAttrGetTypeID; - static constexpr const char *pyClassName = "WaveReadWriteBoundsAttr"; + mlirWaveSymbolMappingAttrGetTypeID; + static constexpr const char *pyClassName = "WaveSymbolMappingAttr"; using PyConcreteAttribute::PyConcreteAttribute; static void bindDerived(ClassTy &c) { @@ -683,47 +684,87 @@ struct PyWaveReadWriteBoundsAttr [](const nb::dict &symDimDict, mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext context) { - std::vector namedAttrs; - namedAttrs.reserve(symDimDict.size()); + std::vector keys; + std::vector values; + keys.reserve(symDimDict.size()); + values.reserve(symDimDict.size()); for (auto [key, value] : symDimDict) { - // Get the key (symbolic dimension) - nb::handle key_handle = key; - if (!nb::isinstance(key_handle)) { - throw nb::type_error( - "Symbolic dimension dictionary key must be a string"); + nb::handle keyHandle = key; + MlirAttribute keyAttr; + if (nb::isinstance(keyHandle)) { + std::string symbolicDim = nb::cast(keyHandle); + keyAttr = mlirWaveSymbolAttrGet( + context->get(), + mlirStringRefCreate(symbolicDim.data(), symbolicDim.size())); + } else { + try { + keyAttr = nb::cast(keyHandle); + } catch (const nb::cast_error &e) { + throw nb::type_error("Symbolic dimension dictionary key must " + "be a string or a WaveSymbolAttr"); + } } - std::string symbolicDim = nb::cast(key_handle); - // Get the value (bound expression) - MlirAttribute attr; + MlirAttribute valueAttr; try { - attr = nb::cast(value); + valueAttr = nb::cast(value); } catch (const nb::cast_error &e) { throw nb::type_error( "Symbolic dimension dictionary value must be an attribute"); } - if (!mlirAttributeIsAWaveExprListAttr(attr)) { + if (!mlirAttributeIsAWaveExprListAttr(valueAttr)) { throw nb::type_error("Symbolic dimension dictionary value must " "be a WaveExprListAttr"); } - namedAttrs.push_back(mlirNamedAttributeGet( - mlirIdentifierGet(context->get(), - mlirStringRefCreate(symbolicDim.data(), - symbolicDim.size())), - attr)); + keys.push_back(keyAttr); + values.push_back(valueAttr); } - return PyWaveReadWriteBoundsAttr( + return PyWaveSymbolMappingAttr( context->getRef(), - mlirWaveReadWriteBoundsAttrGet(mlirDictionaryAttrGet( - context->get(), namedAttrs.size(), namedAttrs.data()))); + mlirWaveSymbolMappingAttrGet(context->get(), keys.size(), + keys.data(), values.data())); }, nb::arg("sym_dim_dict"), nb::arg("context") = nb::none(), - "Gets a wave.WaveReadWriteBoundsAttr from parameters."); - c.def_prop_ro("mapping", [](MlirAttribute self) { - return mlirWaveReadWriteBoundsAttrGetMapping(self); + "Gets a wave.WaveSymbolMappingAttr from parameters."); + c.def("__contains__", [](MlirAttribute self, MlirAttribute key) { + return !mlirAttributeIsNull(mlirWaveSymbolMappingAttrLookup(self, key)); + }); + c.def("__contains__", [](MlirAttribute self, std::string key) { + MlirAttribute keyAttr = + mlirWaveSymbolAttrGet(mlirAttributeGetContext(self), + mlirStringRefCreate(key.data(), key.size())); + return !mlirAttributeIsNull( + mlirWaveSymbolMappingAttrLookup(self, keyAttr)); + }); + c.def("__len__", [](MlirAttribute self) -> intptr_t { + return mlirWaveSymbolMappingAttrGetNumEntries(self); + }); + c.def("__getitem__", [](MlirAttribute self, MlirAttribute key) { + MlirAttribute value = mlirWaveSymbolMappingAttrLookup(self, key); + if (mlirAttributeIsNull(value)) { + throw nb::key_error("Key not found."); + } + return value; + }); + c.def("__getitem__", [](MlirAttribute self, std::string key) { + MlirAttribute keyAttr = + mlirWaveSymbolAttrGet(mlirAttributeGetContext(self), + mlirStringRefCreate(key.data(), key.size())); + MlirAttribute value = mlirWaveSymbolMappingAttrLookup(self, keyAttr); + if (mlirAttributeIsNull(value)) { + throw nb::key_error("Key not found."); + } + return value; + }); + c.def("__getitem__", [](MlirAttribute self, intptr_t index) { + if (index < 0 || index >= mlirWaveSymbolMappingAttrGetNumEntries(self)) { + throw nb::index_error("Index out of range."); + } + return nb::make_tuple(mlirWaveSymbolMappingAttrGetKey(self, index), + mlirWaveSymbolMappingAttrGetValue(self, index)); }); } }; @@ -1091,7 +1132,7 @@ NB_MODULE(_waterDialects, m) { PyWaveApplyExprCombinatorAttr::bind(d); PyWaveMmaKindAttr::bind(d); PyWaveExprListAttr::bind(d); - PyWaveReadWriteBoundsAttr::bind(d); + PyWaveSymbolMappingAttr::bind(d); PyWaveTensorType::bind(d); PyHardwareConstraintAttr::bind(d); PyDeviceConstraintAttr::bind(d); diff --git a/water/test/Dialect/Wave/lower-wave-to-mlir.mlir b/water/test/Dialect/Wave/lower-wave-to-mlir.mlir index 61308ce4c..bd06a1e89 100644 --- a/water/test/Dialect/Wave/lower-wave-to-mlir.mlir +++ b/water/test/Dialect/Wave/lower-wave-to-mlir.mlir @@ -613,9 +613,9 @@ normalform.module [#wave.normal_form (s0 * 64 + s1 * 32)>()[%[[BIDX_Y]], %[[TIDX_Y]]] N : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N + T1 * 32, 4, 1)> - }] { bounds = #wave.read_write_bounds<{ - M = #wave.expr_list<[#wave.symbol<"M">] -> (M)>, - N = #wave.expr_list<[#wave.symbol<"N">] -> (N)>}>} + }] { bounds = #wave.symbol_mapping< + @M = #wave.expr_list<[#wave.symbol<"M">] -> (M)>, + @N = #wave.expr_list<[#wave.symbol<"N">] -> (N)>>} : (!wave.tensor<[@M, @N] of f16, >) -> vector<4xf16> // Bounds for dim 0. // CHECK: %[[DIM0_SIZE:.+]] = affine.apply affine_map<() -> (100)>() @@ -648,9 +648,9 @@ normalform.module [#wave.normal_form, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + T0, 8, 64)>, N : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N + T1 * 32, 1, 1)> - }] { bounds = #wave.read_write_bounds<{ - M = #wave.expr_list<[#wave.symbol<"M">] -> (M)>, - N = #wave.expr_list<[#wave.symbol<"N">] -> (N)>}>} + }] { bounds = #wave.symbol_mapping< + @M = #wave.expr_list<[#wave.symbol<"M">] -> (M)>, + @N = #wave.expr_list<[#wave.symbol<"N">] -> (N)>>} : (!wave.tensor<[@M, @N] of f16, >) -> vector<8xf16> // CHECK: %[[MASK:.+]] = arith.andi {{.*}}, {{.*}} // CHECK: %[[PAD:.*]] = arith.constant {{.*}} : f16 @@ -673,8 +673,8 @@ normalform.module [#wave.normal_form (s0 * 64 + s1)>()[%[[BIDX_X]], %[[TIDX_X]]] M : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + T0, 1, 64)>, N : <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N + T1 * 32, 4, 1)> - }] { bounds = #wave.read_write_bounds<{ - M = #wave.expr_list<[#wave.symbol<"M">] -> (M)>}>} + }] { bounds = #wave.symbol_mapping< + @M = #wave.expr_list<[#wave.symbol<"M">] -> (M)>>} : (!wave.tensor<[@M, @N] of f16, >) -> vector<4xf16> // Only M dimension produces a mask — no andi needed. // CHECK: %[[DIM0_SIZE:.+]] = affine.apply affine_map<() -> (100)>() diff --git a/water/test/Dialect/Wave/ops-invalid.mlir b/water/test/Dialect/Wave/ops-invalid.mlir index 4cb7735b5..9140e3b17 100644 --- a/water/test/Dialect/Wave/ops-invalid.mlir +++ b/water/test/Dialect/Wave/ops-invalid.mlir @@ -474,16 +474,8 @@ normalform.module [#wave.normal_form] { // ----- func.func @bounds_extraneous_dim(%mem: !wave.tensor<[@N] of f32>, %val: !wave.tensor<[@N] of f32, >) { - // expected-error @below {{'bounds' specified for a symbol "M" not used in the indexed memory tensor}} - wave.write %val, %mem { bounds = #wave.read_write_bounds<{ M = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M * 64)>}> } : !wave.tensor<[@N] of f32, >, !wave.tensor<[@N] of f32> - return -} - -// ----- - -func.func @bounds_wrong_type(%mem: !wave.tensor<[@N] of f32>) { - // expected-error @below {{'bounds' values must be WaveExprListAttr, got 42 : i64}} - wave.read %mem { bounds = #wave.read_write_bounds<{ N = 42 }> } : (!wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32, > + // expected-error @below {{'bounds' specified for a symbol M not used in the indexed memory tensor}} + wave.write %val, %mem { bounds = #wave.symbol_mapping<@M = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M * 64)>> } : !wave.tensor<[@N] of f32, >, !wave.tensor<[@N] of f32> return } @@ -491,7 +483,7 @@ func.func @bounds_wrong_type(%mem: !wave.tensor<[@N] of f32>) { func.func @bounds_wrong_rank(%mem: !wave.tensor<[@N] of f32>) { // expected-error @below {{'bounds' must only contain single-result expressions}} - wave.read %mem { bounds = #wave.read_write_bounds<{ N = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M * 64, BLOCK_M * 64)>}> } : (!wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32, > + wave.read %mem { bounds = #wave.symbol_mapping<@N = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M * 64, BLOCK_M * 64)>> } : (!wave.tensor<[@N] of f32>) -> !wave.tensor<[@N] of f32, > return } diff --git a/water/test/Dialect/Wave/ops.mlir b/water/test/Dialect/Wave/ops.mlir index 2fca86a74..b2a01309b 100644 --- a/water/test/Dialect/Wave/ops.mlir +++ b/water/test/Dialect/Wave/ops.mlir @@ -384,16 +384,16 @@ attributes {wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 32, BLOCK_N // CHECK-LABEL: @write_with_bounds func.func @write_with_bounds(%memo: !wave.tensor<[@M] of f32>, %val: !wave.tensor<[@M] of f32, >) { - // CHECK: wave.read_write_bounds - wave.write %val, %memo { bounds = #wave.read_write_bounds<{ M = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M * 64)>}> } : !wave.tensor<[@M] of f32, >, !wave.tensor<[@M] of f32> + // CHECK: wave.symbol_mapping + wave.write %val, %memo { bounds = #wave.symbol_mapping<@M = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M * 64)>> } : !wave.tensor<[@M] of f32, >, !wave.tensor<[@M] of f32> return } // Sparse bounds: only M needs masking. // CHECK-LABEL: @write_with_sparse_bounds func.func @write_with_sparse_bounds(%mem: !wave.tensor<[@M, @N] of f32>, %val: !wave.tensor<[@M, @N] of f32, >) { - // CHECK: wave.read_write_bounds - wave.write %val, %mem { bounds = #wave.read_write_bounds<{ M = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M * 64)>}> } : !wave.tensor<[@M, @N] of f32, >, !wave.tensor<[@M, @N] of f32> + // CHECK: wave.symbol_mapping + wave.write %val, %mem { bounds = #wave.symbol_mapping<@M = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M * 64)>> } : !wave.tensor<[@M, @N] of f32, >, !wave.tensor<[@M, @N] of f32> return } diff --git a/water/test/Dialect/Wave/python_bindings.py b/water/test/Dialect/Wave/python_bindings.py index 274a2cbf2..dfcba681b 100644 --- a/water/test/Dialect/Wave/python_bindings.py +++ b/water/test/Dialect/Wave/python_bindings.py @@ -217,25 +217,78 @@ else: assert False, "Expected to fail with ValueError." - # CHECK: #wave.read_write_bounds<{M = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"BLOCK_M">, #wave.index_symbol] -> (WG0 * 3)>}> - print(wave.WaveReadWriteBoundsAttr.get({"M": expr_attr})) + # CHECK: #wave.symbol_mapping<@M = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"BLOCK_M">, #wave.index_symbol] -> (WG0 * 3)>> + mapping_attr = wave.WaveSymbolMappingAttr.get({"M": expr_attr}) + print(mapping_attr) + assert len(mapping_attr) == 1 + assert mapping_attr["M"] == expr_attr try: - wave.WaveReadWriteBoundsAttr.get({3: expr_attr}) + mapping_attr["nyan"] + except KeyError as e: + assert "Key not found." in str(e) + else: + assert False, "Expected to fail with KeyError." + + mapping_attr_2 = wave.WaveSymbolMappingAttr.get( + {wave.WaveSymbolAttr.get("M"): expr_attr} + ) + print(mapping_attr_2) + assert len(mapping_attr_2) == 1 + assert mapping_attr_2[wave.WaveSymbolAttr.get("M")] == expr_attr + + expr_attr_2 = wave.WaveExprListAttr.get( + [wave.WaveSymbolAttr.get("A")], + ir.AffineMap.get(0, 1, [ir.AffineExpr.get_constant(1)]), + ) + mapping_attr_ordered = wave.WaveSymbolMappingAttr.get( + {"M": expr_attr, "A": expr_attr_2} + ) + assert mapping_attr_ordered[0][0] == wave.WaveSymbolAttr.get("M") + assert mapping_attr_ordered[0][1] == expr_attr + assert mapping_attr_ordered[1][0] == wave.WaveSymbolAttr.get("A") + assert mapping_attr_ordered[1][1] == expr_attr_2 + assert "M" in mapping_attr_ordered + assert "N" not in mapping_attr_ordered + assert wave.WaveSymbolAttr.get("A") in mapping_attr_ordered + + try: + mapping_attr_ordered[42] + except IndexError as e: + assert "Index out of range." in str(e) + else: + assert False, "Expected to fail with IndexError." + + try: + mapping_attr_ordered["N"] + except KeyError as e: + assert "Key not found." in str(e) + else: + assert False, "Expected to fail with KeyError." + + try: + mapping_attr_ordered[wave.WaveSymbolAttr.get("B")] + except KeyError as e: + assert "Key not found." in str(e) + else: + assert False, "Expected to fail with KeyError." + + try: + wave.WaveSymbolMappingAttr.get({3: expr_attr}) except TypeError as e: assert "must be a string" in str(e) else: assert False, "Expected to fail with TypeError." try: - wave.WaveReadWriteBoundsAttr.get({"A": 1.0}) + wave.WaveSymbolMappingAttr.get({"A": 1.0}) except TypeError as e: assert "must be an attribute" in str(e) else: assert False, "Expected to fail with TypeError." try: - wave.WaveReadWriteBoundsAttr.get({"A": addr_attr}) + wave.WaveSymbolMappingAttr.get({"A": addr_attr}) except TypeError as e: assert "must be a WaveExprListAttr" in str(e) else: diff --git a/wave_lang/kernel/wave/mlir_converter/fx_emitter.py b/wave_lang/kernel/wave/mlir_converter/fx_emitter.py index 3641948b4..dcfa82e40 100644 --- a/wave_lang/kernel/wave/mlir_converter/fx_emitter.py +++ b/wave_lang/kernel/wave/mlir_converter/fx_emitter.py @@ -40,7 +40,7 @@ YieldOp, WaveAddressSpaceAttr, WaveMmaKindAttr, - WaveReadWriteBoundsAttr, + WaveSymbolMappingAttr, WaveWorkgroupDimAttr, WaveTensorType, iterate_make_isolated, @@ -200,7 +200,7 @@ def _convert_mma_kind(attr: WaveMmaKindAttr) -> MMAType | ScaledMMAType: def _convert_read_write_bounds( - attr: WaveReadWriteBoundsAttr, + attr: WaveSymbolMappingAttr, ) -> dict[IndexSymbol, IndexExpr]: """ Converts Wave read/write bounds attribute into a dictionary mapping dimensions to expressions. @@ -208,11 +208,9 @@ def _convert_read_write_bounds( Bounds specify the iteration space for memory operations (read/write) along each dimension. """ bounds: dict[IndexSymbol, IndexExpr] = {} - for named in attr.mapping: - key = named.name - value = named.attr + for key, value in attr: exprs = expr_list_attr_to_exprs(value) - bounds[index_symbol(key)] = exprs[0] + bounds[index_symbol(key.name)] = exprs[0] return bounds diff --git a/wave_lang/kernel/wave/mlir_converter/water_emitter.py b/wave_lang/kernel/wave/mlir_converter/water_emitter.py index 69acb7334..fe36cf186 100644 --- a/wave_lang/kernel/wave/mlir_converter/water_emitter.py +++ b/wave_lang/kernel/wave/mlir_converter/water_emitter.py @@ -443,7 +443,7 @@ def _attach_attributes( symbol_name_to_attribute(sym.name) for sym in symbol_mapping.values() ] bounds[dim.name] = wave.WaveExprListAttr.get(symbol_attrs, result) - op.attributes["bounds"] = wave.WaveReadWriteBoundsAttr.get(bounds) + op.attributes["bounds"] = wave.WaveSymbolMappingAttr.get(bounds) if water_id := getattr(node.fx_node, "_water_id", None): op.attributes[_INTERNAL_WATER_ID_ATTR_NAME] = ir.StringAttr.get(water_id)