Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Bugfix for wgrad bulk overlap conflict when dgrad overlap is reduce-scatter #1341

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

denera
Copy link
Collaborator

@denera denera commented Nov 18, 2024

Description

When Userbuffers config dictionary sets overlap method to ring-exchange or pipeline for any *_dgrad layer, that layer's *_wgrad overlap needs to be disabled in order for ub_overlap_rs_dgrad=True option for related TE modules to function correctly.

This PR fixes a bug where the "*_wgrad" overlap was persisting in the Userbuffer configuration and the corresponding UB object was being initialized even when it was not needed.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

Please list the changes introduced in this PR:

  • *_wgrad overlap is now removed from methods["bulk"] list when the same layer's *_dgrad overlap has its method set to either ring-exchange or pipeline.
  • add_ub(name, **ub_cfg) is now only called if name is in the original user-provided ub_cfg. This avoids creating UB objects with default configs that may conflict with the user's intended TP overlap use.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@denera denera added the bug Something isn't working label Nov 18, 2024
@denera denera requested review from timmoon10 and ksivaman November 18, 2024 16:45
@denera denera self-assigned this Nov 18, 2024
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending CI

transformer_engine/pytorch/module/base.py Outdated Show resolved Hide resolved
@timmoon10
Copy link
Collaborator

/te-ci pytorch L0 L1

@denera denera force-pushed the rs-dgrad-overlap-bugfix branch from ab9e05f to 2ca29de Compare November 22, 2024 16:09
@denera
Copy link
Collaborator Author

denera commented Nov 22, 2024

/te-ci pytorch L0 L1

if name in ub_cfgs and "method" in ub_cfgs[name] and ub_cfgs[name]["method"] != "bulk":
# Loop over user configs and disable dgrad and wgrad bulk overlaps for every layer that has a
# reduce-scatter dgrad overlap.
ub_cfg = {} if ub_cfg is None else ub_cfg

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

local variable ub_cfg referenced before assignment.

ub_cfg --> ub_cfgs?

methods[new_method].append(name)

ub_cfg[name] = final_cfg

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ub_cfg --> ub_cfgs?

for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]:
ub_cfg = get_default_config(name)
if ub_cfgs is not None and name in ub_cfgs:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if ub_cfgs is not None and name in ub_cfgs ==> if name in ub_cfgs

fp8_buf = (name in layers_all_gather_overlap) or (
final_cfg = get_default_config(name)
final_cfg.update(ub_cfgs[name])
final_cfg["fp8_buf"] = (name in layers_all_gather_overlap) or (

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I am using 'bf16' dtype for training, and name in layers_all_gather_overlap is true for some pattern (e.g., fc2_dgrad), then final_cfg["fp8_buf"] will be set to True, which is unexpected.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay line 361 has already using use_fp8 to guard the effectiveness of fp8_buf, but it would be better if we can advance this guard when we set the default value of cfg['fp8_buf'] to avoid confusion.

Copy link
Collaborator Author

@denera denera Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't advance this guard because it's not only for all-gather overlaps. I'll run through the logic to explain.

use_fp8 tells the initialization whether the user intends to invoke comm+GEMM overlap under the te.fp8_autocast() context (i.e. with FP8 GEMM inputs).

fp8_buf tells the initialization whether the communication buffer should be allocated as FP8.

For all-gather overlaps, the buffer has to match the GEMM input type, so it will always be allocated in FP8 when use_fp8 == True (i.e. GEMM inputs are FP8) regardless of what fp8_buf is set to in the user's layer configuration. In other words, the user does not get a choice here.

For reduce-scatter overlaps, the GEMM output has to match the buffer type, which can be either FP8 or BF16 when GEMM inputs are FP8. In this scenario, setting fp8_buf = True means that we communicate FP8 data between devices, and then fuse the BF16 upcast into the sum-reduce.

Advancing this guard to the default config options means that the user is denied the option to set fp8_buf for reduce-scatter overlaps, and that RS overlaps always communicate BF16 data, which is not always the optimal choice.

On a side note, the name in method["pipeline"] part of this is oudated and needs to be removed because we support optional FP8 GEMM outputs/buffers in all reduce-scatter overlaps now, not just collective/pipeline methods.

@denera denera force-pushed the rs-dgrad-overlap-bugfix branch from aad8294 to d2d9938 Compare December 5, 2024 14:27
@denera denera force-pushed the rs-dgrad-overlap-bugfix branch from 1036168 to 1797f14 Compare December 17, 2024 20:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants