-
Notifications
You must be signed in to change notification settings - Fork 493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Very slow convergence with bf16 #2254
Comments
Hey @EugenHotaj, thanks for flagging this! I can take a look at it this week. To unblock you, if you have time, can you run it with bf16 and try NOT compiling the loss? You can comment it out in the recipe or set compile = false (but that also disables it for the model) I have a vague memory of someone saying that that CE + compile skips upcasting to fp32. |
@felipemello1 thanks for taking a look! I'm not using compile right now, just running the default llama3_1 config with a different dataset. I could try turning on compile (with the change you suggested) to see if it makes a difference. |
Compile (skipping the loss) didn't seem to help. |
I tried manually hacking the Adam optimizer to keep moments in One thing I discovered is that FSDP2 allows you to keep your weights in |
@felipemello1 any luck on this? |
hey @EugenHotaj , just got time to start it today.
If it doesnt happen in HF, we know its a torchtune issue. If it happens there, then I wonder if our premise is correct that bf16 convergence == fp32 convergence |
Single device llama 3b, no difference. Bf16 actually seems better
gonna try to reproduce it using distributed |
@felipemello1 one thing I would test is setting |
For distributed, no large changes. Let me take a look at grad accumulation
|
used single_device, grad_acc=4, bsz=1, nothing crazy :/
same for distributed, but I can run for longe to confirm
|
can you reproduce it with some public dataset and share the config with me? I can test it with 8xA100 |
Hmm that's definitely surprising. Are you able to try with the following config (but your own dataset):
If you still don't see any issues I can try to repro with a public dataset.
|
maybe alpaca dataset is too easy? any suggestions of a harder public dataset? |
@felipemello1 on second thought, even in your graphs we already see loss divergence between |
From @gau-nernst
@EugenHotaj , wanna give the stochastic rounding a try? (only works with torch nightlies) |
@felipemello1 thanks for the pointer! Just had a chance to try it out today but seems to make a significant difference: These were just some quick 1-node runs but looks pretty promising. I'll kick off some real runs and report back. When would this make it to stable if we want to use it? The implementation doesn't look too difficult but there seems to be some ops missing in Also, I wasn't able to run the original code and kept getting the following failure here (cc @gau-nernst):
Not sure if it's a bug or just something off with my setup but the following fixed it: |
@EugenHotaj Stochastic-rounding Adam/AdamW + FSDP2 requires torch nightly pytorch/ao#1505 I'm curious that the curves look quite different. I haven't tried SR AdamW in fine-tuning setting much, so not sure if that's to be expected. Do you use the same hparams for all settings? My uneducated guess is that you seem to use a higher LR for BF16-SR. |
@gau-nernst the issue I posted above seem slightly different from pytorch/ao#1505. I actually hit pytorch/ao#1505 as well but it was fixed by using nightlies as you mentioned. Not sure if I'm on the wrong version though, I basically just did:
The hparams are identical, only setting I'm changing is |
We've noticed very bad convergence when training in
bf16
vsfp32
.As a comparison, here are the loss curves between
bf16
:and
fp32
:This is a full finetune of 8B llama running on 8 nodes (64 GPUS) but the issue exists even on 1 node (8 GPUS). The runs are identical besides the
dtype
. Notice that even after 250 steps thebf16
run does not go below0.7
loss. In theory, it should be possible to get similar convergence rates with eitherdtype
(at least I think there are multiple existence proofs inside Meta 😛).One thing I tried doing was setting FSDP's
reduce_dtype=fp32
(had to hardcode because torchtune doesn't expose this option AFAICT) but it did not seem to help much. Any other options we should be looking into?Need to confirm this but I think one thing that would greatly help is to keep optimizer states in
fp32
. It would use a lot more memory than end-to-endbf16
but at least it would not slow down training as much as doing everything infp32
. Is there an easy way to do this in torchtune/pytorch? Would doing something like below work?The text was updated successfully, but these errors were encountered: