From 5a1e295e642975a476dc805c0b80b45daf274e01 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Thu, 16 Jan 2025 12:38:53 +0900 Subject: [PATCH 1/6] [NNPA] Revise compiler options for quantization (#3043) * Introduce two new options -nnpa-quant-dynamic and -nnpa-quant-op-types, and remove the old option --nnpa-quanzation. Signed-off-by: Tung D. Le --------- Signed-off-by: Tung D. Le --- docs/AddCustomAccelerators.md | 7 ++ src/Accelerators/Accelerator.hpp | 7 ++ .../NNPA/Compiler/NNPACompilerOptions.cpp | 52 +++++++--- .../NNPA/Compiler/NNPACompilerOptions.hpp | 15 +-- .../NNPA/Compiler/NNPACompilerUtils.cpp | 51 +++++++++- .../ONNXToZHigh/DevicePlacement.cpp | 6 +- .../ONNXToZHigh/ONNXLegalityCheck.cpp | 2 +- .../Conversion/ONNXToZHigh/ONNXToZHigh.cpp | 99 ++++++++----------- .../Conversion/ONNXToZHigh/ONNXToZHigh.hpp | 6 +- .../ONNXToZHigh/ONNXToZHighCommon.cpp | 25 ++++- .../ONNXToZHigh/ONNXToZHighCommon.hpp | 9 +- src/Accelerators/NNPA/NNPAAccelerator.cpp | 5 + src/Accelerators/NNPA/NNPAAccelerator.hpp | 3 +- src/Accelerators/NNPA/Pass/NNPAPasses.hpp | 5 +- src/Compiler/CompilerUtils.cpp | 1 + src/Tools/onnx-mlir-opt/onnx-mlir-opt.cpp | 2 + .../onnx-to-zhigh/quantization.mlir | 4 +- 17 files changed, 198 insertions(+), 101 deletions(-) diff --git a/docs/AddCustomAccelerators.md b/docs/AddCustomAccelerators.md index 722abc6ee3..4047cd65f8 100644 --- a/docs/AddCustomAccelerators.md +++ b/docs/AddCustomAccelerators.md @@ -92,6 +92,13 @@ virtual void registerDialects(mlir::DialectRegistry ®istry) const = 0; /// command line options. virtual void registerPasses(int optLevel) const = 0; +//===--------------------------------------------------------------------===// +// Hooks for both onnx-mlir and onnx-mlir-opt drivers +//===--------------------------------------------------------------------===// + +/// Configure passes for the accelerator. +virtual void configurePasses() const = 0; + //===--------------------------------------------------------------------===// // Hooks for onnx-to-krnl pass //===--------------------------------------------------------------------===// diff --git a/src/Accelerators/Accelerator.hpp b/src/Accelerators/Accelerator.hpp index 5c2b47187e..e10449cdf1 100644 --- a/src/Accelerators/Accelerator.hpp +++ b/src/Accelerators/Accelerator.hpp @@ -108,6 +108,13 @@ class Accelerator { /// command line options. virtual void registerPasses(int optLevel) const = 0; + //===--------------------------------------------------------------------===// + // Hooks for both onnx-mlir and onnx-mlir-opt drivers + //===--------------------------------------------------------------------===// + + /// Configure passes for the accelerator. + virtual void configurePasses() const = 0; + //===--------------------------------------------------------------------===// // Hooks for onnx-to-krnl pass //===--------------------------------------------------------------------===// diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp index 52d7933888..34457eafd8 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp @@ -17,6 +17,10 @@ namespace onnx_mlir { +// Use external storage for the options so that they are globally accessible +std::vector nnpaQuantDynamic; // common for both +std::vector nnpaQuantOpTypes; // common for both + llvm::cl::opt nnpaEmissionTarget( llvm::cl::desc("[Optional] Choose NNPA-related target to emit " "(once selected it will cancel the other targets):"), @@ -101,6 +105,41 @@ llvm::cl::opt nnpaEnableSaturation("nnpa-saturation", "Default is false."), llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions)); +llvm::cl::list> + nnpaQuantDynamicOpt("nnpa-quant-dynamic", + llvm::cl::desc( + "Enable dynamic quantization of the input model. If enabled, it " + "only quantizes from fp32 to i8. If an ONNX operation is already " + "in i8, no quantization is applied to that operation. Optionally, " + "a comma-separated list of quantization options can be specified " + "as its value, e.g. -nnpa-quant-dynamic=symActivation,symWeight."), + llvm::cl::values(clEnumVal(symWeight, "Symmetric quant for weights."), + clEnumVal(asymWeight, "Asymmetric quant for weights."), + clEnumVal(symActivation, "Symmetric quant for activations."), + clEnumVal(asymActivation, "Asymmetric quant for activations."), + // Use an empty string for the case where `--nnpa-quant-dynamic` is + // specified on the command line WITHOUT value, which is different + // from the case where `--nnpa-quant-dynamic` is NOT specified on + // the command line. + clEnumValN(autoQuantOpt, "", + "Compiler automatically finds the best options. Once this " + "option (an empty string) is in the list, the other options " + "are ignored. This is the default option when " + "`-nnpa-quant-dynamic` is specified without any value.")), + llvm::cl::location(nnpaQuantDynamic), llvm::cl::ValueOptional, + llvm::cl::CommaSeparated, llvm::cl::cat(OnnxMlirCommonOptions)); + +llvm::cl::list> nnpaQuantOpTypesOpt( + "nnpa-quant-op-types", + llvm::cl::desc( + "A comma-separated list of types of operations that are quantized. " + "E.g. 'MatMul,Conv'. Strings for types are the same as ONNX operator " + "names in https://onnx.ai/onnx/operators/. Currently, only MatMul is " + "supported. Without specifying this option, the compiler will " + "determine the operation types by itself."), + llvm::cl::location(nnpaQuantOpTypes), llvm::cl::ValueOptional, + llvm::cl::CommaSeparated, llvm::cl::cat(OnnxMlirCommonOptions)); + llvm::cl::opt nnpaUseDynamicQuantizeLinearOnCPU("nnpa-cpu-dql", llvm::cl::desc("Use dynamic quantized linear on CPU. Default is false"), llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions)); @@ -111,17 +150,4 @@ llvm::cl::opt nnpaUseDynamicQuantizeLinearOnCPUForScaleOffset( " scale and offset on CPU. Default is false"), llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions)); -llvm::cl::opt nnpaQuantization("nnpa-quantization", - llvm::cl::desc("Enable quantization with a specific type. Only " - "MatMul whose weight is a constant is supported."), - llvm::cl::values( - clEnumVal(DynSymI8, - "Dynamic Quantization to signed integer 8. Asymmetric " - "quant for activations and symmetric quant for weights."), - clEnumVal(SymSymI8, - "Dynamic Quantization to signed integer 8. Symmetric " - "quant for activations and symmetric quant for weights."), - clEnumVal(QNONE, "No quantization (default).")), - llvm::cl::init(QNONE), llvm::cl::cat(OnnxMlirOptions)); - } // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp index 366efee3fe..e6f7cf6aa7 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp @@ -57,12 +57,12 @@ typedef enum { // Quantization type typedef enum { - DynSymI8, /* Dynamic quantization to signed integer 8. Asymmetric quant for - activations and symmetric quant for weights.*/ - SymSymI8, /* Dynamic quantization to signed integer 8. Symmetric quant for - activations and symmetric quant for weights.*/ - QNONE, /* Only qualifying ops that are faster on NNPA. */ -} NNPAQuantType; + symWeight, + asymWeight, + symActivation, + asymActivation, + autoQuantOpt, +} NNPAQuantOptions; extern llvm::cl::OptionCategory OnnxMlirOptions; extern llvm::cl::OptionCategory OnnxMlirCommonOptions; @@ -79,7 +79,8 @@ extern llvm::cl::opt nnpaSaveDevicePlacementFile; extern llvm::cl::opt nnpaEnableSaturation; extern llvm::cl::opt nnpaUseDynamicQuantizeLinearOnCPU; extern llvm::cl::opt nnpaUseDynamicQuantizeLinearOnCPUForScaleOffset; -extern llvm::cl::opt nnpaQuantization; +extern std::vector nnpaQuantDynamic; +extern std::vector nnpaQuantOpTypes; } // namespace onnx_mlir #endif diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp index d7c5cfcac0..45a9af09f8 100644 --- a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp +++ b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp @@ -49,11 +49,56 @@ using namespace onnx_mlir; namespace onnx_mlir { void configurePassesNNPA() { - configureOnnxToZHighLoweringPass(optReport == OptReport::NNPAUnsupportedOps); // z16 does not support for hardware saturation. // So, force its usage to compiler generated sticks. if (nnpaEnableSaturation && isLessEqualNNPALevel(NNPALevel::M14)) nnpaEnableCompilerStickUnstick = true; + + // Configure ONNXToZHighLoweringPass. + bool isDynQuant = !nnpaQuantDynamic.empty(); + // Default/auto mode: symmetric for weighs and asymmetric for activations. + bool isActivationSym = false; + bool isWeightSym = true; + std::vector quantOpTypes; + if (isDynQuant) { + // Set options for activations and weights if they are given. + // When auto mode is specified, the other specified options are ignored. + if (!llvm::is_contained(nnpaQuantDynamic, NNPAQuantOptions::autoQuantOpt)) { + for (unsigned i = 0; i < nnpaQuantDynamic.size(); ++i) { + switch (nnpaQuantDynamic[i]) { + case NNPAQuantOptions::symWeight: + isWeightSym = true; + break; + case NNPAQuantOptions::asymWeight: + isWeightSym = false; + break; + case NNPAQuantOptions::symActivation: + isActivationSym = true; + break; + case NNPAQuantOptions::asymActivation: + isActivationSym = false; + break; + default: + llvm_unreachable("Unsupported quantization options"); + break; + } + } + } + if (!isWeightSym) { + // TODO: Support asymmetric quantiation for weights. + llvm::outs() + << "Asymmetric quantization for weights is not yet supported. " + "Turning off quantization.\n"; + isDynQuant = false; + } + if (nnpaQuantOpTypes.empty()) { + quantOpTypes.emplace_back("MatMul"); + } else { + quantOpTypes = nnpaQuantOpTypes; + } + } + configureONNXToZHighLoweringPass(optReport == OptReport::NNPAUnsupportedOps, + isDynQuant, isActivationSym, isWeightSym, quantOpTypes); } void addONNXToZHighPasses(mlir::PassManager &pm) { @@ -85,7 +130,8 @@ void addONNXToZHighPasses(mlir::PassManager &pm) { pm.addNestedPass( onnx_mlir::createInstrumentPass(instrumentOps, instrumentActions)); - pm.addPass(onnx_mlir::createONNXToZHighPass(nnpaQuantization)); + // Lowering ONNX to ZHigh. + pm.addPass(onnx_mlir::createONNXToZHighPass()); pm.addNestedPass(onnx_mlir::createShapeInferencePass()); // There are more opportunities for const propagation once all zhigh ops were @@ -191,7 +237,6 @@ void addPassesNNPA(mlir::OwningOpRef &module, // Override pass configurations. configurePasses(); - configurePassesNNPA(); // LLVM_DEBUG(llvm::dbgs() << "Adding NNPA passes" << std::endl;); if (emissionTarget >= EmitONNXIR) { diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacement.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacement.cpp index 47724d8d3e..9979f0bbf3 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacement.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacement.cpp @@ -161,7 +161,7 @@ void DevicePlacementPass::runOnOperation() { // Disable reporting on NNPA unsupported ops in this pass even if // `-opt-report=NNPAUnsupportedOps` is specified.. - OnnxToZHighLoweringConfiguration::reportOnNNPAUnsupportedOps = 0; + ONNXToZHighLoweringConfiguration::reportOnNNPAUnsupportedOps = 0; // Run the unknown dimension analysis to help check equality of unknown // dimensions at compile time. @@ -200,13 +200,13 @@ void DevicePlacementPass::runOnOperation() { // Call ONNXToZHigh pass for lowering multiple ONNX ops at once to ZHigh. // E.g. `onnx.ReLu (onnx.Conv)` to zhigh.Conv. RewritePatternSet Patterns2(context); - getONNXToZHighMultipleOpPatterns(Patterns2, nnpaQuantization); + getONNXToZHighMultipleOpPatterns(Patterns2); (void)applyAnalysisConversion(module, target, std::move(Patterns2), ConversionConfig{.legalizableOps = &legalizedOps2}); // Call ONNXToZHigh pass for lowering a single ONNX op to ZHigh. RewritePatternSet Patterns3(context); - getONNXToZHighOneOpPatterns(Patterns3, nnpaQuantization); + getONNXToZHighOneOpPatterns(Patterns3); getONNXToZHighOneOpDynamicallyLegal(&target, &dimAnalysis); (void)applyAnalysisConversion(module, target, std::move(Patterns3), ConversionConfig{.legalizableOps = &legalizedOps3}); diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp index 76fa3fa547..80c41b77ae 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp @@ -27,7 +27,7 @@ using namespace onnx_mlir; /// Report NNPA unsupported case. bool onnxToZHighUnsupportedReport(Operation *op, const std::string &message) { - if (OnnxToZHighLoweringConfiguration::reportOnNNPAUnsupportedOps && + if (ONNXToZHighLoweringConfiguration::reportOnNNPAUnsupportedOps && !message.empty()) { StringAttr opName = op->getName().getIdentifier(); std::string nodeNameStr = getNodeNameInPresenceOfOpt(op); diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp index 78e94a6a2a..921fba751d 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp @@ -644,8 +644,8 @@ class replaceONNXMatMulByDynQuantI8Pattern using OpRewritePattern::OpRewritePattern; replaceONNXMatMulByDynQuantI8Pattern( - MLIRContext *context, PatternBenefit benefit = 1, bool symForA = false) - : OpRewritePattern(context, benefit), symForA(symForA) {} + MLIRContext *context, PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit) {} LogicalResult matchAndRewrite( ONNXMatMulOp mmOp, PatternRewriter &rewriter) const override { @@ -655,7 +655,8 @@ class replaceONNXMatMulByDynQuantI8Pattern Value B = mmOp.getB(); // Dynamic quantization helper. - DynQuantI8PatternHelper dqHelper(rewriter, loc, op, A, B, nullptr, symForA); + DynQuantI8PatternHelper dqHelper(rewriter, loc, op, A, B, nullptr, + ONNXToZHighLoweringConfiguration::Quant::isActivationSym); // Match if (!isSuitableForZDNN(mmOp) || failed(dqHelper.match())) @@ -666,9 +667,6 @@ class replaceONNXMatMulByDynQuantI8Pattern rewriter.replaceOp(op, res); return success(); } - -private: - bool symForA = false; }; /** @@ -684,8 +682,8 @@ class replaceONNXMatMulAddByDynQuantI8Pattern using OpRewritePattern::OpRewritePattern; replaceONNXMatMulAddByDynQuantI8Pattern( - MLIRContext *context, PatternBenefit benefit = 1, bool symForA = false) - : OpRewritePattern(context, benefit), symForA(symForA) {} + MLIRContext *context, PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit) {} LogicalResult matchAndRewrite( ONNXAddOp addOp, PatternRewriter &rewriter) const override { @@ -704,7 +702,8 @@ class replaceONNXMatMulAddByDynQuantI8Pattern Value B = mmOp.getB(); // Match A, B, C. - DynQuantI8PatternHelper dqHelper(rewriter, loc, op, A, B, C, symForA); + DynQuantI8PatternHelper dqHelper(rewriter, loc, op, A, B, C, + ONNXToZHighLoweringConfiguration::Quant::isActivationSym); if (succeeded(dqHelper.match())) { Value res = dqHelper.rewriteSym(); rewriter.replaceOp(op, res); @@ -713,9 +712,6 @@ class replaceONNXMatMulAddByDynQuantI8Pattern return failure(); } - -private: - bool symForA = false; }; /** @@ -732,8 +728,8 @@ class replaceONNXGemmByDynQuantI8Pattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; replaceONNXGemmByDynQuantI8Pattern( - MLIRContext *context, PatternBenefit benefit = 1, bool symForA = false) - : OpRewritePattern(context, benefit), symForA(symForA) {} + MLIRContext *context, PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit) {} LogicalResult matchAndRewrite( ONNXGemmOp gemmOp, PatternRewriter &rewriter) const override { @@ -747,8 +743,9 @@ class replaceONNXGemmByDynQuantI8Pattern : public OpRewritePattern { bool transB = (gemmOp.getTransB() != 0); // Dynamic quantization helper. - DynQuantI8PatternHelper dqHelper( - rewriter, loc, op, A, B, isNoneValue(C) ? nullptr : C, symForA); + DynQuantI8PatternHelper dqHelper(rewriter, loc, op, A, B, + isNoneValue(C) ? nullptr : C, + ONNXToZHighLoweringConfiguration::Quant::isActivationSym); // Match // TODO: if B is a constant and it is transposed, we can do transpose @@ -765,9 +762,6 @@ class replaceONNXGemmByDynQuantI8Pattern : public OpRewritePattern { rewriter.replaceOp(op, res); return success(); } - -private: - bool symForA = false; }; class replaceONNXMatMulIntegerPattern @@ -1535,28 +1529,11 @@ struct ONNXToZHighLoweringPass ONNXToZHighLoweringPass() = default; ONNXToZHighLoweringPass(const ONNXToZHighLoweringPass &pass) : PassWrapper>() {} - ONNXToZHighLoweringPass(NNPAQuantType quantMode) { - this->quantMode = quantMode; - } void runOnOperation() final; - -public: - Option quantMode{*this, "quantization", - llvm::cl::desc("Enable quantization"), - llvm::cl::values( - clEnumVal(DynSymI8, - "Dynamic Quantization to signed integer 8. Asymmetric quant for " - "activations and symmetric quant for weights."), - clEnumVal(SymSymI8, - "Dynamic Quantization to signed integer 8. Symmetric quant for " - "activations and symmetric quant for weights."), - clEnumVal(QNONE, "No quantization (default).")), - llvm::cl::init(QNONE)}; }; } // end anonymous namespace. -void getONNXToZHighOneOpPatterns( - RewritePatternSet &patterns, NNPAQuantType quantMode) { +void getONNXToZHighOneOpPatterns(RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); patterns.insert(context); patterns.insert(context); @@ -1602,17 +1579,21 @@ void getONNXToZHighOneOpPatterns( patterns.insert(context); patterns.insert(context); - // Pattern for i8 dynamic quantization, symmetric mode. + // Pattern for i8 dynamic quantization. if (isCompatibleWithNNPALevel(NNPALevel::M15) && - (quantMode == NNPAQuantType::DynSymI8 || - quantMode == NNPAQuantType::SymSymI8)) { + ONNXToZHighLoweringConfiguration::isDynQuant) { // Bump up the pattern benefit to run these before non-quantization // patterns. PatternBenefit quantPriority(QUANT_PATTERN_BENEFIT); - patterns.insert( - context, quantPriority, quantMode == NNPAQuantType::SymSymI8); - patterns.insert( - context, quantPriority, quantMode == NNPAQuantType::SymSymI8); + if (llvm::any_of(ONNXToZHighLoweringConfiguration::Quant::opTypes, + [](std::string s) { + return StringRef(s).equals_insensitive("MatMul"); + })) { + patterns.insert( + context, quantPriority); + patterns.insert( + context, quantPriority); + } } } @@ -1648,8 +1629,7 @@ void getONNXToZHighOneOpDynamicallyLegal( addDynamicallyLegalOpFor(target, dimAnalysis); } -void getONNXToZHighMultipleOpPatterns( - RewritePatternSet &patterns, NNPAQuantType quantMode) { +void getONNXToZHighMultipleOpPatterns(RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); patterns.insert(context); patterns.insert(context); @@ -1663,15 +1643,19 @@ void getONNXToZHighMultipleOpPatterns( patterns.insert(context); patterns.insert(context); - // Pattern for i8 dynamic quantization, symmetric mode. + // Pattern for i8 dynamic quantization. if (isCompatibleWithNNPALevel(NNPALevel::M15) && - (quantMode == NNPAQuantType::DynSymI8 || - quantMode == NNPAQuantType::SymSymI8)) { + (ONNXToZHighLoweringConfiguration::isDynQuant)) { // Bump up the pattern benefit to run these before non-quantization // patterns. PatternBenefit quantPriority(QUANT_PATTERN_BENEFIT); - patterns.insert( - context, quantPriority, quantMode == NNPAQuantType::SymSymI8); + if (llvm::any_of(ONNXToZHighLoweringConfiguration::Quant::opTypes, + [](std::string s) { + return StringRef(s).equals_insensitive("MatMul"); + })) { + patterns.insert( + context, quantPriority); + } } // Shape inference for newly-added operations. @@ -1687,8 +1671,8 @@ void ONNXToZHighLoweringPass::runOnOperation() { // Enable reporting on NNPA unsupported ops when specifying // `--opt-report=NNPAUnsupportedOps`. - OnnxToZHighLoweringConfiguration::reportOnNNPAUnsupportedOps = - OnnxToZHighLoweringConfiguration::optReportNNPAUnsupportedOps; + ONNXToZHighLoweringConfiguration::reportOnNNPAUnsupportedOps = + ONNXToZHighLoweringConfiguration::optReportNNPAUnsupportedOps; // We define the specific operations, or dialects, that are legal targets for // this lowering. @@ -1706,8 +1690,7 @@ void ONNXToZHighLoweringPass::runOnOperation() { // a single ONNX Op, because the single op lowering might have conditions that // prohibit the combined ops lowering happened. RewritePatternSet combinedPatterns(&getContext()); - onnx_mlir::getONNXToZHighMultipleOpPatterns( - combinedPatterns, this->quantMode); + onnx_mlir::getONNXToZHighMultipleOpPatterns(combinedPatterns); // It's ok to fail. (void)applyPatternsAndFoldGreedily(module, std::move(combinedPatterns)); @@ -1719,7 +1702,7 @@ void ONNXToZHighLoweringPass::runOnOperation() { // Single ONNX to ZHigh operation lowering. RewritePatternSet patterns(&getContext()); - onnx_mlir::getONNXToZHighOneOpPatterns(patterns, this->quantMode); + onnx_mlir::getONNXToZHighOneOpPatterns(patterns); // This is to make sure we don't want to alloc any MemRef at this high-level // representation. @@ -1742,8 +1725,4 @@ std::unique_ptr createONNXToZHighPass() { return std::make_unique(); } -std::unique_ptr createONNXToZHighPass(NNPAQuantType quantMode) { - return std::make_unique(quantMode); -} - } // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.hpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.hpp index d121058168..caddfc24b8 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.hpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.hpp @@ -24,10 +24,8 @@ namespace onnx_mlir { // Exports ONNXtoZHigh patterns. -void getONNXToZHighOneOpPatterns( - mlir::RewritePatternSet &patterns, NNPAQuantType quantMode); -void getONNXToZHighMultipleOpPatterns( - mlir::RewritePatternSet &patterns, NNPAQuantType quantMode); +void getONNXToZHighOneOpPatterns(mlir::RewritePatternSet &patterns); +void getONNXToZHighMultipleOpPatterns(mlir::RewritePatternSet &patterns); // Exports ONNXtoZHigh dynamically legal checks. void getONNXToZHighOneOpDynamicallyLegal( diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.cpp index ce7c4160bd..4a3c03205d 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.cpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.cpp @@ -103,13 +103,30 @@ Value getDynShape(Location loc, PatternRewriter &rewriter, Value x) { RankedTensorType::get({r}, rewriter.getI64Type()), dims, 0); } -int OnnxToZHighLoweringConfiguration::optReportNNPAUnsupportedOps = +int ONNXToZHighLoweringConfiguration::optReportNNPAUnsupportedOps = 0; // 0: Compile option (--opt-report=NNPAUnsupportedOps) not specified. -int OnnxToZHighLoweringConfiguration::reportOnNNPAUnsupportedOps = +int ONNXToZHighLoweringConfiguration::reportOnNNPAUnsupportedOps = 0; // 0: no reporting. -void configureOnnxToZHighLoweringPass(bool optReportNNPAUnsupportedOps) { - OnnxToZHighLoweringConfiguration::optReportNNPAUnsupportedOps = +bool ONNXToZHighLoweringConfiguration::isDynQuant = false; +bool ONNXToZHighLoweringConfiguration::Quant::isActivationSym = false; +bool ONNXToZHighLoweringConfiguration::Quant::isWeightSym = true; +llvm::SmallVector + ONNXToZHighLoweringConfiguration::Quant::opTypes = {}; + +void configureONNXToZHighLoweringPass(bool optReportNNPAUnsupportedOps, + bool isDynQuant, bool quantIsActivationSym, bool quantIsWeightSym, + llvm::ArrayRef quantOpTypes) { + ONNXToZHighLoweringConfiguration::optReportNNPAUnsupportedOps = optReportNNPAUnsupportedOps; + ONNXToZHighLoweringConfiguration::isDynQuant = isDynQuant; + if (isDynQuant) { + ONNXToZHighLoweringConfiguration::Quant::isActivationSym = + quantIsActivationSym; + ONNXToZHighLoweringConfiguration::Quant::isWeightSym = quantIsWeightSym; + ONNXToZHighLoweringConfiguration::Quant::opTypes.insert( + ONNXToZHighLoweringConfiguration::Quant::opTypes.begin(), + quantOpTypes.begin(), quantOpTypes.end()); + } } } // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp index 382d596e35..4a92309443 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp @@ -30,9 +30,16 @@ const std::string NNPA_DEVICE = "nnpa"; bool isEnableScalarBcastBinary(); -struct OnnxToZHighLoweringConfiguration { +// Populated by configureONNXToZHighLoweringPass(). +struct ONNXToZHighLoweringConfiguration { static int optReportNNPAUnsupportedOps; static int reportOnNNPAUnsupportedOps; + static bool isDynQuant; + struct Quant { + static bool isActivationSym; + static bool isWeightSym; + static llvm::SmallVector opTypes; + }; }; template diff --git a/src/Accelerators/NNPA/NNPAAccelerator.cpp b/src/Accelerators/NNPA/NNPAAccelerator.cpp index 2e4a06c477..50ef2bf0ba 100644 --- a/src/Accelerators/NNPA/NNPAAccelerator.cpp +++ b/src/Accelerators/NNPA/NNPAAccelerator.cpp @@ -138,6 +138,11 @@ void NNPAAccelerator::registerPasses(int optLevel) const { }); } +void NNPAAccelerator::configurePasses() const { + LLVM_DEBUG(llvm::dbgs() << "Configuring passes for NNPA accelerator\n"); + configurePassesNNPA(); +} + mlir::MemRefType NNPAAccelerator::convertTensorTypeToMemRefType( const mlir::TensorType tensorType) const { assert(tensorType.hasRank() && "expected only ranked shapes"); diff --git a/src/Accelerators/NNPA/NNPAAccelerator.hpp b/src/Accelerators/NNPA/NNPAAccelerator.hpp index e40bd774b6..a908c02da2 100644 --- a/src/Accelerators/NNPA/NNPAAccelerator.hpp +++ b/src/Accelerators/NNPA/NNPAAccelerator.hpp @@ -47,7 +47,7 @@ class NNPAAccelerator final : public Accelerator { uint64_t getVersionNumber() const final; //===--------------------------------------------------------------------===// - // Hooks for onnx-mlir-opt driver + // Hooks for onnx-mlir driver //===--------------------------------------------------------------------===// virtual void addPasses(mlir::OwningOpRef &module, mlir::PassManager &pm, onnx_mlir::EmissionTargetType &emissionTarget, @@ -57,6 +57,7 @@ class NNPAAccelerator final : public Accelerator { //===--------------------------------------------------------------------===// virtual void registerDialects(mlir::DialectRegistry ®istry) const final; virtual void registerPasses(int optLevel) const final; + virtual void configurePasses() const final; //===--------------------------------------------------------------------===// // Hooks for onnx-to-krnl pass //===--------------------------------------------------------------------===// diff --git a/src/Accelerators/NNPA/Pass/NNPAPasses.hpp b/src/Accelerators/NNPA/Pass/NNPAPasses.hpp index f00fcdedff..c23fb7f158 100644 --- a/src/Accelerators/NNPA/Pass/NNPAPasses.hpp +++ b/src/Accelerators/NNPA/Pass/NNPAPasses.hpp @@ -30,8 +30,9 @@ std::unique_ptr createDevicePlacementPass( /// Add pass for lowering ONNX ops to ZHigh ops. std::unique_ptr createONNXToZHighPass(); -std::unique_ptr createONNXToZHighPass(NNPAQuantType quantMode); -void configureOnnxToZHighLoweringPass(bool reportOnNNPAUnsupportedOps); +void configureONNXToZHighLoweringPass(bool reportOnNNPAUnsupportedOps, + bool isDynQuant, bool quantIsActivationSym, bool quantIsWeightSym, + llvm::ArrayRef quantOpTypes); /// Add pass for rewriting ONNX ops for ZHigh. std::unique_ptr createRewriteONNXForZHighPass(); diff --git a/src/Compiler/CompilerUtils.cpp b/src/Compiler/CompilerUtils.cpp index 4310010d36..8ca220989b 100644 --- a/src/Compiler/CompilerUtils.cpp +++ b/src/Compiler/CompilerUtils.cpp @@ -1024,6 +1024,7 @@ int compileModule(mlir::OwningOpRef &module, bool hasAccel = false; for (auto *accel : onnx_mlir::accel::Accelerator::getAccelerators()) { hasAccel = true; + accel->configurePasses(); accel->addPasses(module, pm, emissionTarget, outputNameNoExt); } if (!hasAccel) diff --git a/src/Tools/onnx-mlir-opt/onnx-mlir-opt.cpp b/src/Tools/onnx-mlir-opt/onnx-mlir-opt.cpp index 29411aaf68..2acbca10b0 100644 --- a/src/Tools/onnx-mlir-opt/onnx-mlir-opt.cpp +++ b/src/Tools/onnx-mlir-opt/onnx-mlir-opt.cpp @@ -181,6 +181,8 @@ int main(int argc, char **argv) { // Passes are configured with command line options so they must be configured // after command line parsing but before any passes are run. configurePasses(); + for (auto *accel : accel::Accelerator::getAccelerators()) + accel->configurePasses(); auto passManagerSetupFn = [&](PassManager &pm) { MLIRContext *ctx = pm.getContext(); diff --git a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/quantization.mlir b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/quantization.mlir index 83565c6e42..658cc0e4e7 100644 --- a/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/quantization.mlir +++ b/test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/quantization.mlir @@ -1,5 +1,5 @@ -// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --convert-onnx-to-zhigh="quantization=DynSymI8" --constprop-onnx --canonicalize --mlir-print-elementsattrs-with-hex-if-larger=-1 %s -split-input-file | FileCheck %s -// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --convert-onnx-to-zhigh="quantization=SymSymI8" --constprop-onnx --canonicalize --mlir-print-elementsattrs-with-hex-if-larger=-1 %s -split-input-file | FileCheck %s --check-prefix=SYMSYMI8 +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --convert-onnx-to-zhigh --nnpa-quant-dynamic --constprop-onnx --canonicalize --mlir-print-elementsattrs-with-hex-if-larger=-1 %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --march=arch15 --maccel=NNPA --convert-onnx-to-zhigh --nnpa-quant-dynamic=symActivation,symWeight --constprop-onnx --canonicalize --mlir-print-elementsattrs-with-hex-if-larger=-1 %s -split-input-file | FileCheck %s --check-prefix=SYMSYMI8 func.func @test_correctness_of_symmetric_quant_for_weight(%arg0: tensor) -> tensor { %0 = onnx.Constant dense<[[-0.00718058366], [5.253110e-01], [-0.0434652828], [-0.305256933], [0.193365857], [0.0105065238], [-0.143788248], [-0.0161222648], [0.0230324212], [-0.34107244], [-0.273072243], [-0.104352467], [0.0164068397], [-1.32305741], [-0.0345043093], [-0.232206389], [-0.150001124], [0.119475454], [0.730642438], [-0.407772154], [-0.0164191965], [-1.625590e-01], [-0.112515017], [0.158920377], [-0.0997497215], [0.0788274407], [1.1542908], [0.492949218], [-0.125796661], [0.0107790371], [0.141159713], [-0.0774109289], [-0.438130081], [-0.0888700857], [0.207725927], [-0.0913108587], [0.258232892], [0.0672571063], [-0.100412264], [1.68460846], [-0.289168775], [-0.686722457], [0.903651654], [0.110602334], [-0.0505490415], [1.31204939], [0.136107579], [0.26376456], [-0.508291602], [-0.0118971812], [-0.0373991691], [0.448705465], [0.00448446581], [-0.165114298], [0.156860754], [0.141124308], [-0.272756487], [-0.0834815949], [0.020905681], [-0.0877983123], [-1.0087887], [-0.353012145], [-0.0439243801], [-0.00592191564], [-0.0637216269], [0.175808683], [-0.193864927], [-0.0574007072], [0.390869558], [0.138100505], [0.429396927], [1.10117233], [-0.362377733], [0.116578773], [0.0540139228], [-5.85162896E-4], [-0.335441321], [-0.0902953073], [0.017575942], [-0.0359748788], [1.50025952], [-0.668821096], [0.0109066488], [9.907780e-01], [0.10227681], [-0.0582750589], [0.0172416102], [0.0429656394], [0.0465254933], [0.350135148], [-0.260139734], [0.199394852], [-0.136131078], [0.241424322], [0.855418264], [-0.160689577], [-0.825074911], [-0.124827594], [0.0153419804], [0.389386117], [0.153694436], [-0.897866904], [-0.292769879], [0.181667477], [-0.188009143], [-0.0245181341], [-2.17088842], [-0.0526076891], [-0.108600065], [0.187120304], [0.171495944], [0.310159177], [2.204240e+00], [0.0506350659], [-0.159419239], [-0.145082235], [-0.0991335287], [-0.0680764392], [-0.311415762], [-0.187137261], [-0.416945577], [0.0703471377], [0.498331547], [-0.41216433], [-0.427900195], [0.102105901], [0.130767033], [-0.440281332], [0.778514624], [-0.253678083], [0.395671815], [0.380029172], [-0.418493837], [-0.288157403], [0.0689846799], [1.269960e+00], [-0.0585722439], [-0.138125435], [-0.191710189], [0.0163070802], [0.159242466], [0.116627224], [0.289637923], [-0.299413532], [-0.0216965247], [0.271396786], [0.250576884], [-0.131420374], [0.137698188], [-0.0102280416], [0.234722644], [-0.0366179943], [-0.105632246], [-0.145528033], [-0.278210133], [-0.247100428], [0.217718393], [0.171669215], [0.0151556451], [0.961385667], [-0.0484847203], [0.434219301], [-0.00167646946], [-0.0308207348], [-0.102328695], [-0.127907664], [-0.185960412], [0.210866481], [0.140434876], [-0.233541235], [-0.123745643], [-0.0113738365], [1.30043447], [0.179708347], [-0.331716627], [0.0133318678], [-0.107284561], [-0.114116102], [-0.478514463], [0.0616452768], [-0.781869769], [-0.121830635], [-0.0684970543], [-6.584100e-02], [-0.131784603], [-0.619898796], [0.160366163], [-0.50115186], [0.0228514839], [0.581515431], [4.220270e-01], [1.944400e-01], [-1.07740963], [3.732520e-01], [0.725471556], [-0.117193311], [-0.105938725], [0.320118755], [-0.484032601], [-0.0467250831]]> : tensor<200x1xf32> From 008f7771331e95b3f50171f6aa9f35375d3b8b46 Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Mon, 20 Jan 2025 09:42:54 +0900 Subject: [PATCH 2/6] Update the instruction for building multiple accelerators (#3046) Signed-off-by: Tung D. Le --- docs/AddCustomAccelerators.md | 3 ++- src/Accelerators/CMakeLists.txt | 6 ++++-- test/mlir/CMakeLists.txt | 3 ++- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/docs/AddCustomAccelerators.md b/docs/AddCustomAccelerators.md index 4047cd65f8..bb1bbf8e8c 100644 --- a/docs/AddCustomAccelerators.md +++ b/docs/AddCustomAccelerators.md @@ -20,8 +20,9 @@ The folder content is flexible depending on each accelerator. However, we recomm To build accelerators in onnx-mlir, use the cmake variable `ONNX_MLIR_ACCELERATORS` when building onnx-mlir. `ONNX_MLIR_ACCELERATORS` accepts a semicolon-separated list of accelerator names. For example, ```bash $ cd build -$ cmake .. -DONNX_MLIR_ACCELERATORS=accel1;accel2 +$ cmake .. -DONNX_MLIR_ACCELERATORS='accel1;accel2' ``` +Note that the list should be quoted. ### 1.2 Compile a model to run with selected accelerators. diff --git a/src/Accelerators/CMakeLists.txt b/src/Accelerators/CMakeLists.txt index 4a4f97a2e0..db3a75e2db 100644 --- a/src/Accelerators/CMakeLists.txt +++ b/src/Accelerators/CMakeLists.txt @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # Populate the accelerator list and add the accelerator subdirectories. -# ONNX_MLIR_ACCELERATORS is the list of accelerators user specified +# ONNX_MLIR_ACCELERATORS is the semicolon-separated list of accelerators user specified +# Note that the list should be quoted, e.g. -DONNX_MLIR_ACCELERATORS='A;B' # ACCEL_TARGET_LIST is the list of cmake targets # ACCEL_LINK_LIST is the lists of accelerator libraries # ACCEL_INCLUDE_LIST is the list passed to inc generator @@ -10,7 +11,8 @@ if (ONNX_MLIR_ACCELERATORS) add_subdirectory(${t}) # If the accelerator can be built - if (${t}_ENABLED) + string(TOUPPER ${t} T) + if (${T}_ENABLED) list(APPEND ACCEL_TARGET_LIST "${t}Accel") list(APPEND ACCEL_LINK_LIST "OM${t}Accel") list(APPEND ACCEL_INCLUDE_LIST "${t}") diff --git a/test/mlir/CMakeLists.txt b/test/mlir/CMakeLists.txt index 0e2408fcda..60b9bf67b9 100644 --- a/test/mlir/CMakeLists.txt +++ b/test/mlir/CMakeLists.txt @@ -4,7 +4,8 @@ # accelerator code itself cannot be built. if (ONNX_MLIR_ACCELERATORS) foreach(t ${ONNX_MLIR_ACCELERATORS}) - set(${t}_LIT_ENABLED 1) + string(TOUPPER ${t} T) + set(${T}_LIT_ENABLED 1) list(APPEND ACCEL_LIT_LIST "${t}") endforeach(t) endif(ONNX_MLIR_ACCELERATORS) From bd41f89e199d6bf7a9285baf6976b6817c59898d Mon Sep 17 00:00:00 2001 From: "Tung D. Le" Date: Mon, 20 Jan 2025 12:46:23 +0900 Subject: [PATCH 3/6] Add a document for quantization on NNPA (#3045) * Add a document for quantization on NNPA Signed-off-by: Tung D. Le --------- Signed-off-by: Tung D. Le --- docs/Quantization-NNPA.md | 65 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 docs/Quantization-NNPA.md diff --git a/docs/Quantization-NNPA.md b/docs/Quantization-NNPA.md new file mode 100644 index 0000000000..ab53c3f5e6 --- /dev/null +++ b/docs/Quantization-NNPA.md @@ -0,0 +1,65 @@ + + +# Overview + +NNPA in IBM Telum II supports 8-bit signed-integer quantized matrix multiplications. This document shows how to compile an ONNX model for 8-bit quantization on NNPA. When not following these steps, models will still be accelerated when targeting Telum systems using a mixture of 16-bit floating-point numbers for computations mapped to the Telum's Integrated AI accelerator and 32-bit floating-point numbers for computations mapped to the Telum CPUs. + +There are two approaches to using quantization in the onnx-mlir compiler, depending on the input ONNX model to the compile: +- The input model is a quantized model that was quantized by other frameworks such as ONNX Runtime. In this case, the input ONNX model contains 8-bit operations, and the onnx-mlir compiler selects suitable 8-bit operations to run on NNPA. There is no special compile flags needed to enable quantization when compiling this quantized model. Hence, we do not discuss this case in this document. + - In this approach, the compiler supports both static and dynamic quantized models. +- The input model is a non-quantized model, e.g. operations operate on float32 data types. In this case, the onnx-mlir compiler provides several quantization options in order to quantize the model during compilation, then run the compiled model on NNPA. The remaining of this document describes this approach. + - In this approach, the compiler only supports dynamic quantization. + +In both approaches, the following constraints are applied: +- Only per-tensor quantization is supported, meaning `scale` and `zero_point` are computed per-tensor and are scalar values. +- Target quantization data type is 8-bit signed-integer. + +Quantization requires NNPA in IBM Telum II, meaning that the following compile flags must be specified to enable quantization: `-maccel=NNPA -march=arch15`. + +# Dynamic quantization by the compiler + +Again, it is important to note that the onnx-mlir compiler currently: +- supports per-tensor dynamic quantization, and +- quantizes data tensors from float32 to 8-bit signed integer. If a data tensor in the input model is already in 8-bit singed integer, the compiler will not quantize it again. + +The compiler provides two compile flags for dynamically quantizing a model at compile time: +- `--nnpa-quant-dynamic` to enable dynamic quantization. +- `--nnpa-quant-op-types` to specify the types of ONNX operations to quantize manually, e.g. `MatMul,Conv`. + +Users can specify whether or not to symmetrize data for activations and weights by using options `symActivation, asymActivation, symWeight, asymWeight` as values for `--nnpa-quant-dynamic`. +For examples, to asymmetrize data for activations and to symmetrize data for weights, one can use `--nnpa-quant-dynamic=asymActivation,symWeight`. + +By specifying `--nnpa-quant-dynamic` only, the compiler will decide quantization options and operation types by itself. + +## Computing `scale` and `zero_point` +The compiler uses the following equations to compute `scale` and `zero_point` for 8-bit signed integer quantization. + +Asymmetric quantization +``` +scale = (maximum(0, max(x)) - minimum(0, min(x))) / (qmax - qmin) +zero_point = cast(round(saturate(qmin - min(x)/scale))) +``` +where +- `x` is the input tensor to quantize, +- data range is adjusted to include 0, +- `qmax=127` and `qmin=-128` are the max and min values for quantization range. +- `saturate` is to saturate to `[-128, 127]`. + +Symmetric quantization +``` +scale = max(abs(x)) / 127 +zero_point = 0 +``` + +Given `scale` and `zero_point`, the input `x` is quantized to +``` +quantized_x = x/scale + zero_point +``` + +# Performance notes + +It is often the case that symmetric quantization leads to better inference performance but poorer accuracy than asymmetric quantization. +Users may want to experiment with different quantization schemes to find the best combination for their own model. + +# Resources +- [A visual guide to quantization](https://www.maartengrootendorst.com/blog/quantization/) From 6fb96c520dd704c6358e84528e0d25daf326d8b5 Mon Sep 17 00:00:00 2001 From: Sunny Anand <164108690+Sunny-Anand@users.noreply.github.com> Date: Thu, 23 Jan 2025 15:28:44 -0600 Subject: [PATCH 4/6] update onnx opset (#3050) * update onnx opset Signed-off-by: Sunny Anand * address feedback Signed-off-by: Sunny Anand * reformat for clang formatting Signed-off-by: Sunny Anand * reformat for black formattter Signed-off-by: Sunny Anand --------- Signed-off-by: Sunny Anand --- docs/SupportedONNXOps-NNPA.md | 4 ++-- docs/SupportedONNXOps-cpu.md | 4 ++-- src/Dialect/ONNX/ONNXOps.hpp | 4 +++- utils/pre-onnx-mlir.py | 5 +++-- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/docs/SupportedONNXOps-NNPA.md b/docs/SupportedONNXOps-NNPA.md index a0f85aef41..fd91a1b016 100644 --- a/docs/SupportedONNXOps-NNPA.md +++ b/docs/SupportedONNXOps-NNPA.md @@ -3,11 +3,11 @@ # Supported ONNX Operation for Target *NNPA*. -Onnx-mlir currently supports ONNX operations targeting up to opset 21. Limitations are listed when applicable. This documentation highlights the minimum and maximum opset versions that are fully supported by onnx-mlir and not the version changes. +Onnx-mlir currently supports ONNX operations targeting up to opset 22. Limitations are listed when applicable. This documentation highlights the minimum and maximum opset versions that are fully supported by onnx-mlir and not the version changes. * Operations are defined by the [ONNX Standard](https://github.com/onnx/onnx/blob/main/docs/Operators.md). * **Supported Opsets** indicates the lowest and highest opset a model may have for onnx-mlir to support compiling a model with the operator. - * A * indicates onnx-mlir is compatible with the latest version of that operator available as of opset 21. + * A * indicates onnx-mlir is compatible with the latest version of that operator available as of opset 22. * A ^ indicates onnx-mlir is compatible with the latest level of the NNPA Architecture which is z16. diff --git a/docs/SupportedONNXOps-cpu.md b/docs/SupportedONNXOps-cpu.md index 172c9ed7ca..7b6c643776 100644 --- a/docs/SupportedONNXOps-cpu.md +++ b/docs/SupportedONNXOps-cpu.md @@ -3,11 +3,11 @@ # Supported ONNX Operation for Target *cpu*. -Onnx-mlir currently supports ONNX operations targeting up to opset 21. Limitations are listed when applicable. This documentation highlights the minimum and maximum opset versions that are fully supported by onnx-mlir and not the version changes. +Onnx-mlir currently supports ONNX operations targeting up to opset 22. Limitations are listed when applicable. This documentation highlights the minimum and maximum opset versions that are fully supported by onnx-mlir and not the version changes. * Operations are defined by the [ONNX Standard](https://github.com/onnx/onnx/blob/main/docs/Operators.md). * **Supported Opsets** indicates the lowest and highest opset a model may have for onnx-mlir to support compiling a model with the operator. - * A * indicates onnx-mlir is compatible with the latest version of that operator available as of opset 21. + * A * indicates onnx-mlir is compatible with the latest version of that operator available as of opset 22. | Op |Supported Opsets (inclusive) |Limitations |Notes | diff --git a/src/Dialect/ONNX/ONNXOps.hpp b/src/Dialect/ONNX/ONNXOps.hpp index febb5207c0..e104b97ae6 100644 --- a/src/Dialect/ONNX/ONNXOps.hpp +++ b/src/Dialect/ONNX/ONNXOps.hpp @@ -26,7 +26,9 @@ namespace mlir { // OpSet level supported by onnx-mlir -static constexpr int CURRENT_ONNX_OPSET = 20; +// To update all occurrence of the current ONNX opset, please grep +// "CURRENT_ONNX_OPSET" and update all locations accordingly. +static constexpr int CURRENT_ONNX_OPSET = 22; } // end namespace mlir #define GET_OP_CLASSES diff --git a/utils/pre-onnx-mlir.py b/utils/pre-onnx-mlir.py index 66a5ea4a2c..54f8b6f517 100644 --- a/utils/pre-onnx-mlir.py +++ b/utils/pre-onnx-mlir.py @@ -39,7 +39,8 @@ # ==UPDATE_ONNX_VERSION_OPSET== # Look for tag above and update all references when upgrading the ONNX support within ONNX-MLIR. -current_onnx_opset = 21 +# To update all occurrence of the current ONNX opset, please grep "CURRENT_ONNX_OPSET" and update all locations accordingly. +current_onnx_opset = 22 converted_model = version_converter.convert_version(original_model, current_onnx_opset) @@ -58,4 +59,4 @@ + ".onnx" ) onnx.save(converted_model, outFile) - print("The converted model is aved to " + outFile) + print("The converted model is saved to " + outFile) From cf17e0de27029e05dda3c27188fe3b7ee9ef1bd2 Mon Sep 17 00:00:00 2001 From: srcarroll <50210727+srcarroll@users.noreply.github.com> Date: Mon, 27 Jan 2025 11:12:43 -0600 Subject: [PATCH 5/6] Remove element type restriction in softmax lowering (#3051) Signed-off-by: Sam --- .../ONNXToStablehlo/Math/Softmax.cpp | 6 +- .../Math/Softmax-Decompose.mlir | 101 +++++++++++++++ .../onnx_to_stablehlo/Math/Softmax.mlir | 116 ++++-------------- 3 files changed, 125 insertions(+), 98 deletions(-) create mode 100644 test/mlir/conversion/onnx_to_stablehlo/Math/Softmax-Decompose.mlir diff --git a/src/Conversion/ONNXToStablehlo/Math/Softmax.cpp b/src/Conversion/ONNXToStablehlo/Math/Softmax.cpp index f1b0d61657..fa833d0fbd 100644 --- a/src/Conversion/ONNXToStablehlo/Math/Softmax.cpp +++ b/src/Conversion/ONNXToStablehlo/Math/Softmax.cpp @@ -126,10 +126,6 @@ struct ONNXSoftmaxOpLoweringToStablehlo : public ConversionPattern { Location loc = op->getLoc(); Type outputType = *op->result_type_begin(); assert(isRankedShapedType(outputType) && "Expected Ranked ShapedType"); - assert(mlir::cast(operand.getType()) - .getElementType() - .isF32() && - "Currently Only float32 is supported for input"); // Exponential operation Value ElementwiseExpStableHLO = rewriter.create( @@ -204,4 +200,4 @@ void populateLoweringONNXSoftmaxOpToStablehloPattern( RewritePatternSet &patterns, MLIRContext *ctx) { patterns.insert(ctx); } -} // namespace onnx_mlir \ No newline at end of file +} // namespace onnx_mlir diff --git a/test/mlir/conversion/onnx_to_stablehlo/Math/Softmax-Decompose.mlir b/test/mlir/conversion/onnx_to_stablehlo/Math/Softmax-Decompose.mlir new file mode 100644 index 0000000000..0da75f096a --- /dev/null +++ b/test/mlir/conversion/onnx_to_stablehlo/Math/Softmax-Decompose.mlir @@ -0,0 +1,101 @@ +// RUN: onnx-mlir-opt --decompose-onnx="target=stablehlo" --convert-onnx-to-stablehlo %s --canonicalize -split-input-file | FileCheck %s + +func.func @test_softmax(%arg0 : tensor<10x20x30xf32>) -> tensor<10x20x30xf32> { + %0 = "onnx.Softmax"(%arg0) {axis = 1: si64} : (tensor<10x20x30xf32>) -> tensor<10x20x30xf32> + "func.return"(%0) : (tensor<10x20x30xf32>) -> () +} + +// CHECK-LABEL: func.func @test_softmax +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x20x30xf32>) -> tensor<10x20x30xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = shape.const_shape [10, 1, 30] : tensor<3xindex> +// CHECK-DAG: [[VAR_1_:%.+]] = shape.const_shape [10, 20, 30] : tensor<3xindex> +// CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.constant dense<0xFF800000> : tensor +// CHECK: [[VAR_4_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_3_]]) applies stablehlo.maximum across dimensions = [1] : (tensor<10x20x30xf32>, tensor) -> tensor<10x30xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = stablehlo.dynamic_reshape [[VAR_4_]], [[VAR_0_]] : (tensor<10x30xf32>, tensor<3xindex>) -> tensor<10x1x30xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_1_]], dims = [0, 1, 2] : (tensor<10x20x30xf32>, tensor<3xindex>) -> tensor<10x20x30xf32> +// CHECK: [[VAR_7_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_5_]], [[VAR_1_]], dims = [0, 1, 2] : (tensor<10x1x30xf32>, tensor<3xindex>) -> tensor<10x20x30xf32> +// CHECK: [[VAR_8_:%.+]] = stablehlo.subtract [[VAR_6_]], [[VAR_7_]] : tensor<10x20x30xf32> +// CHECK: [[VAR_9_:%.+]] = stablehlo.exponential [[VAR_8_]] : tensor<10x20x30xf32> +// CHECK: [[VAR_10_:%.+]] = stablehlo.reduce([[VAR_9_]] init: [[VAR_2_]]) applies stablehlo.add across dimensions = [1] : (tensor<10x20x30xf32>, tensor) -> tensor<10x30xf32> +// CHECK-DAG: [[VAR_11_:%.+]] = stablehlo.dynamic_reshape [[VAR_10_]], [[VAR_0_]] : (tensor<10x30xf32>, tensor<3xindex>) -> tensor<10x1x30xf32> +// CHECK-DAG: [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_9_]], [[VAR_1_]], dims = [0, 1, 2] : (tensor<10x20x30xf32>, tensor<3xindex>) -> tensor<10x20x30xf32> +// CHECK: [[VAR_13_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_11_]], [[VAR_1_]], dims = [0, 1, 2] : (tensor<10x1x30xf32>, tensor<3xindex>) -> tensor<10x20x30xf32> +// CHECK: [[VAR_14_:%.+]] = stablehlo.divide [[VAR_12_]], [[VAR_13_]] : tensor<10x20x30xf32> +// CHECK: return [[VAR_14_]] : tensor<10x20x30xf32> +// CHECK: } + +// ----- + +func.func @test_softmax_dynamic(%arg0 : tensor) -> tensor { + %0 = "onnx.Softmax"(%arg0) {axis = 1: si64} : (tensor) -> tensor + "func.return"(%0) : (tensor) -> () +} + +// CHECK-LABEL: func.func @test_softmax_dynamic +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { +// CHECK-DAG: [[VAR_0_:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.constant dense<0xFF800000> : tensor +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_1_]]) applies stablehlo.maximum across dimensions = [1] : (tensor, tensor) -> tensor +// CHECK-DAG: [[VAR_3_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor -> tensor<3xindex> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_4_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_0_]] : tensor<3xindex>, index -> index +// CHECK-DAG: [[VAR_5_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_2_]] : tensor<3xindex>, index -> index +// CHECK: [[VAR_6_:%.+]] = shape.from_extents [[VAR_4_]], [[CST_1_]], [[VAR_5_]] : index, index, index +// CHECK: [[VAR_7_:%.+]] = shape.to_extent_tensor [[VAR_6_]] : !shape.shape -> tensor<3xindex> +// CHECK-DAG: [[VAR_8_:%.+]] = stablehlo.dynamic_reshape [[VAR_2_]], [[VAR_7_]] : (tensor, tensor<3xindex>) -> tensor +// CHECK-DAG: [[VAR_9_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor -> tensor<3xindex> +// CHECK: [[VAR_10_:%.+]] = shape.shape_of [[VAR_8_]] : tensor -> tensor<3xindex> +// CHECK: [[VAR_11_:%.+]] = shape.broadcast [[VAR_9_]], [[VAR_10_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex> +// CHECK-DAG: [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor +// CHECK-DAG: [[VAR_13_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_8_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor +// CHECK: [[VAR_14_:%.+]] = stablehlo.subtract [[VAR_12_]], [[VAR_13_]] : tensor +// CHECK: [[VAR_15_:%.+]] = stablehlo.exponential [[VAR_14_]] : tensor +// CHECK-DAG: [[VAR_16_:%.+]] = stablehlo.reduce([[VAR_15_]] init: [[VAR_0_]]) applies stablehlo.add across dimensions = [1] : (tensor, tensor) -> tensor +// CHECK-DAG: [[VAR_17_:%.+]] = shape.shape_of [[VAR_15_]] : tensor -> tensor<3xindex> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_18_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_0_]] : tensor<3xindex>, index -> index +// CHECK-DAG: [[VAR_19_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_2_]] : tensor<3xindex>, index -> index +// CHECK: [[VAR_20_:%.+]] = shape.from_extents [[VAR_18_]], [[CST_1_]], [[VAR_19_]] : index, index, index +// CHECK: [[VAR_21_:%.+]] = shape.to_extent_tensor [[VAR_20_]] : !shape.shape -> tensor<3xindex> +// CHECK-DAG: [[VAR_22_:%.+]] = stablehlo.dynamic_reshape [[VAR_16_]], [[VAR_21_]] : (tensor, tensor<3xindex>) -> tensor +// CHECK-DAG: [[VAR_23_:%.+]] = shape.shape_of [[VAR_15_]] : tensor -> tensor<3xindex> +// CHECK: [[VAR_24_:%.+]] = shape.shape_of [[VAR_22_]] : tensor -> tensor<3xindex> +// CHECK: [[VAR_25_:%.+]] = shape.broadcast [[VAR_23_]], [[VAR_24_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex> +// CHECK-DAG: [[VAR_26_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_15_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor +// CHECK-DAG: [[VAR_27_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_22_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor +// CHECK: [[VAR_28_:%.+]] = stablehlo.divide [[VAR_26_]], [[VAR_27_]] : tensor +// CHECK: return [[VAR_28_]] : tensor +// CHECK: } + + +// ----- + +func.func @test_softmax_2d(%arg0 : tensor<1x10xf32>) -> tensor<1x10xf32> { + %0 = "onnx.Softmax"(%arg0) {axis = -1 : si64} : (tensor<1x10xf32>) -> tensor<1x10xf32> + "func.return"(%0) : (tensor<1x10xf32>) -> () +} + +// CHECK-LABEL: func.func @test_softmax_2d +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x10xf32>) -> tensor<1x10xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = shape.const_shape [1, 1] : tensor<2xindex> +// CHECK-DAG: [[VAR_1_:%.+]] = shape.const_shape [1, 10] : tensor<2xindex> +// CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.constant dense<0xFF800000> : tensor +// CHECK: [[VAR_4_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_3_]]) applies stablehlo.maximum across dimensions = [1] : (tensor<1x10xf32>, tensor) -> tensor<1xf32> +// CHECK-DAG: [[VAR_5_:%.+]] = stablehlo.dynamic_reshape [[VAR_4_]], [[VAR_0_]] : (tensor<1xf32>, tensor<2xindex>) -> tensor<1x1xf32> +// CHECK-DAG: [[VAR_6_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_1_]], dims = [0, 1] : (tensor<1x10xf32>, tensor<2xindex>) -> tensor<1x10xf32> +// CHECK: [[VAR_7_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_5_]], [[VAR_1_]], dims = [0, 1] : (tensor<1x1xf32>, tensor<2xindex>) -> tensor<1x10xf32> +// CHECK: [[VAR_8_:%.+]] = stablehlo.subtract [[VAR_6_]], [[VAR_7_]] : tensor<1x10xf32> +// CHECK: [[VAR_9_:%.+]] = stablehlo.exponential [[VAR_8_]] : tensor<1x10xf32> +// CHECK: [[VAR_10_:%.+]] = stablehlo.reduce([[VAR_9_]] init: [[VAR_2_]]) applies stablehlo.add across dimensions = [1] : (tensor<1x10xf32>, tensor) -> tensor<1xf32> +// CHECK-DAG: [[VAR_11_:%.+]] = stablehlo.dynamic_reshape [[VAR_10_]], [[VAR_0_]] : (tensor<1xf32>, tensor<2xindex>) -> tensor<1x1xf32> +// CHECK-DAG: [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_9_]], [[VAR_1_]], dims = [0, 1] : (tensor<1x10xf32>, tensor<2xindex>) -> tensor<1x10xf32> +// CHECK: [[VAR_13_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_11_]], [[VAR_1_]], dims = [0, 1] : (tensor<1x1xf32>, tensor<2xindex>) -> tensor<1x10xf32> +// CHECK: [[VAR_14_:%.+]] = stablehlo.divide [[VAR_12_]], [[VAR_13_]] : tensor<1x10xf32> +// CHECK: return [[VAR_14_]] : tensor<1x10xf32> +// CHECK: } diff --git a/test/mlir/conversion/onnx_to_stablehlo/Math/Softmax.mlir b/test/mlir/conversion/onnx_to_stablehlo/Math/Softmax.mlir index 0da75f096a..3fe15a13d1 100644 --- a/test/mlir/conversion/onnx_to_stablehlo/Math/Softmax.mlir +++ b/test/mlir/conversion/onnx_to_stablehlo/Math/Softmax.mlir @@ -1,101 +1,31 @@ -// RUN: onnx-mlir-opt --decompose-onnx="target=stablehlo" --convert-onnx-to-stablehlo %s --canonicalize -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --convert-onnx-to-stablehlo %s --canonicalize -split-input-file | FileCheck %s -func.func @test_softmax(%arg0 : tensor<10x20x30xf32>) -> tensor<10x20x30xf32> { - %0 = "onnx.Softmax"(%arg0) {axis = 1: si64} : (tensor<10x20x30xf32>) -> tensor<10x20x30xf32> - "func.return"(%0) : (tensor<10x20x30xf32>) -> () +func.func @test_softmax_bf16(%arg0 : tensor<10x20x30xbf16>) -> tensor<10x20x30xbf16> { + %0 = "onnx.Softmax"(%arg0) {axis = 1: si64} : (tensor<10x20x30xbf16>) -> tensor<10x20x30xbf16> + "func.return"(%0) : (tensor<10x20x30xbf16>) -> () } -// CHECK-LABEL: func.func @test_softmax -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x20x30xf32>) -> tensor<10x20x30xf32> { -// CHECK-DAG: [[VAR_0_:%.+]] = shape.const_shape [10, 1, 30] : tensor<3xindex> -// CHECK-DAG: [[VAR_1_:%.+]] = shape.const_shape [10, 20, 30] : tensor<3xindex> -// CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor -// CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.constant dense<0xFF800000> : tensor -// CHECK: [[VAR_4_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_3_]]) applies stablehlo.maximum across dimensions = [1] : (tensor<10x20x30xf32>, tensor) -> tensor<10x30xf32> -// CHECK-DAG: [[VAR_5_:%.+]] = stablehlo.dynamic_reshape [[VAR_4_]], [[VAR_0_]] : (tensor<10x30xf32>, tensor<3xindex>) -> tensor<10x1x30xf32> -// CHECK-DAG: [[VAR_6_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_1_]], dims = [0, 1, 2] : (tensor<10x20x30xf32>, tensor<3xindex>) -> tensor<10x20x30xf32> -// CHECK: [[VAR_7_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_5_]], [[VAR_1_]], dims = [0, 1, 2] : (tensor<10x1x30xf32>, tensor<3xindex>) -> tensor<10x20x30xf32> -// CHECK: [[VAR_8_:%.+]] = stablehlo.subtract [[VAR_6_]], [[VAR_7_]] : tensor<10x20x30xf32> -// CHECK: [[VAR_9_:%.+]] = stablehlo.exponential [[VAR_8_]] : tensor<10x20x30xf32> -// CHECK: [[VAR_10_:%.+]] = stablehlo.reduce([[VAR_9_]] init: [[VAR_2_]]) applies stablehlo.add across dimensions = [1] : (tensor<10x20x30xf32>, tensor) -> tensor<10x30xf32> -// CHECK-DAG: [[VAR_11_:%.+]] = stablehlo.dynamic_reshape [[VAR_10_]], [[VAR_0_]] : (tensor<10x30xf32>, tensor<3xindex>) -> tensor<10x1x30xf32> -// CHECK-DAG: [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_9_]], [[VAR_1_]], dims = [0, 1, 2] : (tensor<10x20x30xf32>, tensor<3xindex>) -> tensor<10x20x30xf32> -// CHECK: [[VAR_13_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_11_]], [[VAR_1_]], dims = [0, 1, 2] : (tensor<10x1x30xf32>, tensor<3xindex>) -> tensor<10x20x30xf32> -// CHECK: [[VAR_14_:%.+]] = stablehlo.divide [[VAR_12_]], [[VAR_13_]] : tensor<10x20x30xf32> -// CHECK: return [[VAR_14_]] : tensor<10x20x30xf32> -// CHECK: } +// CHECK-LABEL: func.func @test_softmax_bf16 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x20x30xbf16>) -> tensor<10x20x30xbf16> { +// CHECK: [[CST:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK-NEXT: [[EXP:%.+]] = stablehlo.exponential [[PARAM_0_]] : tensor<10x20x30xbf16> +// CHECK-NEXT: [[REDUCE:%.+]] = stablehlo.reduce([[EXP]] init: [[CST]]) applies stablehlo.add across dimensions = [1] : (tensor<10x20x30xbf16>, tensor) -> tensor<10x30xbf16> +// CHECK-NEXT: [[DENOM:%.+]] = stablehlo.broadcast_in_dim [[REDUCE]], dims = [0, 2] : (tensor<10x30xbf16>) -> tensor<10x20x30xbf16> +// CHECK-NEXT: [[RES:%.+]] = stablehlo.divide [[EXP]], [[DENOM]] : tensor<10x20x30xbf16> +// CHECK-NEXT: return [[RES]] : tensor<10x20x30xbf16> // ----- -func.func @test_softmax_dynamic(%arg0 : tensor) -> tensor { - %0 = "onnx.Softmax"(%arg0) {axis = 1: si64} : (tensor) -> tensor - "func.return"(%0) : (tensor) -> () +func.func @test_softmax_f64(%arg0 : tensor<10x20x30xf64>) -> tensor<10x20x30xf64> { + %0 = "onnx.Softmax"(%arg0) {axis = -1: si64} : (tensor<10x20x30xf64>) -> tensor<10x20x30xf64> + "func.return"(%0) : (tensor<10x20x30xf64>) -> () } -// CHECK-LABEL: func.func @test_softmax_dynamic -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor) -> tensor { -// CHECK-DAG: [[VAR_0_:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor -// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : index -// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[VAR_1_:%.+]] = stablehlo.constant dense<0xFF800000> : tensor -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_1_]]) applies stablehlo.maximum across dimensions = [1] : (tensor, tensor) -> tensor -// CHECK-DAG: [[VAR_3_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor -> tensor<3xindex> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_4_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_0_]] : tensor<3xindex>, index -> index -// CHECK-DAG: [[VAR_5_:%.+]] = shape.get_extent [[VAR_3_]], [[CST_2_]] : tensor<3xindex>, index -> index -// CHECK: [[VAR_6_:%.+]] = shape.from_extents [[VAR_4_]], [[CST_1_]], [[VAR_5_]] : index, index, index -// CHECK: [[VAR_7_:%.+]] = shape.to_extent_tensor [[VAR_6_]] : !shape.shape -> tensor<3xindex> -// CHECK-DAG: [[VAR_8_:%.+]] = stablehlo.dynamic_reshape [[VAR_2_]], [[VAR_7_]] : (tensor, tensor<3xindex>) -> tensor -// CHECK-DAG: [[VAR_9_:%.+]] = shape.shape_of [[PARAM_0_]] : tensor -> tensor<3xindex> -// CHECK: [[VAR_10_:%.+]] = shape.shape_of [[VAR_8_]] : tensor -> tensor<3xindex> -// CHECK: [[VAR_11_:%.+]] = shape.broadcast [[VAR_9_]], [[VAR_10_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex> -// CHECK-DAG: [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor -// CHECK-DAG: [[VAR_13_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_8_]], [[VAR_11_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor -// CHECK: [[VAR_14_:%.+]] = stablehlo.subtract [[VAR_12_]], [[VAR_13_]] : tensor -// CHECK: [[VAR_15_:%.+]] = stablehlo.exponential [[VAR_14_]] : tensor -// CHECK-DAG: [[VAR_16_:%.+]] = stablehlo.reduce([[VAR_15_]] init: [[VAR_0_]]) applies stablehlo.add across dimensions = [1] : (tensor, tensor) -> tensor -// CHECK-DAG: [[VAR_17_:%.+]] = shape.shape_of [[VAR_15_]] : tensor -> tensor<3xindex> -// CHECK-NOT: separator of consecutive DAGs -// CHECK-DAG: [[VAR_18_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_0_]] : tensor<3xindex>, index -> index -// CHECK-DAG: [[VAR_19_:%.+]] = shape.get_extent [[VAR_17_]], [[CST_2_]] : tensor<3xindex>, index -> index -// CHECK: [[VAR_20_:%.+]] = shape.from_extents [[VAR_18_]], [[CST_1_]], [[VAR_19_]] : index, index, index -// CHECK: [[VAR_21_:%.+]] = shape.to_extent_tensor [[VAR_20_]] : !shape.shape -> tensor<3xindex> -// CHECK-DAG: [[VAR_22_:%.+]] = stablehlo.dynamic_reshape [[VAR_16_]], [[VAR_21_]] : (tensor, tensor<3xindex>) -> tensor -// CHECK-DAG: [[VAR_23_:%.+]] = shape.shape_of [[VAR_15_]] : tensor -> tensor<3xindex> -// CHECK: [[VAR_24_:%.+]] = shape.shape_of [[VAR_22_]] : tensor -> tensor<3xindex> -// CHECK: [[VAR_25_:%.+]] = shape.broadcast [[VAR_23_]], [[VAR_24_]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex> -// CHECK-DAG: [[VAR_26_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_15_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor -// CHECK-DAG: [[VAR_27_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_22_]], [[VAR_25_]], dims = [0, 1, 2] : (tensor, tensor<3xindex>) -> tensor -// CHECK: [[VAR_28_:%.+]] = stablehlo.divide [[VAR_26_]], [[VAR_27_]] : tensor -// CHECK: return [[VAR_28_]] : tensor -// CHECK: } - - -// ----- - -func.func @test_softmax_2d(%arg0 : tensor<1x10xf32>) -> tensor<1x10xf32> { - %0 = "onnx.Softmax"(%arg0) {axis = -1 : si64} : (tensor<1x10xf32>) -> tensor<1x10xf32> - "func.return"(%0) : (tensor<1x10xf32>) -> () -} - -// CHECK-LABEL: func.func @test_softmax_2d -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x10xf32>) -> tensor<1x10xf32> { -// CHECK-DAG: [[VAR_0_:%.+]] = shape.const_shape [1, 1] : tensor<2xindex> -// CHECK-DAG: [[VAR_1_:%.+]] = shape.const_shape [1, 10] : tensor<2xindex> -// CHECK-DAG: [[VAR_2_:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor -// CHECK-DAG: [[VAR_3_:%.+]] = stablehlo.constant dense<0xFF800000> : tensor -// CHECK: [[VAR_4_:%.+]] = stablehlo.reduce([[PARAM_0_]] init: [[VAR_3_]]) applies stablehlo.maximum across dimensions = [1] : (tensor<1x10xf32>, tensor) -> tensor<1xf32> -// CHECK-DAG: [[VAR_5_:%.+]] = stablehlo.dynamic_reshape [[VAR_4_]], [[VAR_0_]] : (tensor<1xf32>, tensor<2xindex>) -> tensor<1x1xf32> -// CHECK-DAG: [[VAR_6_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[PARAM_0_]], [[VAR_1_]], dims = [0, 1] : (tensor<1x10xf32>, tensor<2xindex>) -> tensor<1x10xf32> -// CHECK: [[VAR_7_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_5_]], [[VAR_1_]], dims = [0, 1] : (tensor<1x1xf32>, tensor<2xindex>) -> tensor<1x10xf32> -// CHECK: [[VAR_8_:%.+]] = stablehlo.subtract [[VAR_6_]], [[VAR_7_]] : tensor<1x10xf32> -// CHECK: [[VAR_9_:%.+]] = stablehlo.exponential [[VAR_8_]] : tensor<1x10xf32> -// CHECK: [[VAR_10_:%.+]] = stablehlo.reduce([[VAR_9_]] init: [[VAR_2_]]) applies stablehlo.add across dimensions = [1] : (tensor<1x10xf32>, tensor) -> tensor<1xf32> -// CHECK-DAG: [[VAR_11_:%.+]] = stablehlo.dynamic_reshape [[VAR_10_]], [[VAR_0_]] : (tensor<1xf32>, tensor<2xindex>) -> tensor<1x1xf32> -// CHECK-DAG: [[VAR_12_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_9_]], [[VAR_1_]], dims = [0, 1] : (tensor<1x10xf32>, tensor<2xindex>) -> tensor<1x10xf32> -// CHECK: [[VAR_13_:%.+]] = stablehlo.dynamic_broadcast_in_dim [[VAR_11_]], [[VAR_1_]], dims = [0, 1] : (tensor<1x1xf32>, tensor<2xindex>) -> tensor<1x10xf32> -// CHECK: [[VAR_14_:%.+]] = stablehlo.divide [[VAR_12_]], [[VAR_13_]] : tensor<1x10xf32> -// CHECK: return [[VAR_14_]] : tensor<1x10xf32> -// CHECK: } +// CHECK-LABEL: func.func @test_softmax_f64 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x20x30xf64>) -> tensor<10x20x30xf64> { +// CHECK: [[CST:%.+]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK-NEXT: [[EXP:%.+]] = stablehlo.exponential [[PARAM_0_]] : tensor<10x20x30xf64> +// CHECK-NEXT: [[REDUCE:%.+]] = stablehlo.reduce([[EXP]] init: [[CST]]) applies stablehlo.add across dimensions = [2] : (tensor<10x20x30xf64>, tensor) -> tensor<10x20xf64> +// CHECK-NEXT: [[DENOM:%.+]] = stablehlo.broadcast_in_dim [[REDUCE]], dims = [0, 1] : (tensor<10x20xf64>) -> tensor<10x20x30xf64> +// CHECK-NEXT: [[RES:%.+]] = stablehlo.divide [[EXP]], [[DENOM]] : tensor<10x20x30xf64> +// CHECK-NEXT: return [[RES]] : tensor<10x20x30xf64> From 2a8b11109fa4c7bd08ac08d2f4d1fa7838c2531d Mon Sep 17 00:00:00 2001 From: Jonas Rickert Date: Wed, 29 Jan 2025 08:59:11 +0100 Subject: [PATCH 6/6] Fix ASAN/UBSAN issues in DimAnalysis (#3052) - Fixes a memory leak - Fixes an integer overflow caused by a dynamic shape - Fixes reshape to wrong type in LIT tests Signed-off-by: Rickert, Jonas --- src/Dialect/ONNX/ONNXDimAnalysis.cpp | 16 +++++++------ test/mlir/onnx/onnx_dim_analysis.mlir | 34 +++++++++++++-------------- 2 files changed, 26 insertions(+), 24 deletions(-) diff --git a/src/Dialect/ONNX/ONNXDimAnalysis.cpp b/src/Dialect/ONNX/ONNXDimAnalysis.cpp index 17a91a42ec..26e5e97a62 100644 --- a/src/Dialect/ONNX/ONNXDimAnalysis.cpp +++ b/src/Dialect/ONNX/ONNXDimAnalysis.cpp @@ -361,11 +361,11 @@ static bool exploreSameDimsUsingShapeHelper(const DimAnalysis::DimT &dim, ONNXOpShapeHelper *shapeHelper = shape_op.getShapeHelper(op, {}, nullptr, nullptr); // If no shape helper, or unimplemented, just abort. - if (!shapeHelper || !shapeHelper->isImplemented()) + if (!shapeHelper) return false; // Compute shape. - if (failed(shapeHelper->computeShape())) { + if (!shapeHelper->isImplemented() || failed(shapeHelper->computeShape())) { delete shapeHelper; return false; } @@ -961,12 +961,14 @@ void DimAnalysis::visitDim( bool outputHasOneDynamicDim = (llvm::count(outputType.getShape(), ShapedType::kDynamic) == 1); // Check if the products of static sizes in the data and output are equal. - // It's ok to count ShapedType::kDynamic (dynamic dimension) in the size. int64_t dataStaticSize = 1, outputStaticSize = 1; - for (int64_t i = 0; i < dataType.getRank(); ++i) - dataStaticSize *= dataType.getShape()[i]; - for (int64_t i = 0; i < outputType.getRank(); ++i) - outputStaticSize *= outputType.getShape()[i]; + for (int64_t i = 0; i < dataType.getRank(); ++i) { + dataStaticSize *= dataType.isDynamicDim(i) ? -1 : dataType.getShape()[i]; + } + for (int64_t i = 0; i < outputType.getRank(); ++i) { + outputStaticSize *= + outputType.isDynamicDim(i) ? -1 : outputType.getShape()[i]; + } // Conditions hold, the dynamic dimension can be from the data. if (dataHasOneDynamicDim && outputHasOneDynamicDim && (dataStaticSize == outputStaticSize)) { diff --git a/test/mlir/onnx/onnx_dim_analysis.mlir b/test/mlir/onnx/onnx_dim_analysis.mlir index d0459f07ff..74f51c0f65 100644 --- a/test/mlir/onnx/onnx_dim_analysis.mlir +++ b/test/mlir/onnx/onnx_dim_analysis.mlir @@ -184,38 +184,38 @@ func.func @test_matmul_batchsize(%arg0: tensor) -> tensor) -> tensor<8x?x16x32xf32> { +func.func @test_matmul_batchsize_diff_rank(%arg0: tensor<8x?x16x4xf32>) -> tensor<8x?x16x128xf32> { %shape = onnx.Constant dense<[-1, 4, 128]> : tensor<3xi64> - %0 = "onnx.Reshape"(%arg0, %shape) {allowzero = 0 : si64} : (tensor<8x?x16x4xf32>, tensor<3xi64>) -> tensor - %1 = "onnx.MatMul"(%arg0, %0) : (tensor<8x?x16x4xf32>, tensor) -> tensor<8x?x16x32xf32> - "onnx.Return"(%1) : (tensor<8x?x16x32xf32>) -> () + %0 = "onnx.Reshape"(%arg0, %shape) {allowzero = 0 : si64} : (tensor<8x?x16x4xf32>, tensor<3xi64>) -> tensor + %1 = "onnx.MatMul"(%arg0, %0) : (tensor<8x?x16x4xf32>, tensor) -> tensor<8x?x16x128xf32> + "onnx.Return"(%1) : (tensor<8x?x16x128xf32>) -> () // CHECK-LABEL: func.func @test_matmul_batchsize_diff_rank -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<8x?x16x4xf32>) -> tensor<8x?x16x32xf32> { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<8x?x16x4xf32>) -> tensor<8x?x16x128xf32> { // CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 1 : si64, group_id = 0 : si64} : (tensor<8x?x16x4xf32>) -> () // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[-1, 4, 128]> : tensor<3xi64> -// CHECK: [[VAR_1_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_0_]]) {allowzero = 0 : si64} : (tensor<8x?x16x4xf32>, tensor<3xi64>) -> tensor -// CHECK: "onnx.DimGroup"([[VAR_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () -// CHECK: [[VAR_2_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[VAR_1_]]) : (tensor<8x?x16x4xf32>, tensor) -> tensor<8x?x16x32xf32> -// CHECK: "onnx.DimGroup"([[VAR_2_]]) {axis = 1 : si64, group_id = 0 : si64} : (tensor<8x?x16x32xf32>) -> () -// CHECK: onnx.Return [[VAR_2_]] : tensor<8x?x16x32xf32> +// CHECK: [[VAR_1_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_0_]]) {allowzero = 0 : si64} : (tensor<8x?x16x4xf32>, tensor<3xi64>) -> tensor +// CHECK: "onnx.DimGroup"([[VAR_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () +// CHECK: [[VAR_2_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[VAR_1_]]) : (tensor<8x?x16x4xf32>, tensor) -> tensor<8x?x16x128xf32> +// CHECK: "onnx.DimGroup"([[VAR_2_]]) {axis = 1 : si64, group_id = 0 : si64} : (tensor<8x?x16x128xf32>) -> () +// CHECK: onnx.Return [[VAR_2_]] : tensor<8x?x16x128xf32> // CHECK: } } // ----- -func.func @test_reshape_single_dyn_dim(%arg0: tensor<8x?x16x4xf32>) -> tensor { +func.func @test_reshape_single_dyn_dim(%arg0: tensor<8x?x16x4xf32>) -> tensor { %shape = onnx.Constant dense<[-1, 4, 128]> : tensor<3xi64> - %0 = "onnx.Reshape"(%arg0, %shape) {allowzero = 0 : si64} : (tensor<8x?x16x4xf32>, tensor<3xi64>) -> tensor - "onnx.Return"(%0) : (tensor) -> () + %0 = "onnx.Reshape"(%arg0, %shape) {allowzero = 0 : si64} : (tensor<8x?x16x4xf32>, tensor<3xi64>) -> tensor + "onnx.Return"(%0) : (tensor) -> () // CHECK-LABEL: func.func @test_reshape_single_dyn_dim -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<8x?x16x4xf32>) -> tensor { +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<8x?x16x4xf32>) -> tensor { // CHECK-DAG: "onnx.DimGroup"([[PARAM_0_]]) {axis = 1 : si64, group_id = 0 : si64} : (tensor<8x?x16x4xf32>) -> () // CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[-1, 4, 128]> : tensor<3xi64> -// CHECK: [[VAR_1_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_0_]]) {allowzero = 0 : si64} : (tensor<8x?x16x4xf32>, tensor<3xi64>) -> tensor -// CHECK: "onnx.DimGroup"([[VAR_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () -// CHECK: onnx.Return [[VAR_1_]] : tensor +// CHECK: [[VAR_1_:%.+]] = "onnx.Reshape"([[PARAM_0_]], [[VAR_0_]]) {allowzero = 0 : si64} : (tensor<8x?x16x4xf32>, tensor<3xi64>) -> tensor +// CHECK: "onnx.DimGroup"([[VAR_1_]]) {axis = 0 : si64, group_id = 0 : si64} : (tensor) -> () +// CHECK: onnx.Return [[VAR_1_]] : tensor // CHECK: } }