Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions lit_tests/kernel/wave/mlir_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,9 @@ def mlir_converter_matrix_add():
# CHECK-SAME: M : <[{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, 1, 64)>
# CHECK-SAME: N : <[{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, BLOCK_N ceildiv 2, 1)>
# CHECK-SAME: bounds
# CHECK-SAME: #wave.read_write_bounds
# CHECK-SAME: M = #wave.expr_list
# CHECK-SAME: N = #wave.expr_list
# CHECK-SAME: #wave.symbol_mapping
# CHECK-SAME: @M = #wave.expr_list
# CHECK-SAME: @N = #wave.expr_list
# CHECK-SAME: elements_per_thread = 32 : i64
# CHECK-SAME: (!wave.tensor<[@M, @N] of f16, <global>>) -> !wave.tensor<[@M, @N] of f16, <register>>

Expand All @@ -309,9 +309,9 @@ def mlir_converter_matrix_add():
# CHECK-SAME: M : <[{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, 1, 64)>
# CHECK-SAME: N : <[{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, BLOCK_N ceildiv 2, 1)>
# CHECK-SAME: bounds
# CHECK-SAME: #wave.read_write_bounds
# CHECK-SAME: M = #wave.expr_list
# CHECK-SAME: N = #wave.expr_list
# CHECK-SAME: #wave.symbol_mapping
# CHECK-SAME: @M = #wave.expr_list
# CHECK-SAME: @N = #wave.expr_list
# CHECK-SAME: elements_per_thread = 32 : i64
# CHECK-SAME: (!wave.tensor<[@M, @N] of f16, <global>>) -> !wave.tensor<[@M, @N] of f16, <register>>

Expand All @@ -332,9 +332,9 @@ def mlir_converter_matrix_add():
# CHECK-SAME: M : <[{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, 1, 64)>
# CHECK-SAME: N : <[{{.*}}, {{.*}}, {{.*}}] -> ({{.*}}, BLOCK_N ceildiv 2, 1)>
# CHECK-SAME: bounds
# CHECK-SAME: #wave.read_write_bounds
# CHECK-SAME: M = #wave.expr_list
# CHECK-SAME: N = #wave.expr_list
# CHECK-SAME: #wave.symbol_mapping
# CHECK-SAME: @M = #wave.expr_list
# CHECK-SAME: @N = #wave.expr_list
# CHECK-SAME: elements_per_thread = 32 : i64
# CHECK-SAME: !wave.tensor<[@M, @N] of f32, <register>>, !wave.tensor<[@M, @N] of f32, <global>>

Expand Down
38 changes: 14 additions & 24 deletions water/include/water/Dialect/Wave/IR/WaveAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -595,30 +595,6 @@ class WaveExprListAttrOf<list<AttrDef> inputTypes> : Attr<
let returnType = WaveExprListAttr.returnType;
}


def WaveReadWriteBoundsAttr : AttrDef<WaveDialect, "WaveReadWriteBounds"> {
let mnemonic = "read_write_bounds";
let description = [{
This attribute contains a dictionary mapping from symbolic dimensions
(strings) to a WaveExprListAttr specifying the bounds of the read/write
operations this is attached to. The dictionary may be sparse: only
dimensions that require masking need an entry.
}];

let parameters = (ins "::mlir::DictionaryAttr":$mapping);
let assemblyFormat = "`<` $mapping `>`";

let extraClassDeclaration = [{
/// Check if a symbolic dimension exists in the mapping
bool hasSymbol(::llvm::StringRef symbolicDim) const {
return getMapping().get(symbolicDim) != nullptr;
}

/// Get the number of symbols in the mapping
size_t getNumSymbols() const { return getMapping().size(); }
}];
}

//-----------------------------------------------------------------------------
// Symbol mapping attribute
//-----------------------------------------------------------------------------
Expand Down Expand Up @@ -647,6 +623,12 @@ def WaveSymbolMappingAttr : AttrDef<WaveDialect, "WaveSymbolMapping"> {
let hasCustomAssemblyFormat = 1;
let genVerifyDecl = 1;

let builders = [
AttrBuilder<(ins
"::llvm::ArrayRef<std::pair<::wave::WaveSymbolAttr, ::wave::WaveExprListAttr>>"
:$entries)>
];

let extraClassDeclaration = [{
/// Verify that every value `WaveExprListAttr` in the mapping has exactly
/// `numResults` results. Returns failure with a diagnostic on the first
Expand All @@ -658,6 +640,14 @@ def WaveSymbolMappingAttr : AttrDef<WaveDialect, "WaveSymbolMapping"> {
/// Look up the `WaveExprListAttr` associated with `key`.
/// Returns a default-constructed (null) attribute if the key is absent.
::wave::WaveExprListAttr lookup(::wave::WaveSymbolAttr key) const;

/// Check if a symbolic dimension exists in the mapping.
bool hasSymbol(::wave::WaveSymbolAttr key) const {
return static_cast<bool>(lookup(key));
}

/// Get the number of entries in the mapping.
size_t getNumEntries() const { return getKeys().size(); }
}];
}

Expand Down
4 changes: 2 additions & 2 deletions water/include/water/Dialect/Wave/IR/WaveOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def ReadOp : WaveOp<"read", [
Arg<WaveMemoryType, "Memory to read from">:$memory,
Arg<OptionalAttr<I64Attr>,
"Number of elements processed by each thread">:$elements_per_thread,
Arg<OptionalAttr<WaveReadWriteBoundsAttr>,
Arg<OptionalAttr<WaveSymbolMappingAttr>,
"Bound expressions for each symbolic dimension">:$bounds,
Arg<OptionalAttr<WaveSymbolArrayAttr>,
"Ordered dimension symbols from memory type shape">:$ordered_syms
Expand Down Expand Up @@ -517,7 +517,7 @@ def WriteOp : WaveOp<"write", [
Arg<WaveMemoryType, "Memory to write into">:$memory,
Arg<OptionalAttr<I64Attr>,
"Number of elements processed by each thread">:$elements_per_thread,
Arg<OptionalAttr<WaveReadWriteBoundsAttr>,
Arg<OptionalAttr<WaveSymbolMappingAttr>,
"Bound expressions for each symbolic dimension">:$bounds,
Arg<OptionalAttr<WaveSymbolArrayAttr>,
"Ordered dimension symbols from memory type shape">:$ordered_syms
Expand Down
33 changes: 23 additions & 10 deletions water/include/water/c/Dialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -432,24 +432,37 @@ MLIR_CAPI_EXPORTED MlirAttribute
mlirWaveExprListAttrGetSymbol(MlirAttribute attr, intptr_t index);

//===---------------------------------------------------------------------===//
// WaveReadWriteBoundsAttr
// WaveSymbolMappingAttr
//===---------------------------------------------------------------------===//

/// Checks whether the given MLIR attribute is a WaveReadWriteBoundsAttr.
/// Checks whether the given MLIR attribute is a WaveSymbolMappingAttr.
MLIR_CAPI_EXPORTED bool
mlirAttributeIsAWaveReadWriteBoundsAttr(MlirAttribute attr);
mlirAttributeIsAWaveSymbolMappingAttr(MlirAttribute attr);

/// Creates a new WaveReadWriteBoundsAttr with the given mapping from symbolic
/// dimensions to their bound expressions.
/// Creates a new WaveSymbolMappingAttr from parallel arrays of keys (symbol
/// attrs) and values (expr list attrs).
MLIR_CAPI_EXPORTED MlirAttribute
mlirWaveReadWriteBoundsAttrGet(MlirAttribute mapping);
mlirWaveSymbolMappingAttrGet(MlirContext ctx, intptr_t numEntries,
MlirAttribute *keys, MlirAttribute *values);

/// Gets the underlying dictionary mapping from a WaveReadWriteBoundsAttr.
/// Returns the number of entries in a WaveSymbolMappingAttr.
MLIR_CAPI_EXPORTED intptr_t
mlirWaveSymbolMappingAttrGetNumEntries(MlirAttribute attr);

/// Returns the key at the given index.
MLIR_CAPI_EXPORTED MlirAttribute
mlirWaveSymbolMappingAttrGetKey(MlirAttribute attr, intptr_t index);

/// Returns the value at the given index.
MLIR_CAPI_EXPORTED MlirAttribute
mlirWaveSymbolMappingAttrGetValue(MlirAttribute attr, intptr_t index);

/// Returns the value for the given key or null if the key is not present.
MLIR_CAPI_EXPORTED MlirAttribute
mlirWaveReadWriteBoundsAttrGetMapping(MlirAttribute attr);
mlirWaveSymbolMappingAttrLookup(MlirAttribute attr, MlirAttribute key);

/// Returns the typeID of a WaveReadWriteBoundsAttr.
MLIR_CAPI_EXPORTED MlirTypeID mlirWaveReadWriteBoundsAttrGetTypeID();
/// Returns the typeID of a WaveSymbolMappingAttr.
MLIR_CAPI_EXPORTED MlirTypeID mlirWaveSymbolMappingAttrGetTypeID();

//===---------------------------------------------------------------------===//
// HardwareConstraintAttr
Expand Down
54 changes: 37 additions & 17 deletions water/lib/CAPI/Dialects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,35 +455,55 @@ MlirAttribute mlirWaveExprListAttrGetSymbol(MlirAttribute attr,
llvm::cast<wave::WaveExprListAttr>(unwrap(attr)).getSymbols()[index]);
}
//===---------------------------------------------------------------------===//
// WaveReadWriteBoundsAttr
// WaveSymbolMappingAttr
//===---------------------------------------------------------------------===//

bool mlirAttributeIsAWaveReadWriteBoundsAttr(MlirAttribute attr) {
return llvm::isa<wave::WaveReadWriteBoundsAttr>(unwrap(attr));
bool mlirAttributeIsAWaveSymbolMappingAttr(MlirAttribute attr) {
return llvm::isa<wave::WaveSymbolMappingAttr>(unwrap(attr));
}

MlirAttribute mlirWaveReadWriteBoundsAttrGet(MlirAttribute mapping) {
auto dictAttr = llvm::cast<DictionaryAttr>(unwrap(mapping));
MlirAttribute mlirWaveSymbolMappingAttrGet(MlirContext ctx, intptr_t numEntries,
MlirAttribute *keys,
MlirAttribute *values) {
SmallVector<wave::WaveSymbolAttr> keyAttrs;
SmallVector<wave::WaveExprListAttr> valueAttrs;
keyAttrs.reserve(numEntries);
valueAttrs.reserve(numEntries);
for (intptr_t i = 0; i < numEntries; ++i) {
keyAttrs.push_back(llvm::cast<wave::WaveSymbolAttr>(unwrap(keys[i])));
valueAttrs.push_back(llvm::cast<wave::WaveExprListAttr>(unwrap(values[i])));
}
return wrap(
wave::WaveSymbolMappingAttr::get(unwrap(ctx), keyAttrs, valueAttrs));
}

MLIRContext *ctx = dictAttr.getContext();
intptr_t mlirWaveSymbolMappingAttrGetNumEntries(MlirAttribute attr) {
return llvm::cast<wave::WaveSymbolMappingAttr>(unwrap(attr)).getNumEntries();
}

assert(llvm::all_of(dictAttr,
[](const NamedAttribute &namedAttr) {
return llvm::isa<wave::WaveExprListAttr>(
namedAttr.getValue());
}) &&
"expected mapping to contain only WaveExprListAttr values");
MlirAttribute mlirWaveSymbolMappingAttrGetKey(MlirAttribute attr,
intptr_t index) {
return wrap(
llvm::cast<wave::WaveSymbolMappingAttr>(unwrap(attr)).getKeys()[index]);
}

return wrap(wave::WaveReadWriteBoundsAttr::get(ctx, dictAttr));
MlirAttribute mlirWaveSymbolMappingAttrGetValue(MlirAttribute attr,
intptr_t index) {
return wrap(
llvm::cast<wave::WaveSymbolMappingAttr>(unwrap(attr)).getValues()[index]);
}

MlirAttribute mlirWaveReadWriteBoundsAttrGetMapping(MlirAttribute attr) {
MlirAttribute mlirWaveSymbolMappingAttrLookup(MlirAttribute attr,
MlirAttribute key) {
auto keyAttr = llvm::dyn_cast<wave::WaveSymbolAttr>(unwrap(key));
if (!keyAttr)
return MlirAttribute();
return wrap(
llvm::cast<wave::WaveReadWriteBoundsAttr>(unwrap(attr)).getMapping());
llvm::cast<wave::WaveSymbolMappingAttr>(unwrap(attr)).lookup(keyAttr));
}

MlirTypeID mlirWaveReadWriteBoundsAttrGetTypeID() {
return wrap(TypeID::get<wave::WaveReadWriteBoundsAttr>());
MlirTypeID mlirWaveSymbolMappingAttrGetTypeID() {
return wrap(TypeID::get<wave::WaveSymbolMappingAttr>());
}

//===---------------------------------------------------------------------===//
Expand Down
14 changes: 14 additions & 0 deletions water/lib/Dialect/Wave/IR/WaveAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,20 @@ DeviceConstraintAttr::verify(function_ref<InFlightDiagnostic()> emitError,
// WaveSymbolMappingAttr
//===----------------------------------------------------------------------===//

WaveSymbolMappingAttr WaveSymbolMappingAttr::get(
MLIRContext *context,
ArrayRef<std::pair<WaveSymbolAttr, WaveExprListAttr>> entries) {
SmallVector<WaveSymbolAttr> keys;
SmallVector<WaveExprListAttr> values;
keys.reserve(entries.size());
values.reserve(entries.size());
for (auto &[k, v] : entries) {
keys.push_back(k);
values.push_back(v);
}
return Base::get(context, keys, values);
}

Attribute WaveSymbolMappingAttr::parse(AsmParser &parser, Type) {
// Capture the location before consuming any tokens so that verification
// errors are reported at the opening `<`.
Expand Down
46 changes: 15 additions & 31 deletions water/lib/Dialect/Wave/IR/WaveOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
Expand Down Expand Up @@ -1595,33 +1594,20 @@ verifyIndexElementsPerThread(Operation *op, ArrayAttr indexAttr,
// will not generate mask operations during lowering.
static LogicalResult verifyReadWriteBounds(Location loc,
wave::WaveTensorType boundedType,
DictionaryAttr bounds) {
WaveSymbolMappingAttr bounds) {
assert(bounds && "expected non-null bounds");
assert(boundedType && "expected non-null type");

// We need a fixed iteration order of names for determinism of error messages,
// so using a vector instead of a StringSet.
// TODO: consider refactoring bounds and other dictionary-like attributes to
// be indexed by symbol expressions rather than string attributes to avoid
// string comparisons everywhere.
SmallVector<StringRef> validSymbolNames = llvm::map_to_vector(
boundedType.getShape(),
[](wave::WaveSymbolAttr symbol) { return symbol.getName(); });

for (NamedAttribute value : bounds) {
if (!llvm::is_contained(validSymbolNames, value.getName().strref())) {
ArrayRef<wave::WaveSymbolAttr> validSymbols = boundedType.getShape();

for (auto [key, value] : llvm::zip(bounds.getKeys(), bounds.getValues())) {
if (!llvm::is_contained(validSymbols, key)) {
return emitError(loc)
<< "'bounds' specified for a symbol " << value.getName()
<< "'bounds' specified for a symbol " << key.getName()
<< " not used in the "
"indexed memory tensor";
}

// Value type must be WaveExprListAttr.
auto exprListAttr = dyn_cast<wave::WaveExprListAttr>(value.getValue());
if (!exprListAttr)
return emitError(loc) << "'bounds' values must be WaveExprListAttr, got "
<< value.getValue();
if (exprListAttr.getRank() != 1) {
if (value.getRank() != 1) {
return emitError(loc)
<< "'bounds' must only contain single-result expressions";
}
Expand All @@ -1634,7 +1620,7 @@ static LogicalResult verifyReadWriteBounds(Location loc,
static LogicalResult verifyReadWriteOp(Operation *op, ArrayAttr indexAttr,
std::optional<int64_t> elementsPerThread,
Type memoryType, Type valueType,
WaveReadWriteBoundsAttr bounds,
WaveSymbolMappingAttr bounds,
ArrayAttr orderedSyms) {

if (failed(wave::detail::verifyElementTypesMatch(
Expand Down Expand Up @@ -1675,7 +1661,7 @@ static LogicalResult verifyReadWriteOp(Operation *op, ArrayAttr indexAttr,
if (!bounds)
return success();

return verifyReadWriteBounds(op->getLoc(), tensorType, bounds.getMapping());
return verifyReadWriteBounds(op->getLoc(), tensorType, bounds);
}

LogicalResult ReadOp::verify() {
Expand Down Expand Up @@ -2422,11 +2408,12 @@ permuteIndexExprsStrides(const IndexExprsLatticeStorage &inputLattice,

DictionaryAttr inputDict = inputLattice.getConcreteValue();

llvm::StringMap<WaveIndexMappingAttr> symbolToMapping;
llvm::DenseMap<WaveSymbolAttr, WaveIndexMappingAttr> symbolToMapping;
for (NamedAttribute namedAttr : inputDict) {
if (auto mapping =
llvm::dyn_cast<WaveIndexMappingAttr>(namedAttr.getValue())) {
symbolToMapping[namedAttr.getName().getValue()] = mapping;
auto key = WaveSymbolAttr::get(ctx, namedAttr.getName());
symbolToMapping[key] = mapping;
}
}

Expand All @@ -2438,11 +2425,8 @@ permuteIndexExprsStrides(const IndexExprsLatticeStorage &inputLattice,
permutedMappings.reserve(srcShape.size());
for (auto [srcSymbol, targetSymbol] :
llvm::zip_equal(srcShape, targetShape)) {
llvm::StringRef srcName = srcSymbol.getName();
auto srcMappingIt = symbolToMapping.find(srcName);

llvm::StringRef targetName = targetSymbol.getName();
auto targetMappingIt = symbolToMapping.find(targetName);
auto srcMappingIt = symbolToMapping.find(srcSymbol);
auto targetMappingIt = symbolToMapping.find(targetSymbol);

assert(srcMappingIt != symbolToMapping.end() &&
"source mapping not found for symbol");
Expand All @@ -2469,7 +2453,7 @@ permuteIndexExprsStrides(const IndexExprsLatticeStorage &inputLattice,
alignedStep, alignedStride);

permutedMappings.push_back(
NamedAttribute(StringAttr::get(ctx, srcName), newMapping));
NamedAttribute(StringAttr::get(ctx, srcSymbol.getName()), newMapping));
}

return IndexExprsLatticeStorage(DictionaryAttr::get(ctx, permutedMappings));
Expand Down
Loading