Skip to content
Open
13 changes: 13 additions & 0 deletions water/include/water/Dialect/Wave/IR/WaveInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,10 @@ verifyTypesMatchingDimensions(std::optional<mlir::Location> loc,
// corresponding flag is set, compatible address spaces.
llvm::LogicalResult verifyCompatibleOperandsAndResultsOpTrait(
mlir::Operation *op, bool includeAddressSpace, bool includeElementalType);

// Verification logic for the equal-element-types trait. Succeeds if all
// operands and results have the same element type.
llvm::LogicalResult verifyEqualElementTypesOpTrait(mlir::Operation *op);
}; // namespace detail

template <typename OpTy>
Expand Down Expand Up @@ -301,6 +305,15 @@ class CompatibleOperandsAndResultsShapeOpTrait
}
};

template <typename OpTy>
class EqualElementTypesOpTrait
: public mlir::OpTrait::TraitBase<OpTy, EqualElementTypesOpTrait> {
public:
static llvm::LogicalResult verifyTrait(mlir::Operation *op) {
return detail::verifyEqualElementTypesOpTrait(op);
}
};

//-----------------------------------------------------------------------------
// WaveElementsPerThreadOpInterface
//-----------------------------------------------------------------------------
Expand Down
5 changes: 5 additions & 0 deletions water/include/water/Dialect/Wave/IR/WaveInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,11 @@ def CompatibleOperandsAndResultsShapeOpTrait
let cppNamespace = "::wave";
}

def EqualElementTypesOpTrait
: NativeOpTrait<"EqualElementTypesOpTrait"> {
let cppNamespace = "::wave";
}

//-----------------------------------------------------------------------------
// WaveInferIndexExprsOpInterface and implementation traits
//-----------------------------------------------------------------------------
Expand Down
9 changes: 6 additions & 3 deletions water/include/water/Dialect/Wave/IR/WaveOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,8 @@ def AllocateOp : WaveOp<"allocate", [
}

def ExtractOp : WaveOp<"extract",
[DeclareOpInterfaceMethods<WaveElementsPerThreadOpInterface>]> {
[DeclareOpInterfaceMethods<WaveElementsPerThreadOpInterface>,
EqualElementTypesOpTrait]> {
let summary = "Extracts a single element from a vector at the given index";
let description = [{
This is an internal operation that appears during expansion/lowering and
Expand Down Expand Up @@ -661,7 +662,8 @@ def CastOp : WaveOp<"cast", [
def BroadcastOp : WaveOp<"broadcast", [
DeclareOpInterfaceMethods<WaveInferTypeOpInterface>,
DeclareOpInterfaceMethods<WaveElementsPerThreadOpInterface>,
WaveInferIndexExprsOpInterface, IdentityIndexExprsOpTrait]> {
WaveInferIndexExprsOpInterface, IdentityIndexExprsOpTrait,
EqualElementTypesOpTrait]> {
let summary = "Broadcast a tensor to a larger shape by replicating values";
let description = [{
Broadcasts the source tensor to the result shape by replicating values
Expand Down Expand Up @@ -740,7 +742,8 @@ def SelfIndexOp : WaveOp<"self_index", [
def PermuteOp : WaveOp<"permute", [
DeclareOpInterfaceMethods<WaveInferTypeOpInterface>,
WaveElementsPerThreadOpInterface, IdentityElementsPerThreadOpTrait,
DeclareOpInterfaceMethods<WaveInferIndexExprsOpInterface>]> {
DeclareOpInterfaceMethods<WaveInferIndexExprsOpInterface>,
EqualElementTypesOpTrait]> {
let summary = "Permute the dimensions of a register-resident tensor";
let description = [{
Reorders the symbolic dimensions of a register-resident tensor according
Expand Down
42 changes: 42 additions & 0 deletions water/lib/Dialect/Wave/IR/WaveInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,48 @@ llvm::LogicalResult wave::detail::verifyCompatibleOperandsAndResultsOpTrait(
kResultNamePrefix, os.str());
}

llvm::LogicalResult
wave::detail::verifyEqualElementTypesOpTrait(Operation *op) {
if (op->getNumOperands() == 0 && op->getNumResults() == 0)
return llvm::success();

// Get the reference type from the first operand if available, otherwise from
// the first result
Type referenceType;
llvm::StringRef referenceName;
bool isOperandReference = false;
if (op->getNumOperands() > 0) {
referenceType = op->getOperandTypes()[0];
referenceName = "operand #0";
isOperandReference = true;
} else {
referenceType = op->getResultTypes()[0];
referenceName = "result #0";
}

// Verify all operand element types match (skip operand #0 if it's the
// reference to avoid redundant check)
unsigned startIdx = isOperandReference ? 1 : 0;
for (unsigned idx = startIdx; idx < op->getNumOperands(); ++idx) {
if (failed(verifyElementTypesMatch(
op->getLoc(), referenceName, referenceType,
"operand #" + llvm::Twine(idx), op->getOperandTypes()[idx])))
return llvm::failure();
}

// Verify all result element types match (skip result #0 if it's the
// reference to avoid redundant check)
startIdx = !isOperandReference ? 1 : 0;
for (unsigned idx = startIdx; idx < op->getNumResults(); ++idx) {
if (failed(verifyElementTypesMatch(
op->getLoc(), referenceName, referenceType,
"result #" + llvm::Twine(idx), op->getResultTypes()[idx])))
return llvm::failure();
}

return llvm::success();
}

//-----------------------------------------------------------------------------
// Lattice implementation
//-----------------------------------------------------------------------------
Expand Down
15 changes: 0 additions & 15 deletions water/lib/Dialect/Wave/IR/WaveOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1715,12 +1715,6 @@ LogicalResult ExtractOp::verify() {
<< position.getRank();
}

if (failed(detail::verifyElementTypesMatch(getLoc(), "source",
getSource().getType(), "result",
getResult().getType()))) {
return failure();
}

if (auto resultVectorType = dyn_cast<VectorType>(getResult().getType())) {
if (resultVectorType.getShape()[0] != 1) {
return emitOpError() << "result must be a 1-element vector, got "
Expand Down Expand Up @@ -2151,11 +2145,6 @@ llvm::SmallVector<WaveSymbolAttr> wave::BroadcastOp::inferBroadcastDims() {
}

LogicalResult wave::BroadcastOp::verify() {
if (failed(detail::verifyElementTypesMatch(getLoc(), "source",
getSource().getType(), "result",
getResult().getType())))
return failure();

auto sourceType = llvm::dyn_cast<WaveTensorType>(getSource().getType());
auto resultType = llvm::dyn_cast<WaveTensorType>(getResult().getType());

Expand Down Expand Up @@ -2290,10 +2279,6 @@ LogicalResult wave::PermuteOp::verify() {
Value input = getValue();
Value result = getResult();

if (failed(detail::verifyElementTypesMatch(getLoc(), "input", input.getType(),
"result", result.getType())))
return failure();

auto inputType = dyn_cast<WaveTensorType>(input.getType());
auto resultType = dyn_cast<WaveTensorType>(result.getType());

Expand Down
12 changes: 10 additions & 2 deletions water/test/Dialect/Wave/ops-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,14 @@ func.func @extract_dimension_mismatch(%src: !wave.tensor<[@M, @N] of f32>) {

// -----

func.func @extract_element_type_mismatch(%src: !wave.tensor<[@M, @N] of f32>) {
// expected-error @below {{expected operand #0 and result #0 elemental types to match}}
%0 = wave.extract %src[#wave.expr_list<[] -> (0)>] : (!wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M] of f16>
return
}

// -----

func.func @extract_slice_mismatch_offset_size(%memory: !wave.tensor<[@A, @B] of f16>) {
// expected-error @below {{offset, size, and stride must all have the same rank, but got offset rank 1, size rank 2, and stride rank 1}}
wave.extract_slice %memory {offset = #wave.expr_list<[] -> (3)>, size = #wave.expr_list<[] -> (32, 16)>, stride = #wave.expr_list<[] -> (2)>} : (!wave.tensor<[@A, @B] of f16>) -> !wave.tensor<[@A, @B] of f16>
Expand Down Expand Up @@ -843,7 +851,7 @@ func.func @broadcast_explicit_dims_with_fully_specified_types(%arg0: !wave.tenso

func.func @broadcast_element_type_mismatch(%arg0: !wave.tensor<[@M, @N] of f32, <register>>) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot Please also a negative test for ExtractOp, where the trait was added as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in 0b07423. Added test extract_element_type_mismatch that verifies the trait catches element type mismatches between source and result.

// Source and result must have matching element types.
// expected-error @below {{expected source and result elemental types to match, got 'f32', 'f16'}}
// expected-error @below {{expected operand #0 and result #0 elemental types to match}}
wave.broadcast %arg0 : (!wave.tensor<[@M, @N] of f32, <register>>) -> !wave.tensor<[@M, @N, @K] of f16, <register>>
return
}
Expand Down Expand Up @@ -893,7 +901,7 @@ func.func @permute_empty_result_shape(%arg0: !wave.tensor<[@M, @N] of f32, <regi

// Test that permute input and result element types must match
func.func @permute_element_type_mismatch(%arg0: !wave.tensor<[@M, @N] of f32, <register>>) {
// expected-error @below {{expected input and result elemental types to match, got 'f32', 'f16'}}
// expected-error @below {{expected operand #0 and result #0 elemental types to match}}
wave.permute %arg0 : !wave.tensor<[@M, @N] of f32, <register>> to !wave.tensor<[@N, @M] of f16, <register>>
return
}
Expand Down
Loading