diff --git a/water/include/water/Dialect/Wave/IR/WaveInterfaces.h b/water/include/water/Dialect/Wave/IR/WaveInterfaces.h index 421340f82..2a9117382 100644 --- a/water/include/water/Dialect/Wave/IR/WaveInterfaces.h +++ b/water/include/water/Dialect/Wave/IR/WaveInterfaces.h @@ -266,6 +266,10 @@ verifyTypesMatchingDimensions(std::optional loc, // corresponding flag is set, compatible address spaces. llvm::LogicalResult verifyCompatibleOperandsAndResultsOpTrait( mlir::Operation *op, bool includeAddressSpace, bool includeElementalType); + +// Verification logic for the equal-element-types trait. Succeeds if all +// operands and results have the same element type. +llvm::LogicalResult verifyEqualElementTypesOpTrait(mlir::Operation *op); }; // namespace detail template @@ -301,6 +305,15 @@ class CompatibleOperandsAndResultsShapeOpTrait } }; +template +class EqualElementTypesOpTrait + : public mlir::OpTrait::TraitBase { +public: + static llvm::LogicalResult verifyTrait(mlir::Operation *op) { + return detail::verifyEqualElementTypesOpTrait(op); + } +}; + //----------------------------------------------------------------------------- // WaveElementsPerThreadOpInterface //----------------------------------------------------------------------------- diff --git a/water/include/water/Dialect/Wave/IR/WaveInterfaces.td b/water/include/water/Dialect/Wave/IR/WaveInterfaces.td index d8481f473..0c97a892b 100644 --- a/water/include/water/Dialect/Wave/IR/WaveInterfaces.td +++ b/water/include/water/Dialect/Wave/IR/WaveInterfaces.td @@ -181,6 +181,11 @@ def CompatibleOperandsAndResultsShapeOpTrait let cppNamespace = "::wave"; } +def EqualElementTypesOpTrait + : NativeOpTrait<"EqualElementTypesOpTrait"> { + let cppNamespace = "::wave"; +} + //----------------------------------------------------------------------------- // WaveInferIndexExprsOpInterface and implementation traits //----------------------------------------------------------------------------- diff --git a/water/include/water/Dialect/Wave/IR/WaveOps.td b/water/include/water/Dialect/Wave/IR/WaveOps.td index 5379d27a0..d26e1a167 100644 --- a/water/include/water/Dialect/Wave/IR/WaveOps.td +++ b/water/include/water/Dialect/Wave/IR/WaveOps.td @@ -373,7 +373,8 @@ def AllocateOp : WaveOp<"allocate", [ } def ExtractOp : WaveOp<"extract", - [DeclareOpInterfaceMethods]> { + [DeclareOpInterfaceMethods, + EqualElementTypesOpTrait]> { let summary = "Extracts a single element from a vector at the given index"; let description = [{ This is an internal operation that appears during expansion/lowering and @@ -661,7 +662,8 @@ def CastOp : WaveOp<"cast", [ def BroadcastOp : WaveOp<"broadcast", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - WaveInferIndexExprsOpInterface, IdentityIndexExprsOpTrait]> { + WaveInferIndexExprsOpInterface, IdentityIndexExprsOpTrait, + EqualElementTypesOpTrait]> { let summary = "Broadcast a tensor to a larger shape by replicating values"; let description = [{ Broadcasts the source tensor to the result shape by replicating values @@ -740,7 +742,8 @@ def SelfIndexOp : WaveOp<"self_index", [ def PermuteOp : WaveOp<"permute", [ DeclareOpInterfaceMethods, WaveElementsPerThreadOpInterface, IdentityElementsPerThreadOpTrait, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + EqualElementTypesOpTrait]> { let summary = "Permute the dimensions of a register-resident tensor"; let description = [{ Reorders the symbolic dimensions of a register-resident tensor according diff --git a/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp b/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp index ee72f069d..977a17dd2 100644 --- a/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp +++ b/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp @@ -685,6 +685,48 @@ llvm::LogicalResult wave::detail::verifyCompatibleOperandsAndResultsOpTrait( kResultNamePrefix, os.str()); } +llvm::LogicalResult +wave::detail::verifyEqualElementTypesOpTrait(Operation *op) { + if (op->getNumOperands() == 0 && op->getNumResults() == 0) + return llvm::success(); + + // Get the reference type from the first operand if available, otherwise from + // the first result + Type referenceType; + llvm::StringRef referenceName; + bool isOperandReference = false; + if (op->getNumOperands() > 0) { + referenceType = op->getOperandTypes()[0]; + referenceName = "operand #0"; + isOperandReference = true; + } else { + referenceType = op->getResultTypes()[0]; + referenceName = "result #0"; + } + + // Verify all operand element types match (skip operand #0 if it's the + // reference to avoid redundant check) + unsigned startIdx = isOperandReference ? 1 : 0; + for (unsigned idx = startIdx; idx < op->getNumOperands(); ++idx) { + if (failed(verifyElementTypesMatch( + op->getLoc(), referenceName, referenceType, + "operand #" + llvm::Twine(idx), op->getOperandTypes()[idx]))) + return llvm::failure(); + } + + // Verify all result element types match (skip result #0 if it's the + // reference to avoid redundant check) + startIdx = !isOperandReference ? 1 : 0; + for (unsigned idx = startIdx; idx < op->getNumResults(); ++idx) { + if (failed(verifyElementTypesMatch( + op->getLoc(), referenceName, referenceType, + "result #" + llvm::Twine(idx), op->getResultTypes()[idx]))) + return llvm::failure(); + } + + return llvm::success(); +} + //----------------------------------------------------------------------------- // Lattice implementation //----------------------------------------------------------------------------- diff --git a/water/lib/Dialect/Wave/IR/WaveOps.cpp b/water/lib/Dialect/Wave/IR/WaveOps.cpp index 547c07874..7af3296db 100644 --- a/water/lib/Dialect/Wave/IR/WaveOps.cpp +++ b/water/lib/Dialect/Wave/IR/WaveOps.cpp @@ -1715,12 +1715,6 @@ LogicalResult ExtractOp::verify() { << position.getRank(); } - if (failed(detail::verifyElementTypesMatch(getLoc(), "source", - getSource().getType(), "result", - getResult().getType()))) { - return failure(); - } - if (auto resultVectorType = dyn_cast(getResult().getType())) { if (resultVectorType.getShape()[0] != 1) { return emitOpError() << "result must be a 1-element vector, got " @@ -2151,11 +2145,6 @@ llvm::SmallVector wave::BroadcastOp::inferBroadcastDims() { } LogicalResult wave::BroadcastOp::verify() { - if (failed(detail::verifyElementTypesMatch(getLoc(), "source", - getSource().getType(), "result", - getResult().getType()))) - return failure(); - auto sourceType = llvm::dyn_cast(getSource().getType()); auto resultType = llvm::dyn_cast(getResult().getType()); @@ -2290,10 +2279,6 @@ LogicalResult wave::PermuteOp::verify() { Value input = getValue(); Value result = getResult(); - if (failed(detail::verifyElementTypesMatch(getLoc(), "input", input.getType(), - "result", result.getType()))) - return failure(); - auto inputType = dyn_cast(input.getType()); auto resultType = dyn_cast(result.getType()); diff --git a/water/test/Dialect/Wave/ops-invalid.mlir b/water/test/Dialect/Wave/ops-invalid.mlir index 0f2a3b662..a1682c596 100644 --- a/water/test/Dialect/Wave/ops-invalid.mlir +++ b/water/test/Dialect/Wave/ops-invalid.mlir @@ -599,6 +599,14 @@ func.func @extract_dimension_mismatch(%src: !wave.tensor<[@M, @N] of f32>) { // ----- +func.func @extract_element_type_mismatch(%src: !wave.tensor<[@M, @N] of f32>) { + // expected-error @below {{expected operand #0 and result #0 elemental types to match}} + %0 = wave.extract %src[#wave.expr_list<[] -> (0)>] : (!wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M] of f16> + return +} + +// ----- + func.func @extract_slice_mismatch_offset_size(%memory: !wave.tensor<[@A, @B] of f16>) { // expected-error @below {{offset, size, and stride must all have the same rank, but got offset rank 1, size rank 2, and stride rank 1}} wave.extract_slice %memory {offset = #wave.expr_list<[] -> (3)>, size = #wave.expr_list<[] -> (32, 16)>, stride = #wave.expr_list<[] -> (2)>} : (!wave.tensor<[@A, @B] of f16>) -> !wave.tensor<[@A, @B] of f16> @@ -843,7 +851,7 @@ func.func @broadcast_explicit_dims_with_fully_specified_types(%arg0: !wave.tenso func.func @broadcast_element_type_mismatch(%arg0: !wave.tensor<[@M, @N] of f32, >) { // Source and result must have matching element types. - // expected-error @below {{expected source and result elemental types to match, got 'f32', 'f16'}} + // expected-error @below {{expected operand #0 and result #0 elemental types to match}} wave.broadcast %arg0 : (!wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N, @K] of f16, > return } @@ -893,7 +901,7 @@ func.func @permute_empty_result_shape(%arg0: !wave.tensor<[@M, @N] of f32, >) { - // expected-error @below {{expected input and result elemental types to match, got 'f32', 'f16'}} + // expected-error @below {{expected operand #0 and result #0 elemental types to match}} wave.permute %arg0 : !wave.tensor<[@M, @N] of f32, > to !wave.tensor<[@N, @M] of f16, > return }