Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 5, 2024
1 parent 340e033 commit 1036168
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 11 deletions.
10 changes: 4 additions & 6 deletions tests/pytorch/distributed/run_layer_with_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _get_ub_cfg(config):
"qkv_fprop": dict(),
"qkv_dgrad": {
"method": "pipeline" if config.rs_dgrad_overlap else "bulk",
"fp8_buf": True if config.fp8_buf and config.rs_dgrad_overlap else False
"fp8_buf": True if config.fp8_buf and config.rs_dgrad_overlap else False,
},
}
)
Expand All @@ -101,9 +101,7 @@ def _get_ub_cfg(config):
if config.layer_type in [te.Linear, te.MultiheadAttention, te.TransformerLayer]:
ub_cfg.update(
{
"proj_fprop": {
"fp8_buf": True if config.fp8_buf else False
},
"proj_fprop": {"fp8_buf": True if config.fp8_buf else False},
"proj_dgrad": dict(),
}
)
Expand All @@ -113,7 +111,7 @@ def _get_ub_cfg(config):
"fc1_fprop": dict(),
"fc1_dgrad": {
"method": "pipeline" if config.rs_dgrad_overlap else "bulk",
"fp8_buf": True if config.fp8_buf and config.rs_dgrad_overlap else False
"fp8_buf": True if config.fp8_buf and config.rs_dgrad_overlap else False,
},
"fc2_fprop": {
"fp8_buf": True if config.fp8_buf else False,
Expand Down Expand Up @@ -150,7 +148,7 @@ def _parse_args(argv=None, namespace=None):
"--fp8-buf",
action="store_true",
default=False,
help="Allocate FP8 communication buffers for layers that support it."
help="Allocate FP8 communication buffers for layers that support it.",
)
parser.add_argument(
"--rs-dgrad-overlap",
Expand Down
6 changes: 1 addition & 5 deletions tests/pytorch/distributed/test_comm_gemm_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,7 @@ def test_bulk_overlaps(comm_type, fp8):
)
@pytest.mark.parametrize(
"fp8,fp8_init",
[
(False, False),
(True, False),
(True, True)
],
[(False, False), (True, False), (True, True)],
ids=[
" BF16 GEMM - BF16 PARAMS ",
" FP8 GEMM - BF16 PARAMS ",
Expand Down

0 comments on commit 1036168

Please sign in to comment.