Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP8 Support for MCore MoE #648

Merged
merged 4 commits into from
Apr 29, 2024
Merged

FP8 Support for MCore MoE #648

merged 4 commits into from
Apr 29, 2024

Conversation

Victarry
Copy link
Contributor

@Victarry Victarry commented Jan 31, 2024

Add FP8 support for MoE in MCore.

Related MR in MCore:
https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/merge_requests/1089

Implementation details:

  • Add rng_tracker_name for initialize with EP
  • Add if statement in kernel to handle zeros tokens passed to an expert.

@Victarry Victarry force-pushed the denliu/moe_fp8 branch 2 times, most recently from 9b520c6 to 8df68fc Compare March 6, 2024 05:36
@Victarry Victarry marked this pull request as ready for review March 6, 2024 08:10
@Victarry Victarry force-pushed the denliu/moe_fp8 branch 2 times, most recently from 42f28d3 to 8e03976 Compare March 6, 2024 16:10
@ptrendx
Copy link
Member

ptrendx commented Mar 7, 2024

I don't like the fact that the layers need to know that they are experts. Can't it be abstracted in some way using the options that we already have or add options that are more generic?

@Victarry
Copy link
Contributor Author

Victarry commented Mar 8, 2024

I see. Good advice.

I remove the is_expert flag and use explicit_parallel_comm to indicate the communications are handled outside the te.Linear.

@Victarry
Copy link
Contributor Author

Hi @ptrendx, could you please continue the review and share your comments.

To make sure this feature can be included in MCore v0.6, I think it's better to merge this MR this week.

@ptrendx
Copy link
Member

ptrendx commented Mar 13, 2024

So, to be honest, I don't quite understand why we need that communication flag at all. MCore should be able to just call te.Linear without setting row or column parallelism to the same effect, no? And then we would not need any special flag on the TE side?

Also, you added this rng tracker name option, but did not document it.

Handling of the zero token case I think is fine.

@Victarry
Copy link
Contributor Author

Added documentation. Thanks.

@ptrendx
Copy link
Member

ptrendx commented Mar 18, 2024

/te-ci pytorch

Copy link
Member

@ptrendx ptrendx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes LGTM. @Victarry, please add a test that runs the Linear layer with empty input to shw that it works and then we will be able to merge.

@Victarry
Copy link
Contributor Author

Hi @ptrendx, I have added the new unittest for Linear layer with empty input.

Since I didn't find an appropriate file to place the new testcase, I created a new test file named test_linear_layer.py.
Please tell me if you have any other suggentions.

@ptrendx
Copy link
Member

ptrendx commented Mar 25, 2024

I would put it in test_sanity.py. It looks good, but please also add some check in it - like checking that the batch size of the output of the linear layer is the same as the input (so if 0 gets passed as input, 0 is also provided as the output).

@Victarry
Copy link
Contributor Author

Done! Thanks for you advice. 👍🏻

@Victarry
Copy link
Contributor Author

Victarry commented Apr 7, 2024

Hi, @ptrendx, can this MR be merged now?

@ptrendx
Copy link
Member

ptrendx commented Apr 7, 2024

/te-ci pytorch

@Victarry
Copy link
Contributor Author

Victarry commented Apr 7, 2024

Hi, @ptrendx, I just fixed the UT error in CI, could you please trigger the ci again?

@ptrendx
Copy link
Member

ptrendx commented Apr 7, 2024

/te-ci pytorch

@Victarry
Copy link
Contributor Author

Victarry commented Apr 8, 2024

Hi, @ptrendx, the CI pipeline is passed, could you please merge this MR?
Thanks a lot!

@ptrendx
Copy link
Member

ptrendx commented Apr 8, 2024

Hi @Victarry, we are trying to minimize the changes going into 1.6 release so will merge that PR after 1.6 branch is created.

Signed-off-by: Dennis Liu <[email protected]>
@ptrendx
Copy link
Member

ptrendx commented Apr 16, 2024

Hi @Victarry, now that the 1.6 branch is created, could you resolve conflicts in your PR? Then we will be able to merge it.

@Victarry
Copy link
Contributor Author

Hi, @ptrendx, I just resolved the conflicts, please merge this PR. Thanks!

@ptrendx
Copy link
Member

ptrendx commented Apr 17, 2024

/te-ci pytorch

@Victarry
Copy link
Contributor Author

Hi @ptrendx, I guess the CI failure is due to other code change in main branch. Could you please trigger the pytorch CI again?

@ptrendx
Copy link
Member

ptrendx commented Apr 25, 2024

/te-ci pytorch

@viclzhu
Copy link

viclzhu commented Apr 25, 2024

Hi @Victarry, I'm wondering when the mcore related changes will be available on the public mcore repository. Or if it's already available, could you point me to the relevant changes or PR?
It sounds like it will likely be available on mcore-0.7, but I couldn't seem to find the changes yet on the beta-0.7 branch.

Thanks!

@Victarry
Copy link
Contributor Author

Hi, @viclzhu, the mcore related change is planed to be published before the end of May.
Thanks.

@Victarry
Copy link
Contributor Author

Hi @ptrendx, I found that the UT only failed on L40, but I'm not sure why does this happen.

Do you have any insights?

@ptrendx
Copy link
Member

ptrendx commented Apr 29, 2024

Hi Victarry - I checked and this failure is unrelated to this PR, so I believe it is safe to merge.

@ptrendx ptrendx merged commit 32d1eb1 into NVIDIA:main Apr 29, 2024
19 of 20 checks passed
phu0ngng pushed a commit to phu0ngng/TransformerEngine that referenced this pull request May 3, 2024
* Add support for MoE with FP8.

Signed-off-by: Dennis Liu <[email protected]>

* Fix unittest.

Signed-off-by: Dennis Liu <[email protected]>

* Fix error in linear backward.

Signed-off-by: Dennis Liu <[email protected]>

---------

Signed-off-by: Dennis Liu <[email protected]>
Co-authored-by: Przemyslaw Tredak <[email protected]>
pggPL pushed a commit to pggPL/TransformerEngine that referenced this pull request May 15, 2024
* Add support for MoE with FP8.

Signed-off-by: Dennis Liu <[email protected]>

* Fix unittest.

Signed-off-by: Dennis Liu <[email protected]>

* Fix error in linear backward.

Signed-off-by: Dennis Liu <[email protected]>

---------

Signed-off-by: Dennis Liu <[email protected]>
Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
pggPL pushed a commit to pggPL/TransformerEngine that referenced this pull request May 16, 2024
* Add support for MoE with FP8.

Signed-off-by: Dennis Liu <[email protected]>

* Fix unittest.

Signed-off-by: Dennis Liu <[email protected]>

* Fix error in linear backward.

Signed-off-by: Dennis Liu <[email protected]>

---------

Signed-off-by: Dennis Liu <[email protected]>
Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
pggPL pushed a commit to pggPL/TransformerEngine that referenced this pull request May 23, 2024
* Add support for MoE with FP8.

Signed-off-by: Dennis Liu <[email protected]>

* Fix unittest.

Signed-off-by: Dennis Liu <[email protected]>

* Fix error in linear backward.

Signed-off-by: Dennis Liu <[email protected]>

---------

Signed-off-by: Dennis Liu <[email protected]>
Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
@@ -335,6 +338,8 @@ at::Tensor fp8_transpose(at::Tensor input,

size_t M = static_cast<size_t>(input.size(0));
size_t N = static_cast<size_t>(input.size(1));
if (M == 0 || N == 0)
return input;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Victarry This will cause shape mismatch error between wgrad and weight when gradient accumulation fusion is disabled.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants