-
Notifications
You must be signed in to change notification settings - Fork 510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Qlora uses more memory than regular lora #2255
Comments
Hi @AndrewMead10 thanks for creating the issue. At least on our default Llama 3.2 1B config I do not see this. E.g. if I run
vs
I see the following: Let me try on your config to see if I can reproduce. One general comment is that for the 1B model proportionally more of the memory will be taken up by the tied embedding/output layer and the activations (vs the memory of the layers you're applying LoRA/QLoRA too), so in general I would expect the memory savings of QLoRA to be a bit smaller (though certainly not negative). |
since @ebsmothers was not able to reproduce it, It could be that you had some dead process that was taking up some memory, making it look like QLoRA requires more. Looking at peak active memory should help, but i don't see it in the plot. Maybe before re-running qlora, run |
I did some more tests, making sure no other proc were using gpu memory, and it seems that when using torch.compile is the reason that the memory is higher. When compile=False, qlora uses less memory than lora as expected, but when compile=True the memory usage switches, with lora using less(!) than qlora. |
I cant test 3b or higher rn, but my guess is that this may just be a quirk of the small size of the 1b model? @felipemello1 @ebsmothers |
Thanks @AndrewMead10 this is an interesting finding. I do see something similar: compiled LoRA has lower allocated memory than compiled QLoRA. Really compiling QLoRA seems to not yield any memory savings (at least under default mode). However, the perf improvement of compiling QLoRA is much greater. It's ~40% slower in the uncompiled version, but with the compiled version the gap is <10%. I'm pleasantly surprised it gets so close since QLoRA has the extra NF4 -> bf16 ops that LoRA does not have. Since we were just talking about tensor subclasses + compile in torchtune, cc @bdhirsh in case you have any thoughts on why memory savings with |
If qLoRA + compile is giving higher peak memory compared to qLoRA + eager (aka compile is "removing" some of the memory savings you expect from qLoRA), that sounds like a bug / worth investigating. Are you able to get a memory profile snapshot of both the eager and compiled runs? That would probably tell us a lot more. There are some nice instructions at https://pytorch.org/docs/stable/torch_cuda_memory.html. One example that I know of where compile can give worse peak memory than eager, although I'm not sure if it applies to this case: if you implement your own version of careful checkpointing in a custom |
Thanks @bdhirsh! To clarify, we don't actually see higher peak memory for compiled QLoRA than we do with eager QLoRA. We just don't see any memory savings, which leads to the unexpected behavior observed by @AndrewMead10: compiled QLoRA peak memory exceeds compiled LoRA peak memory (since compiled LoRA has nice memory savings vs eager). So there may not be an obvious bug here, just a counterintuitive result. But agree that maybe the best next step for further investigation is to gather memory profiles, we can follow up after that. |
IMO, we shouldn't over index on 1B. Maybe we can run it for 8B and check the behavior? |
I ran some tests on llama 3B and 8B: for 3B:
for 8B*:
* I had to use activation checkpointing to run the training for llama 8b, since my 3090 didnt have enough vram otherwise. Also for both 3B and 8B i reduced the BS to 1, otherwise the config is the same as the 1B one 1B (for completeness):
|
Hey @AndrewMead10 thanks for your patience on this one, just getting back to it now. The 3B results are the ones that would concern me, as Brian mentioned it's generally unexpected to have memory increase with compile enabled. However, I ran myself (on A100, not 3090) and do not see the same results. The commands I ran are just
and
@bdhirsh is your previous statement about memory increases hardware-dependent in any way? (My assumption would be no, but just want to confirm.) I can also try the same on 3090 once I get a bit more time. |
hmm no - off the top of my head, I can't think of any obvious reasons we would see "compile gives you higher peak memory than eager" on one cuda hardware but not another. If that's true than a repro/memory profile would be interesting to see |
I wanted to compare Lora and Qlora finetuning for Llama 1B, but i found that qlora was using more memory than lora was.
Here is the wandb report with the logs
Here is my config, only diff between the runs is changing
_component_: torchtune.models.llama3_2.qlora_llama3_2_1b
to
_component_: torchtune.models.llama3_2.lora_llama3_2_1b
The text was updated successfully, but these errors were encountered: