diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index 827dec5010..5f8ccab334 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -18,6 +18,7 @@ #include #include "common/util/logging.h" +#include "common/util/system.h" #include "userbuffers/userbuffers.h" #define HALF_BYTES 2 @@ -112,6 +113,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { cudaDeviceProp prop; cudaGetDeviceProperties(&prop, 0); _math_sms = (set_sm_margin) ? prop.multiProcessorCount - num_comm_sm : prop.multiProcessorCount; + _math_sms -= transformer_engine::getenv("NVTE_EXT_MARGIN_SM", 0); output_tensor = torch::Tensor(); auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); @@ -587,6 +589,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { cudaDeviceProp prop; cudaGetDeviceProperties(&prop, 0); _math_sms = (set_sm_margin) ? prop.multiProcessorCount - num_comm_sm : prop.multiProcessorCount; + _math_sms -= transformer_engine::getenv("NVTE_EXT_MARGIN_SM", 0); _tp_size = tp_size; _aggregate2 = aggregate2; diff --git a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp index b25a8cf110..71402d2001 100755 --- a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp +++ b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp @@ -6,7 +6,9 @@ #include #include "extensions.h" - +#include +#include +#include "common/util/system.h" namespace { transformer_engine::DType reverse_map_dtype(int64_t dtype) { @@ -316,6 +318,13 @@ at::Tensor te_gemm_ts(at::Tensor A, bool accumulate_arg = static_cast(accumulate); bool use_split_accumulator_arg = static_cast(use_split_accumulator); + // Set an external SM Margin to all the GEMMs. + // This comes in handy when DP is overlapped with GEMMs + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, 0); + int num_math_sms = prop.multiProcessorCount \ + - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", 0); + if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; @@ -342,7 +351,7 @@ at::Tensor te_gemm_ts(at::Tensor A, workspaceSize_arg, accumulate_arg, use_split_accumulator_arg, - 0); + num_math_sms); return D; }