Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 20, 2024
1 parent 7cb1954 commit 90458d4
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 34 deletions.
8 changes: 4 additions & 4 deletions tests/pytorch/distributed/run_layer_with_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,13 @@ def _parse_args(argv=None, namespace=None):
type=str.lower,
default="row",
choices=["row", "column"],
help="Parallel mode for te.Linear."
help="Parallel mode for te.Linear.",
)
parser.add_argument(
"--overlap-rs-dgrad",
action="store_true",
default=False,
help="Overlap reduce-scatter with DGRAD in the backward pass instead of bulk overlaps."
help="Overlap reduce-scatter with DGRAD in the backward pass instead of bulk overlaps.",
)
parser.add_argument(
"--debug",
Expand Down Expand Up @@ -254,8 +254,8 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False):
ub_cfgs = None
if opts.overlap_rs_dgrad:
ub_cfgs = {
"proj_dgrad" : {"method" : "ring_exchange"},
"qkv_dgrad" : {"method" : "ring_exchange"},
"proj_dgrad": {"method": "ring_exchange"},
"qkv_dgrad": {"method": "ring_exchange"},
}
te.module.base.initialize_ub(
[opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim],
Expand Down
59 changes: 29 additions & 30 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,18 +192,15 @@ def forward(
ub_algo = None
rs_out = None
inputmat_data = (
inputmat_total._data
if isinstance(inputmat_total, Float8Tensor)
else inputmat_total
inputmat_total._data if isinstance(inputmat_total, Float8Tensor) else inputmat_total
)
if ub_overlap_rs_fprop:
ub_obj = get_ub(ub_name + "_fprop")
out = ub_obj.get_ubuf_output(1)
dim_size = list(inputmat_total.size())
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = out_features
rs_out = torch.empty(dim_size, dtype=activation_dtype,
device=inputmat_total.device)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
if ub_obj.is_p2p_overlap():
if ub_obj.is_atomic_gemm():
ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P
Expand Down Expand Up @@ -299,8 +296,7 @@ def forward(
dim_size = list(inputmat_total.size())
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = out_features
rs_out = torch.empty(dim_size, dtype=activation_dtype,
device=inputmat_total.device)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device)
if ub_obj.is_p2p_overlap():
ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P
else:
Expand Down Expand Up @@ -460,14 +456,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
dgrad = None
dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]]
if ctx.ub_overlap_ag:
# Overlap grad_output all-gather with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
if ctx.ub_obj_gradout.is_atomic_gemm():
ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P
else:
ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
dgrad = torch.empty(dgrad_shape, dtype=ctx.activation_dtype,
device=grad_output.device)
# Overlap grad_output all-gather with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
if ctx.ub_obj_gradout.is_atomic_gemm():
ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P
else:
ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
dgrad = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
)

elif ctx.ub_overlap_rs_dgrad:
# Overlap dgrad reduce-scatter with dgrad compute
Expand All @@ -483,8 +480,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
ub_algo_dgrad = tex.CommOverlapAlgo.ATOMIC_GEMM_RS
else:
ub_algo_dgrad = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS
rs_out = torch.empty(dgrad_shape, dtype=ctx.activation_dtype,
device=grad_output.device)
rs_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
)
ctx.ub_bulk_dgrad = False
ctx.ub_bulk_wgrad = False

Expand All @@ -494,9 +492,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
ub_algo_dgrad = tex.CommOverlapAlgo.BULK_OVERLAP_AG
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
inputmat_data = (
inputmat._data
if isinstance(inputmat, Float8Tensor)
else inputmat
inputmat._data if isinstance(inputmat, Float8Tensor) else inputmat
)
ctx.ub_obj_gradout.copy_input_to_ubuf(inputmat_data, True)
inputmat_ubuf = ctx.ub_obj_gradout.get_ubuf_output(1)
Expand All @@ -514,8 +510,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
if dgrad is None:
if ctx.parallel_mode == "column" and ctx.sequence_parallel:
dgrad_shape[0] = dgrad_shape[0] * tp_world_size
dgrad = torch.empty(dgrad_shape, dtype=ctx.activation_dtype,
device=grad_output.device)
dgrad = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
)

(
grad_output,
Expand All @@ -530,10 +527,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
inputmat_total = None
inputmat_t_total = None
inputmat_gather_handle = None
if (weight.requires_grad
if (
weight.requires_grad
and ctx.parallel_mode == "column"
and ctx.sequence_parallel
and not ctx.ub_bulk_dgrad):
and not ctx.ub_bulk_dgrad
):
inputmat_total, inputmat_gather_handle = gather_along_first_dim(
inputmat, ctx.tp_group, async_op=ctx.requires_dgrad
)
Expand All @@ -554,8 +553,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],

if ctx.requires_dgrad:
if ctx.fp8:
if (ctx.is_input_fp8
or (ctx.ub_overlap_rs_dgrad and ctx.ub_obj_gradout.is_fp8_ubuf())):
if ctx.is_input_fp8 or (
ctx.ub_overlap_rs_dgrad and ctx.ub_obj_gradout.is_fp8_ubuf()
):
out_index, meta_tensor, output_te_dtype, output_dtype = (
tex.FP8BwdTensors.GRAD_INPUT1,
ctx.fp8_meta["scaling_bwd"],
Expand Down Expand Up @@ -634,7 +634,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
elif ctx.tensor_parallel and not ctx.sequence_parallel:
dgrad, dgrad_reduce_handle = allreduce(dgrad, ctx.tp_group, async_op=True)


wgrad = None
if weight.requires_grad:
if ctx.fp8:
Expand Down Expand Up @@ -931,9 +930,9 @@ def __init__(

assert not (self.ub_overlap_rs_fprop and self.ub_overlap_ag_fprop), "Internal TE error!"
assert not (self.ub_overlap_ag_dgrad and self.ub_overlap_rs_dgrad), "Internal TE error!"
assert not (self.ub_overlap_rs_dgrad and (self.ub_bulk_dgrad or self.ub_bulk_wgrad)), (
"Internal TE error!"
)
assert not (
self.ub_overlap_rs_dgrad and (self.ub_bulk_dgrad or self.ub_bulk_wgrad)
), "Internal TE error!"

self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name
Expand Down

0 comments on commit 90458d4

Please sign in to comment.