Skip to content

Commit

Permalink
Merge branch 'main' into denliu/moe_fp8
Browse files Browse the repository at this point in the history
Signed-off-by: Dennis Liu <[email protected]>
  • Loading branch information
Victarry committed Mar 22, 2024
2 parents d8a9d48 + b855656 commit c862ac0
Show file tree
Hide file tree
Showing 11 changed files with 497 additions and 268 deletions.
18 changes: 6 additions & 12 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3175,10 +3175,8 @@ def __init__(
qkv_weight_interleaved: bool = True,
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_split_rs: bool = False,
ub_split_ag: bool = False,
ub_atomic_gemm_rs: bool = False,
ub_atomic_gemm_ag: bool = False,
ub_overlap_rs: bool = False,
ub_overlap_ag: bool = False,
bias: bool = True,
normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda",
Expand Down Expand Up @@ -3265,9 +3263,8 @@ def __init__(
zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag,
ub_overlap_ag=ub_overlap_ag,
normalization=normalization,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
ub_name="qkv",
**common_gemm_kwargs,
)
Expand Down Expand Up @@ -3297,9 +3294,8 @@ def __init__(
zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag,
ub_overlap_ag=ub_overlap_ag,
normalization=normalization,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
ub_name="qkv",
**common_gemm_kwargs,
)
Expand Down Expand Up @@ -3347,10 +3343,8 @@ def __init__(
bias=bias,
return_bias=return_bias,
parallel_mode="row" if set_parallel_mode else None,
ub_split_rs=ub_split_rs,
ub_split_ag=ub_split_ag,
ub_atomic_gemm_rs=ub_atomic_gemm_rs,
ub_atomic_gemm_ag=ub_atomic_gemm_ag,
ub_overlap_rs=ub_overlap_rs,
ub_overlap_ag=ub_overlap_ag,
ub_name="proj",
**common_gemm_kwargs,
)
Expand Down
30 changes: 24 additions & 6 deletions transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,14 @@ def fp8_gemm(
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (0, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG:
fn = ub.split_overlap_ag
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P:
fn = ub.split_overlap_ag_p2p
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG:
fn = ub.atomic_gemm_overlap_ag
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P:
fn = ub.atomic_gemm_overlap_ag_p2p
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
Expand All @@ -121,12 +121,24 @@ def fp8_gemm(
extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS requires extra output tensor'
args = tuple(args + (True, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P:
fn = ub.split_overlap_rs_p2p
assert (
extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS_P2P requires extra output tensor'
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS:
fn = ub.atomic_gemm_overlap_rs
assert (
extra_output_tensor is not None
), 'ATOMIC_GEMM_RS requires extra output tensor'
args = tuple(args + (True, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P:
fn = ub.atomic_gemm_overlap_rs_p2p
assert (
extra_output_tensor is not None
), 'ATOMIC_GEMM_RS_P2P requires extra output tensor'
args = tuple(args + (extra_output_tensor,))
_ = fn(*args)

return out, gelu_input
Expand Down Expand Up @@ -221,8 +233,8 @@ def gemm(
elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
fn = ub.bulk_overlap
args = tuple(args + (0, empty_tensor))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG:
fn = ub.split_overlap_ag
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P:
fn = ub.split_overlap_ag_p2p
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
Expand All @@ -233,6 +245,12 @@ def gemm(
extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS requires extra output tensor'
args = tuple(args + (False, extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P:
fn = ub.split_overlap_rs_p2p
assert (
extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS_P2P requires extra output tensor'
args = tuple(args + (extra_output_tensor,))
_ = fn(*args)

return out, grad_bias, gelu_input
Loading

0 comments on commit c862ac0

Please sign in to comment.