Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pytorch] LayerNormMLP seems to causing grad norm explosion under multi-node #709

Closed
tylaar opened this issue Mar 10, 2024 · 2 comments
Closed
Labels
bug Something isn't working needinfo

Comments

@tylaar
Copy link

tylaar commented Mar 10, 2024

Hello there. I've noticed something weird during my pre-training process.

Setup:
TransformerEngine 1.3.0 + Megatron-LM
TP=1 PP=1, only using DP
Structure: LLaMa like LayerNormMLP (swiglu as activation and RMSNorm)

When I train with one host + 8 H100, the training grad norm seems to be okay, like shrinking down in a expected pacing;
However when I add more hosts (each with 8 H100), the grad norm starts to blow up (like in the first 50 step which is normal ~3.4, but starting to blow up in during 50-100 and 100-200), finally blow into 1000 ...

I've tried to set fp8_wgrad=True, plusing setting DelayScheduling amax_history window into 2048 instead of 1024, this somehow mitigated the issue in a way, but my questions are:

  1. Is that multi-host based training with grad norm exploding known issue? which requires to expand amax_history and set wgrad to true?
  2. Is there any performance degrade going to be expected by setting fp_wgrad=True and amax_history to 2048?

Thanks a lot!

@tylaar tylaar changed the title LayerNormMLP seems to causing grad norm explosion under multi-node [Pytorch] LayerNormMLP seems to causing grad norm explosion under multi-node Mar 10, 2024
@ptrendx
Copy link
Member

ptrendx commented Mar 11, 2024

Hi @tylaar. No, that is definitely not expected and you should not have to change the values of the parameters (fp8_wgrad should be True by default actually). Could you provide us with the information how to reproduce the issue you are seeing?

@tylaar
Copy link
Author

tylaar commented Jun 11, 2024

Hi @ptrendx, sorry for taking so long to reply this thread. I found that after apply all components (like activation swiglu and RMSNorm) by using TE implementation + fp8_wgrad = True + fp8_amax_history_len = 1024 resolved my production issue.
Some takeaway experience from my side is, even though there are some bitwise not aligned issue like we mentioned in #717, TE impl of swiglu and RMSNorm is still reliable and problemless than previous implementation comparing to our in-house impl.

I will close this now, and if I found some other issue I will file another issue. Thanks a lot!

@tylaar tylaar closed this as completed Jun 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needinfo
Projects
None yet
Development

No branches or pull requests

2 participants