Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qlora uses more memory than regular lora #2255

Open
AndrewMead10 opened this issue Jan 11, 2025 · 11 comments
Open

Qlora uses more memory than regular lora #2255

AndrewMead10 opened this issue Jan 11, 2025 · 11 comments
Assignees
Labels
triaged This issue has been assigned an owner and appropriate label

Comments

@AndrewMead10
Copy link
Contributor

AndrewMead10 commented Jan 11, 2025

I wanted to compare Lora and Qlora finetuning for Llama 1B, but i found that qlora was using more memory than lora was.

Here is the wandb report with the logs

Here is my config, only diff between the runs is changing

_component_: torchtune.models.llama3_2.qlora_llama3_2_1b

to

_component_: torchtune.models.llama3_2.lora_llama3_2_1b

# using a Llama3.2 1B Instruct model
#
# This config assumes that you've run the following command before launching
# this run:
#   tune download meta-llama/Llama-3.2-1B-Instruct --output-dir ./tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth"
#
# To launch on a single device, run the following command from root:
#   tune run lora_finetune_single_device --config llama3_2/1B_lora_single_device
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
#   tune run lora_finetune_single_device --config llama3_2/1B_lora_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works only for training on single device.

output_dir: ./tmp/torchtune/llama3_2_1B/lora_single_device # /tmp may be deleted by your system. Change it to your preference.

# Model Arguments
model:
  _component_: torchtune.models.llama3_2.qlora_llama3_2_1b
  lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
  apply_lora_to_mlp: True
  lora_rank: 32  # higher increases accuracy and memory
  lora_alpha: 64  # usually alpha=2*rank
  lora_dropout: 0.0
  use_dora: False

# Tokenizer
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: ./tmp/Llama-3.2-1B/original/tokenizer.model
  max_seq_len: 2048

checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: ./tmp/Llama-3.2-1B/
  checkpoint_files: [
     model.safetensors
  ]
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: LLAMA3_2
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
  _component_: torchtune.datasets.chat_dataset
  packed: True
  source: qnguyen3/orca_math_10k
  conversation_column: conversations
  conversation_style: sharegpt
  split: train
seed: 42
shuffle: True
batch_size: 4

# Optimizer and Scheduler
optimizer:
  _component_: torch.optim.AdamW
  lr: 1e-4
  weight_decay: 1e-2
  fused: True
lr_scheduler:
  _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
  num_warmup_steps: 100

loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Training
epochs: 3
max_steps_per_epoch: null
gradient_accumulation_steps: 2  # Use to increase effective batch size
compile: True  # torch.compile the model + loss, True increases speed + decreases memory

# Logging
metric_logger:
  _component_: torchtune.training.metric_logging.WandBLogger
  project: benchmark_torchtune
log_every_n_steps: 1
log_peak_memory_stats: True

# Environment
device: cuda
dtype: bf16

# Activations Memory
enable_activation_checkpointing: False  # True reduces memory
enable_activation_offloading: False  # True reduces memory


# Profiler (disabled)
profiler:
  _component_: torchtune.training.setup_torch_profiler
  enabled: False

  #Output directory of trace artifacts
  output_dir: ${output_dir}/profiling_outputs

  #`torch.profiler.ProfilerActivity` types to trace
  cpu: True
  cuda: True

  #trace options passed to `torch.profiler.profile`
  profile_memory: False
  with_stack: False
  record_shapes: True
  with_flops: False

  # `torch.profiler.schedule` options:
  # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
  wait_steps: 5
  warmup_steps: 3
  active_steps: 2
  num_cycles: 1
@ebsmothers
Copy link
Contributor

Hi @AndrewMead10 thanks for creating the issue. At least on our default Llama 3.2 1B config I do not see this. E.g. if I run

tune run lora_finetune_single_device --config llama3_2/1B_lora_single_device \
metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=test-2255 metric_logger.name=lora

vs

tune run lora_finetune_single_device --config llama3_2/1B_lora_single_device \
metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=test-2255 \
metric_logger.name=qlora model=torchtune.models.llama3_2.qlora_llama3_2_1b

I see the following:

Image

Let me try on your config to see if I can reproduce. One general comment is that for the 1B model proportionally more of the memory will be taken up by the tied embedding/output layer and the activations (vs the memory of the layers you're applying LoRA/QLoRA too), so in general I would expect the memory savings of QLoRA to be a bit smaller (though certainly not negative).

@felipemello1
Copy link
Contributor

since @ebsmothers was not able to reproduce it, It could be that you had some dead process that was taking up some memory, making it look like QLoRA requires more. Looking at peak active memory should help, but i don't see it in the plot.

Maybe before re-running qlora, run nvidia-smi to confirm that no memory is being used.

@AndrewMead10
Copy link
Contributor Author

Hi @AndrewMead10 thanks for creating the issue. At least on our default Llama 3.2 1B config I do not see this. E.g. if I run

tune run lora_finetune_single_device --config llama3_2/1B_lora_single_device \
metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=test-2255 metric_logger.name=lora

vs

tune run lora_finetune_single_device --config llama3_2/1B_lora_single_device \
metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=test-2255 \
metric_logger.name=qlora model=torchtune.models.llama3_2.qlora_llama3_2_1b

I see the following:

Image Let me try on your config to see if I can reproduce. One general comment is that for the 1B model proportionally more of the memory will be taken up by the tied embedding/output layer and the activations (vs the memory of the layers you're applying LoRA/QLoRA too), so in general I would expect the memory savings of QLoRA to be a bit smaller (though certainly not negative).

since @ebsmothers was not able to reproduce it, It could be that you had some dead process that was taking up some memory, making it look like QLoRA requires more. Looking at peak active memory should help, but i don't see it in the plot.

Maybe before re-running qlora, run nvidia-smi to confirm that no memory is being used.

I did some more tests, making sure no other proc were using gpu memory, and it seems that when using torch.compile is the reason that the memory is higher. When compile=False, qlora uses less memory than lora as expected, but when compile=True the memory usage switches, with lora using less(!) than qlora.

Here is the wandb report showing it

@AndrewMead10
Copy link
Contributor Author

I cant test 3b or higher rn, but my guess is that this may just be a quirk of the small size of the 1b model? @felipemello1 @ebsmothers

@ebsmothers
Copy link
Contributor

Thanks @AndrewMead10 this is an interesting finding. I do see something similar: compiled LoRA has lower allocated memory than compiled QLoRA. Really compiling QLoRA seems to not yield any memory savings (at least under default mode).

Image

However, the perf improvement of compiling QLoRA is much greater. It's ~40% slower in the uncompiled version, but with the compiled version the gap is <10%. I'm pleasantly surprised it gets so close since QLoRA has the extra NF4 -> bf16 ops that LoRA does not have.

Image

Since we were just talking about tensor subclasses + compile in torchtune, cc @bdhirsh in case you have any thoughts on why memory savings with NF4Tensor would be minimal.

@bdhirsh
Copy link

bdhirsh commented Jan 13, 2025

If qLoRA + compile is giving higher peak memory compared to qLoRA + eager (aka compile is "removing" some of the memory savings you expect from qLoRA), that sounds like a bug / worth investigating.

Are you able to get a memory profile snapshot of both the eager and compiled runs? That would probably tell us a lot more. There are some nice instructions at https://pytorch.org/docs/stable/torch_cuda_memory.html.

One example that I know of where compile can give worse peak memory than eager, although I'm not sure if it applies to this case: if you implement your own version of careful checkpointing in a custom autograd.Function, the compiler is free today to change what you've saved for backward if it thinks it can yield better perf. This only really applies in cases where you are doing interesting saving decisions with custom autograd.Functions though.

@ebsmothers
Copy link
Contributor

Thanks @bdhirsh! To clarify, we don't actually see higher peak memory for compiled QLoRA than we do with eager QLoRA. We just don't see any memory savings, which leads to the unexpected behavior observed by @AndrewMead10: compiled QLoRA peak memory exceeds compiled LoRA peak memory (since compiled LoRA has nice memory savings vs eager). So there may not be an obvious bug here, just a counterintuitive result. But agree that maybe the best next step for further investigation is to gather memory profiles, we can follow up after that.

@felipemello1
Copy link
Contributor

IMO, we shouldn't over index on 1B. Maybe we can run it for 8B and check the behavior?

@AndrewMead10
Copy link
Contributor Author

@felipemello1 @ebsmothers

I ran some tests on llama 3B and 8B:

for 3B:

  • For qlora, we see that the compiled version uses more memory than the uncompiled version
  • when not compiled, qlora uses less memory than lora, as expected
  • when compiled, qlora uses more memory than lora

for 8B*:

  • everything is normal
  • compiled qlora uses less memory than uncompiled
  • qlora uses less memory than lora

* I had to use activation checkpointing to run the training for llama 8b, since my 3090 didnt have enough vram otherwise. Also for both 3B and 8B i reduced the BS to 1, otherwise the config is the same as the 1B one

1B (for completeness):

  • qlora used less memory when compiled
  • qlora used more memory than lora when compiled

3B report

8B report

1B report

@joecummings joecummings added the triaged This issue has been assigned an owner and appropriate label label Jan 14, 2025
@ebsmothers
Copy link
Contributor

Hey @AndrewMead10 thanks for your patience on this one, just getting back to it now. The 3B results are the ones that would concern me, as Brian mentioned it's generally unexpected to have memory increase with compile enabled. However, I ran myself (on A100, not 3090) and do not see the same results.

Image

The commands I ran are just

tune run lora_finetune_single_device --config llama3_2/3B_qlora_single_device 

and

tune run lora_finetune_single_device --config llama3_2/3B_qlora_single_device compile=True

@bdhirsh is your previous statement about memory increases hardware-dependent in any way? (My assumption would be no, but just want to confirm.) I can also try the same on 3090 once I get a bit more time.

@bdhirsh
Copy link

bdhirsh commented Jan 16, 2025

@bdhirsh is your previous statement about memory increases hardware-dependent in any way? (My assumption would be no, but just want to confirm.)

hmm no - off the top of my head, I can't think of any obvious reasons we would see "compile gives you higher peak memory than eager" on one cuda hardware but not another. If that's true than a repro/memory profile would be interesting to see

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged This issue has been assigned an owner and appropriate label
Projects
None yet
Development

No branches or pull requests

5 participants