From 68b5072fdb9df6b6edab1392b02a705394b2e906 Mon Sep 17 00:00:00 2001 From: "Xiaodong (Vincent) Huang" Date: Wed, 8 Mar 2023 11:42:45 +0800 Subject: [PATCH] Cherry-pick fix fcPlugin build failure on CUDA12 (#2742) Signed-off-by: Vincent Huang --- plugin/fcPlugin/fcPlugin.cpp | 26 +++++++++++++++----------- plugin/fcPlugin/fcPlugin.h | 16 ++++++++-------- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/plugin/fcPlugin/fcPlugin.cpp b/plugin/fcPlugin/fcPlugin.cpp index e9045a14..3f8e7340 100644 --- a/plugin/fcPlugin/fcPlugin.cpp +++ b/plugin/fcPlugin/fcPlugin.cpp @@ -58,7 +58,11 @@ static void printPerfStructure(const customMatmulPerf_t& perf, int const& m, int double timeAvg = perf.time * 1e-3; // Convert to seconds. It has been divided by kernelRepeats in customMatmulRun(). double gflop = (2 * static_cast(m * n) * k) * 1e-9; // Real - gLogVerbose << "Algo=" << p.algoId << " Tile=" << p.tile << " (" << matmulTileName[p.tile] << ") K=" << p.numSplitsK << " Red.Sch.=" << p.reductionScheme << " Swiz=" << p.swizzle << " Cust=" << p.customOption << " Stat=" << perf.status << " Time=" << perf.time << " WSbytes=" << perf.workspaceSize << " math=" << p.mathMode << " waves=" << perf.wavesCount << "GFlops=" << (gflop / timeAvg) << std::endl; + gLogVerbose << "Algo=" << p.algoId << " Tile=" << p.tile << " (" << matmulTileName[p.tile] << ") K=" << p.numSplitsK + << " Red.Sch.=" << p.reductionScheme << " Swiz=" << p.swizzle << " Cust=" << p.customOption + << " Stat=" << perf.status << " Time=" << perf.time << " WSbytes=" << perf.workspaceSize + << " math=" << p.numericImpl << " waves=" << perf.wavesCount << "GFlops=" << (gflop / timeAvg) + << std::endl; } static inline bool time_compare(const customMatmulPerf_t& perf_a, const customMatmulPerf_t& perf_b) @@ -166,9 +170,10 @@ void LtGemmSearch(cublasLtHandle_t ltHandle, cublasOperation_t transa, cublasOpe PLUGIN_CUBLASASSERT(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workSpaceSize, sizeof(workSpaceSize))); - const int mathMode = Ctype == CUDA_R_16F ? 1 : 0; + uint64_t const numericImplPrefer + = Ctype == CUDA_R_16F ? CUBLASLT_NUMERICAL_IMPL_FLAGS_HMMA : CUBLASLT_NUMERICAL_IMPL_FLAGS_FMA; PLUGIN_CUBLASASSERT(cublasLtMatmulPreferenceSetAttribute( - preference, CUBLASLT_MATMUL_PREF_MATH_MODE_MASK, &mathMode, sizeof(mathMode))); + preference, CUBLASLT_MATMUL_PREF_IMPL_MASK, &numericImplPrefer, sizeof(numericImplPrefer))); // Create operation descriptor; see cublasLtMatmulDescAttributes_t for details // about defaults; here we just need to set the transforms for A and B #if CUBLAS_VER_MAJOR < 11 @@ -213,13 +218,12 @@ void LtGemmSearch(cublasLtHandle_t ltHandle, cublasOperation_t transa, cublasOpe continue; } - int mathMode = -1; + uint64_t numericImpl = -1; PLUGIN_CUBLASASSERT(cublasLtMatmulAlgoCapGetAttribute( - &algo, CUBLASLT_ALGO_CAP_MATHMODE_IMPL, &mathMode, sizeof(mathMode), nullptr)); - // TODO is this the right way to check that it's SGEMM? - if (Ctype == CUDA_R_32F && mathMode == 1) + &algo, CUBLASLT_ALGO_CAP_NUMERICAL_IMPL_FLAGS, &numericImpl, sizeof(numericImpl), nullptr)); + if (Ctype == CUDA_R_32F && numericImpl == CUBLASLT_NUMERICAL_IMPL_FLAGS_HMMA) { - // if mathMode is 1, cublasLt chooses automatically to run in mixed precision for certain sizes + // skip HMMA-fp32accu kernels continue; } @@ -524,19 +528,19 @@ void FCPluginDynamic::configurePlugin( AlgoProps p; p.populate(mAlgo); - if (mType == DataType::kFLOAT && p.mathMode == 1) + if (mType == DataType::kFLOAT && p.numericImpl == CUBLASLT_NUMERICAL_IMPL_FLAGS_HMMA) { gLogWarning << "cuBLAS might use mixed precision instead of FP32" << std::endl; } - if (mType == DataType::kHALF && p.mathMode == 0) + if (mType == DataType::kHALF && p.numericImpl != CUBLASLT_NUMERICAL_IMPL_FLAGS_HMMA) { gLogWarning << "TensorCore support was not selected" << std::endl; } gLogVerbose << "FCPluginDynamic configuration Algo=" << p.algoId << " Tile=" << p.tile << " (" << matmulTileName[p.tile] << ") K=" << p.numSplitsK << " Red.Sch.=" << p.reductionScheme - << " Swiz=" << p.swizzle << " Cust=" << p.customOption << " mathMode=" << p.mathMode + << " Swiz=" << p.swizzle << " Cust=" << p.customOption << " numericImpl=" << p.numericImpl << " ws=" << actualWorkspace << std::endl; } catch (const std::exception& e) diff --git a/plugin/fcPlugin/fcPlugin.h b/plugin/fcPlugin/fcPlugin.h index 48860ab1..3bf2ed79 100644 --- a/plugin/fcPlugin/fcPlugin.h +++ b/plugin/fcPlugin/fcPlugin.h @@ -352,13 +352,13 @@ const char* const matmulTileName[] = { struct AlgoProps { - int algoId; - int tile; - int swizzle; - int customOption; - int numSplitsK; - int reductionScheme; - int mathMode; + int32_t algoId; + int32_t tile; + int32_t swizzle; + int32_t customOption; + int32_t numSplitsK; + int32_t reductionScheme; + uint64_t numericImpl; void populate(const cublasLtMatmulAlgo_t& algo) { @@ -376,7 +376,7 @@ struct AlgoProps PLUGIN_CUBLASASSERT(cublasLtMatmulAlgoConfigGetAttribute( matmulAlgo, CUBLASLT_ALGO_CONFIG_CUSTOM_OPTION, &customOption, sizeof(customOption), nullptr)); PLUGIN_CUBLASASSERT(cublasLtMatmulAlgoCapGetAttribute( - matmulAlgo, CUBLASLT_ALGO_CAP_MATHMODE_IMPL, &mathMode, sizeof(mathMode), nullptr)); + matmulAlgo, CUBLASLT_ALGO_CAP_NUMERICAL_IMPL_FLAGS, &numericImpl, sizeof(numericImpl), nullptr)); } };