From 93f15d21851315004e4993bad04ebec89012529d Mon Sep 17 00:00:00 2001 From: Dennis Liu Date: Wed, 31 Jan 2024 01:38:13 -0800 Subject: [PATCH] code clean. --- transformer_engine/pytorch/module/linear.py | 63 +++++++++++---------- 1 file changed, 33 insertions(+), 30 deletions(-) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index f7e0ede862..7e24a01b5c 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -310,14 +310,13 @@ def forward( ctx.requires_dgrad = inp.requires_grad # Row Parallel Linear - if ub_split_rs or ub_atomic_gemm_rs: - out = rs_out - elif explicit_expert_comm: - out = out - elif parallel_mode == "row" and sequence_parallel: - out, _ = reduce_scatter_along_first_dim(out, tp_group) - elif parallel_mode == "row" and tensor_parallel: - out, _ = allreduce(out, tp_group) + if not explicit_expert_comm: + if ub_split_rs or ub_atomic_gemm_rs: + out = rs_out + elif parallel_mode == "row" and sequence_parallel: + out, _ = reduce_scatter_along_first_dim(out, tp_group) + elif parallel_mode == "row" and tensor_parallel: + out, _ = allreduce(out, tp_group) # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp.shape[1:-1], out.shape[-1]) @@ -425,16 +424,15 @@ def backward( ) # Overlap dgrad-RS/AR with wgrad - if ctx.explicit_expert_comm: - dgrad = dgrad - elif ctx.parallel_mode == "column" and ctx.sequence_parallel: - if handle is not None: - handle.wait() - dgrad, handle = reduce_scatter_along_first_dim( - dgrad, ctx.tp_group, async_op=True - ) - elif ctx.parallel_mode == "column" and ctx.tensor_parallel: - dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True) + if not ctx.explicit_expert_comm: + if ctx.parallel_mode == "column" and ctx.sequence_parallel: + if handle is not None: + handle.wait() + dgrad, handle = reduce_scatter_along_first_dim( + dgrad, ctx.tp_group, async_op=True + ) + elif ctx.parallel_mode == "column" and ctx.tensor_parallel: + dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True) if weight.requires_grad: if ctx.fp8: @@ -588,6 +586,7 @@ class Linear(TransformerEngineBaseModule): forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives. ep_size : int, default = 1 + used as EP (expert parallel) world size. parallel_mode : {None, 'Column', 'Row'}, default = `None` used to decide whether this Linear layer is Column Parallel Linear or Row Parallel Linear as described `here `_. @@ -775,7 +774,8 @@ def __init__( init_fn=init_method, get_rng_state_tracker=get_rng_state_tracker, fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT) - setattr(getattr(self, self.weight_names[i]), 'allreduce', not (self.is_expert and self.expert_parallel)) + setattr(getattr(self, self.weight_names[i]), 'allreduce', + not (self.is_expert and self.expert_parallel)) # Construct bias parameter if needed if self.use_bias: @@ -785,7 +785,8 @@ def __init__( bias = torch.nn.Parameter(bias) self.register_parameter(self.bias_names[i], bias, init_fn=init_method_constant(0.0)) - setattr(getattr(self, self.bias_names[i]), 'allreduce', not (self.is_expert and self.expert_parallel)) + setattr(getattr(self, self.bias_names[i]), 'allreduce', + not (self.is_expert and self.expert_parallel)) else: bias = torch.Tensor().to(dtype=params_dtype, device=device) setattr(self, self.bias_names[i], bias) @@ -908,6 +909,17 @@ def forward( weight_tensor = self.weight_tensor bias_tensor = self.bias_tensor + # When tokens are not passed to the expert + if inp.nelement() == 0: + # Manually call prepare_backward for amax global buffer key deletion + def no_tokens_backward(grad): + with _prepare_backward(self.fp8, self.fp8_meta, + self.tp_group, self.tp_size, name="_Linear"): + pass + return grad + inp.register_hook(no_tokens_backward) + return torch.matmul(inp, weight_tensor.t()) + # Fetch the fp8 weights placeholders (for linear/gemm) weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad( is_first_microbatch @@ -949,16 +961,7 @@ def forward( self.ub_atomic_gemm_ag, self.ub_name, ) - if inp.nelement() == 0: - # Enable global buffer key deletion when no tokens are passed - def no_tokens_backward(grad): - with _prepare_backward(self.fp8, self.fp8_meta, self.tp_group, self.tp_size, name="_Linear"): - pass - return grad - inp.register_hook(no_tokens_backward) - return torch.matmul(inp, weight_tensor.t()) - else: - out = linear_fn(*args) + out = linear_fn(*args) if self.gemm_bias_unfused_add: out = out + cast_if_needed(bias_tensor, self.activation_dtype)