Skip to content

Commit

Permalink
code clean.
Browse files Browse the repository at this point in the history
  • Loading branch information
Victarry committed Jan 31, 2024
1 parent ba5ad8b commit 93f15d2
Showing 1 changed file with 33 additions and 30 deletions.
63 changes: 33 additions & 30 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,14 +310,13 @@ def forward(
ctx.requires_dgrad = inp.requires_grad

# Row Parallel Linear
if ub_split_rs or ub_atomic_gemm_rs:
out = rs_out
elif explicit_expert_comm:
out = out
elif parallel_mode == "row" and sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif parallel_mode == "row" and tensor_parallel:
out, _ = allreduce(out, tp_group)
if not explicit_expert_comm:
if ub_split_rs or ub_atomic_gemm_rs:
out = rs_out
elif parallel_mode == "row" and sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif parallel_mode == "row" and tensor_parallel:
out, _ = allreduce(out, tp_group)

# [*, in_features] -> [*, out_features] except first dimension changes for SP
return out.view(-1, *inp.shape[1:-1], out.shape[-1])
Expand Down Expand Up @@ -425,16 +424,15 @@ def backward(
)

# Overlap dgrad-RS/AR with wgrad
if ctx.explicit_expert_comm:
dgrad = dgrad
elif ctx.parallel_mode == "column" and ctx.sequence_parallel:
if handle is not None:
handle.wait()
dgrad, handle = reduce_scatter_along_first_dim(
dgrad, ctx.tp_group, async_op=True
)
elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
if not ctx.explicit_expert_comm:
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
if handle is not None:
handle.wait()
dgrad, handle = reduce_scatter_along_first_dim(
dgrad, ctx.tp_group, async_op=True
)
elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)

if weight.requires_grad:
if ctx.fp8:
Expand Down Expand Up @@ -588,6 +586,7 @@ class Linear(TransformerEngineBaseModule):
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
ep_size : int, default = 1
used as EP (expert parallel) world size.
parallel_mode : {None, 'Column', 'Row'}, default = `None`
used to decide whether this Linear layer is Column Parallel Linear or Row
Parallel Linear as described `here <https://arxiv.org/pdf/1909.08053.pdf>`_.
Expand Down Expand Up @@ -775,7 +774,8 @@ def __init__(
init_fn=init_method,
get_rng_state_tracker=get_rng_state_tracker,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT)
setattr(getattr(self, self.weight_names[i]), 'allreduce', not (self.is_expert and self.expert_parallel))
setattr(getattr(self, self.weight_names[i]), 'allreduce',
not (self.is_expert and self.expert_parallel))

# Construct bias parameter if needed
if self.use_bias:
Expand All @@ -785,7 +785,8 @@ def __init__(
bias = torch.nn.Parameter(bias)
self.register_parameter(self.bias_names[i], bias,
init_fn=init_method_constant(0.0))
setattr(getattr(self, self.bias_names[i]), 'allreduce', not (self.is_expert and self.expert_parallel))
setattr(getattr(self, self.bias_names[i]), 'allreduce',
not (self.is_expert and self.expert_parallel))
else:
bias = torch.Tensor().to(dtype=params_dtype, device=device)
setattr(self, self.bias_names[i], bias)
Expand Down Expand Up @@ -908,6 +909,17 @@ def forward(
weight_tensor = self.weight_tensor
bias_tensor = self.bias_tensor

# When tokens are not passed to the expert
if inp.nelement() == 0:
# Manually call prepare_backward for amax global buffer key deletion
def no_tokens_backward(grad):
with _prepare_backward(self.fp8, self.fp8_meta,
self.tp_group, self.tp_size, name="_Linear"):
pass
return grad
inp.register_hook(no_tokens_backward)
return torch.matmul(inp, weight_tensor.t())

# Fetch the fp8 weights placeholders (for linear/gemm)
weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad(
is_first_microbatch
Expand Down Expand Up @@ -949,16 +961,7 @@ def forward(
self.ub_atomic_gemm_ag,
self.ub_name,
)
if inp.nelement() == 0:
# Enable global buffer key deletion when no tokens are passed
def no_tokens_backward(grad):
with _prepare_backward(self.fp8, self.fp8_meta, self.tp_group, self.tp_size, name="_Linear"):
pass
return grad
inp.register_hook(no_tokens_backward)
return torch.matmul(inp, weight_tensor.t())
else:
out = linear_fn(*args)
out = linear_fn(*args)

if self.gemm_bias_unfused_add:
out = out + cast_if_needed(bias_tensor, self.activation_dtype)
Expand Down

0 comments on commit 93f15d2

Please sign in to comment.