Skip to content

Commit

Permalink
add external margin (#713)
Browse files Browse the repository at this point in the history
Add envvar for SM margin in GEMM

Signed-off-by: Rachit Garg <[email protected]>
Co-authored-by: Rachit Garg <[email protected]>
  • Loading branch information
rachitgarg91 and Rachit Garg authored Mar 13, 2024
1 parent a38b291 commit e3d2efd
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/csrc/comm_gemm_overlap.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <torch/types.h>

#include "common/util/logging.h"
#include "common/util/system.h"
#include "userbuffers/userbuffers.h"

#define HALF_BYTES 2
Expand Down Expand Up @@ -112,6 +113,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
_math_sms = (set_sm_margin) ? prop.multiProcessorCount - num_comm_sm : prop.multiProcessorCount;
_math_sms -= transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", 0);

output_tensor = torch::Tensor();
auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA);
Expand Down Expand Up @@ -587,6 +589,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
_math_sms = (set_sm_margin) ? prop.multiProcessorCount - num_comm_sm : prop.multiProcessorCount;
_math_sms -= transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", 0);

_tp_size = tp_size;
_aggregate2 = aggregate2;
Expand Down
13 changes: 11 additions & 2 deletions transformer_engine/pytorch/csrc/ts_fp8_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

#include <torch/script.h>
#include "extensions.h"

#include <cuda.h>
#include <cuda_fp8.h>
#include "common/util/system.h"

namespace {
transformer_engine::DType reverse_map_dtype(int64_t dtype) {
Expand Down Expand Up @@ -316,6 +318,13 @@ at::Tensor te_gemm_ts(at::Tensor A,
bool accumulate_arg = static_cast<bool>(accumulate);
bool use_split_accumulator_arg = static_cast<bool>(use_split_accumulator);

// Set an external SM Margin to all the GEMMs.
// This comes in handy when DP is overlapped with GEMMs
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
int num_math_sms = prop.multiProcessorCount \
- transformer_engine::getenv<int>("NVTE_EXT_MARGIN_SM", 0);

if (A_scale_inverse.numel())
A_scale_inverse = A_scale_inverse[A_fp8_tensor];

Expand All @@ -342,7 +351,7 @@ at::Tensor te_gemm_ts(at::Tensor A,
workspaceSize_arg,
accumulate_arg,
use_split_accumulator_arg,
0);
num_math_sms);
return D;
}

Expand Down

0 comments on commit e3d2efd

Please sign in to comment.