Skip to content

Commit

Permalink
Add additional carveout for DP
Browse files Browse the repository at this point in the history
Signed-off-by: rachitg <[email protected]>
  • Loading branch information
rachitg committed Mar 11, 2024
1 parent 8255f87 commit ef84339
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions transformer_engine/pytorch/csrc/comm_gemm_overlap.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
cga_size = comm_cga_size;
_empty_tensor = empty_tensor;

const char* ext_margin_sm = std::getenv("NVTE_EXT_MARGIN_SM");

int num_ext_margin_sm = 0;
if (ext_margin_sm != NULL){
num_ext_margin_sm = atoi(ext_margin_sm);
}
// Allocate and register extra userbuffers
int ubuf_bytes = sample.numel() * sample.element_size();
_ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes,
Expand All @@ -112,6 +118,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
_math_sms = (set_sm_margin) ? prop.multiProcessorCount - num_comm_sm : prop.multiProcessorCount;
_math_sms -= num_ext_margin_sm;

output_tensor = torch::Tensor();
auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA);
Expand Down Expand Up @@ -555,6 +562,12 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
sms = 1;
cga_size = 1;

const char* ext_margin_sm = std::getenv("NVTE_EXT_MARGIN_SM");
int num_ext_margin_sm = 0;
if (ext_margin_sm != NULL){
num_ext_margin_sm = atoi(ext_margin_sm);
}

_empty_tensor = empty_tensor;
// Create workspace tensor with userbuffer
int ubuf_bytes = sample.numel() * sample.element_size();
Expand Down Expand Up @@ -587,6 +600,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
_math_sms = (set_sm_margin) ? prop.multiProcessorCount - num_comm_sm : prop.multiProcessorCount;
_math_sms -= num_ext_margin_sm;

_tp_size = tp_size;
_aggregate2 = aggregate2;
Expand Down

0 comments on commit ef84339

Please sign in to comment.