Skip to content

Commit

Permalink
fixed UB config reference before assignment and corrected FP8 UB buff…
Browse files Browse the repository at this point in the history
…er logic

Signed-off-by: Alp Dener <[email protected]>
  • Loading branch information
denera committed Dec 5, 2024
1 parent 2ca29de commit d2d9938
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def add_ub(

# Loop over user configs and disable dgrad and wgrad bulk overlaps for every layer that has a
# reduce-scatter dgrad overlap.
ub_cfg = {} if ub_cfg is None else ub_cfg
ub_cfgs = {} if ub_cfgs is None else ub_cfgs
for name in dgrad_reduce_scatter_overlap:
if name in ub_cfgs:
final_cfg = get_default_config(name)
Expand All @@ -410,15 +410,16 @@ def add_ub(
new_method = final_cfg["method"]
methods[new_method].append(name)

ub_cfg[name] = final_cfg
ub_cfgs[name] = final_cfg

# Now initialize the UB objects for each layer
for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]:
if ub_cfgs is not None and name in ub_cfgs:
if name in ub_cfgs:
final_cfg = get_default_config(name)
final_cfg.update(ub_cfgs[name])
final_cfg["fp8_buf"] = (name in layers_all_gather_overlap) or (
ub_cfgs[name].get("fp8_buf", False) and name in methods["pipeline"]
final_cfg["fp8_buf"] = (
(name in layers_all_gather_overlap)
or ub_cfgs[name].get("fp8_buf", False)
)
add_ub(name, **final_cfg)

Expand Down

0 comments on commit d2d9938

Please sign in to comment.