Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cuda graph capture for grouped gemm #1345

Merged
merged 6 commits into from
Nov 27, 2024

Conversation

xrennvidia
Copy link
Collaborator

@xrennvidia xrennvidia commented Nov 21, 2024

Description

Cuda graph does not work with Grouped GEMM.
The saved forward activations are corrupted before bwd_graph is replayed. Explicitly setting retain_graph=True can hold the activations and fix the issue.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@xrennvidia
Copy link
Collaborator Author

/te-ci pytorch

@xrennvidia xrennvidia requested a review from timmoon10 November 22, 2024 02:01
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't we expect this to increase memory usage?

I see that torch.cuda.make_graphed_callables doesn't set retain_graph=True:
https://github.com/pytorch/pytorch/blob/c25b201583fc28243b87c460a2f18e2531a676e7/torch/cuda/graphs.py#L326-L336
We want to match plain PyTorch as much as possible unless there is a good reason to introduce divergence. If this is MoE-specific, perhaps we could add a kwarg like retain_graph_in_backward that is False by default.

@xrennvidia
Copy link
Collaborator Author

/te-ci pytorch

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Signed-off-by: Xiaowei Ren <[email protected]>
@xrennvidia
Copy link
Collaborator Author

/te-ci pytorch

@xrennvidia xrennvidia merged commit a132ac4 into NVIDIA:main Nov 27, 2024
14 of 15 checks passed
@xrennvidia xrennvidia deleted the xren/cg_fix_grouped_gemm branch November 29, 2024 00:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants