diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index 9abbb93d29..82754c33a5 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -138,13 +138,13 @@ def _parse_args(argv=None, namespace=None): type=str.lower, default="row", choices=["row", "column"], - help="Parallel mode for te.Linear." + help="Parallel mode for te.Linear.", ) parser.add_argument( "--overlap-rs-dgrad", action="store_true", default=False, - help="Overlap reduce-scatter with DGRAD in the backward pass instead of bulk overlaps." + help="Overlap reduce-scatter with DGRAD in the backward pass instead of bulk overlaps.", ) parser.add_argument( "--debug", @@ -254,8 +254,8 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): ub_cfgs = None if opts.overlap_rs_dgrad: ub_cfgs = { - "proj_dgrad" : {"method" : "ring_exchange"}, - "qkv_dgrad" : {"method" : "ring_exchange"}, + "proj_dgrad": {"method": "ring_exchange"}, + "qkv_dgrad": {"method": "ring_exchange"}, } te.module.base.initialize_ub( [opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim], diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 92ceb3d82e..f10ece53e4 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -192,9 +192,7 @@ def forward( ub_algo = None rs_out = None inputmat_data = ( - inputmat_total._data - if isinstance(inputmat_total, Float8Tensor) - else inputmat_total + inputmat_total._data if isinstance(inputmat_total, Float8Tensor) else inputmat_total ) if ub_overlap_rs_fprop: ub_obj = get_ub(ub_name + "_fprop") @@ -202,8 +200,7 @@ def forward( dim_size = list(inputmat_total.size()) dim_size[0] = dim_size[0] // tp_world_size dim_size[1] = out_features - rs_out = torch.empty(dim_size, dtype=activation_dtype, - device=inputmat_total.device) + rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) if ub_obj.is_p2p_overlap(): if ub_obj.is_atomic_gemm(): ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P @@ -299,8 +296,7 @@ def forward( dim_size = list(inputmat_total.size()) dim_size[0] = dim_size[0] // tp_world_size dim_size[1] = out_features - rs_out = torch.empty(dim_size, dtype=activation_dtype, - device=inputmat_total.device) + rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) if ub_obj.is_p2p_overlap(): ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P else: @@ -460,14 +456,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dgrad = None dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] if ctx.ub_overlap_ag: - # Overlap grad_output all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") - if ctx.ub_obj_gradout.is_atomic_gemm(): - ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P - else: - ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P - dgrad = torch.empty(dgrad_shape, dtype=ctx.activation_dtype, - device=grad_output.device) + # Overlap grad_output all-gather with dgrad compute + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + if ctx.ub_obj_gradout.is_atomic_gemm(): + ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P + else: + ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P + dgrad = torch.empty( + dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device + ) elif ctx.ub_overlap_rs_dgrad: # Overlap dgrad reduce-scatter with dgrad compute @@ -483,8 +480,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_RS else: ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - rs_out = torch.empty(dgrad_shape, dtype=ctx.activation_dtype, - device=grad_output.device) + rs_out = torch.empty( + dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device + ) ctx.ub_bulk_dgrad = False ctx.ub_bulk_wgrad = False @@ -494,9 +492,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ub_algo_dgrad = tex.CommOverlapAlgo.BULK_OVERLAP_AG ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") inputmat_data = ( - inputmat._data - if isinstance(inputmat, Float8Tensor) - else inputmat + inputmat._data if isinstance(inputmat, Float8Tensor) else inputmat ) ctx.ub_obj_gradout.copy_input_to_ubuf(inputmat_data, True) inputmat_ubuf = ctx.ub_obj_gradout.get_ubuf_output(1) @@ -514,8 +510,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if dgrad is None: if ctx.parallel_mode == "column" and ctx.sequence_parallel: dgrad_shape[0] = dgrad_shape[0] * tp_world_size - dgrad = torch.empty(dgrad_shape, dtype=ctx.activation_dtype, - device=grad_output.device) + dgrad = torch.empty( + dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device + ) ( grad_output, @@ -530,10 +527,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total = None inputmat_t_total = None inputmat_gather_handle = None - if (weight.requires_grad + if ( + weight.requires_grad and ctx.parallel_mode == "column" and ctx.sequence_parallel - and not ctx.ub_bulk_dgrad): + and not ctx.ub_bulk_dgrad + ): inputmat_total, inputmat_gather_handle = gather_along_first_dim( inputmat, ctx.tp_group, async_op=ctx.requires_dgrad ) @@ -554,8 +553,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.requires_dgrad: if ctx.fp8: - if (ctx.is_input_fp8 - or (ctx.ub_overlap_rs_dgrad and ctx.ub_obj_gradout.is_fp8_ubuf())): + if ctx.is_input_fp8 or ( + ctx.ub_overlap_rs_dgrad and ctx.ub_obj_gradout.is_fp8_ubuf() + ): out_index, meta_tensor, output_te_dtype, output_dtype = ( tex.FP8BwdTensors.GRAD_INPUT1, ctx.fp8_meta["scaling_bwd"], @@ -634,7 +634,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], elif ctx.tensor_parallel and not ctx.sequence_parallel: dgrad, dgrad_reduce_handle = allreduce(dgrad, ctx.tp_group, async_op=True) - wgrad = None if weight.requires_grad: if ctx.fp8: @@ -931,9 +930,9 @@ def __init__( assert not (self.ub_overlap_rs_fprop and self.ub_overlap_ag_fprop), "Internal TE error!" assert not (self.ub_overlap_ag_dgrad and self.ub_overlap_rs_dgrad), "Internal TE error!" - assert not (self.ub_overlap_rs_dgrad and (self.ub_bulk_dgrad or self.ub_bulk_wgrad)), ( - "Internal TE error!" - ) + assert not ( + self.ub_overlap_rs_dgrad and (self.ub_bulk_dgrad or self.ub_bulk_wgrad) + ), "Internal TE error!" self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name