Skip to content

Commit

Permalink
fixed incorrect TP overlap option asserts
Browse files Browse the repository at this point in the history
Signed-off-by: Alp Dener <[email protected]>
  • Loading branch information
denera committed Dec 18, 2024
1 parent 3afa7c1 commit bb0a330
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,17 +928,15 @@ def __init__(
self.ub_name = ub_name

assert not (self.ub_overlap_rs_fprop and self.ub_overlap_ag_fprop), (
"Failed to infer TP overlap options! "
+ "Forward pass cannot do AG+GEMM and GEMM+RS at the same time."
"Cannot enable AG+GEMM and GEMM+RS overlaps at the same time."
)
assert not (self.ub_overlap_ag_dgrad and self.ub_overlap_rs_dgrad), (
"Failed to infer TP overlap options! "
+ "Backward pass cannot do AG+DGRAD and DGRAD+RS at the same time."
)
assert not self.ub_overlap_rs_dgrad and (self.ub_bulk_dgrad or self.ub_bulk_wgrad), (
"Failed to infer TP overlap options! "
+ "Backward pass cannot do DGRAD+RS and bulk overlaps at the same time."
assert not (self.ub_overlap_rs_dgrad and self.ub_bulk_dgrad), (
"Cannot enable DGRAD+RS and bulk DGRAD overlaps at the same time."
)
assert not (
self.ub_overlap_ag_dgrad
and (self.ub_overlap_rs_dgrad or self.ub_bulk_dgrad)
), "Cannot enable AG+DGRAD and DGRAD+RS or bulk DGRAD overlaps at the same time."

self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name
Expand Down

0 comments on commit bb0a330

Please sign in to comment.