From 8727650ea52ed660f4447c9276e6ddfca6d34646 Mon Sep 17 00:00:00 2001 From: Alex Collins Date: Wed, 28 Aug 2024 09:04:50 +0100 Subject: [PATCH] Allow dot operand hoisting for math dialect, arith.truncf, arith.trunci --- .../TritonGPU/Transforms/OptimizeDotOperands.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index e6e0ec8d7cef..fc2bd52c7b72 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -131,19 +131,21 @@ class HoistLayoutConversion : public OpRewritePattern { // bitwidth is unable to realize that there is a mixed-precision dot // (hence kWidth = 1) but wants to hoist through the type conversion. if (isa(src) && dotOpEnc.getKWidth() == 1) - return failure(); + return failure(); - // Only consider custom conversions or arith ops. + // Only consider custom conversions, math or arith ops. // TODO(jlebar): Is this too restrictive? if (!isa(src) && !isPureUnaryInlineAsm(src) && - src->getDialect()->getTypeID() != TypeID::get()) + src->getDialect()->getTypeID() != TypeID::get() && + src->getDialect()->getTypeID() != TypeID::get()) return failure(); // Currently, these instructions are not supported during lowering of // shared -> dot_operand layout. Not all types and type conversions are // supported. - if (isa(src)) + if (isa(src)) { return failure(); + } // Don't hoist through u1 -> fp casts as they aren't supported in // ElementwiseOpToLLVM::reorderValues().