Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finetuning Llama 3.1 8B Base Model on ChatML Format Dataset – Loss Reaches NaN After 2000 Steps #2246

Open
abdul-456 opened this issue Jan 10, 2025 · 11 comments
Assignees
Labels
triaged This issue has been assigned an owner and appropriate label

Comments

@abdul-456
Copy link

abdul-456 commented Jan 10, 2025

I am encountering significant challenges while attempting to finetune the Llama 3.1 8B base model using Torchtune on my custom ChatML-formatted dataset:

<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>
{user_question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{model_answer}<|eot_id|>

Previously, I successfully finetuned the Instruct variant of the Llama 3.1 8B model on an Alpaca-style dataset, but the resulting model produced gibberish responses. To address this, I converted my citation-based dataset into the Llama 3.1 ChatML format. Despite these adjustments, during finetuning, the training loss escalates to NaN after approximately 2000 steps, rendering the finetuning process ineffective.

Below are excerpts from the training logs where the loss becomes NaN:

Step 4363 | loss:14433.9736328125 lr:2e-05 tokens_per_second_per_gpu:309.3285827636719 peak_memory_active:18.35957908630371 peak_memory_alloc:18.35957908630371 peak_memory_reserved:19.521484375
...
Step 4398 | loss:nan lr:2e-05 tokens_per_second_per_gpu:212.1794891357422 peak_memory_active:18.29242515563965 peak_memory_alloc:18.29242515563965 peak_memory_reserved:19.521484375
Step 4399 | loss:nan lr:2e-05 tokens_per_second_per_gpu:249.82809448242188 peak_memory_active:18.321777820587158 peak_memory_alloc:18.321777820587158 peak_memory_reserved:19.521484375

@joecummings
Copy link
Contributor

Woah, this looks pretty wacky. Can you post your full training config and the recipe you're using?

@abdul-456
Copy link
Author

abdul-456 commented Jan 10, 2025

Sure, the below is the recipe i am using:
tune run full_finetune_single_device --config llama3_1/8B_full_single_device

8B_full_single_device.yaml:

tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model

dataset:
  _component_: torchtune.datasets.alpaca_dataset
  source: "abdulmannan-01/rag_combined_dataset_orca_and_openscholar_alpaca_format"
  split: "train"
  train_on_input: false
  max_seq_len: 2048
  packed: false



model:
  _component_: torchtune.models.llama3_1.llama3_1_8b

checkpointer:
  _component_: torchtune.utils.FullModelHFCheckpointer
  checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
  checkpoint_files: [
    model-00001-of-00004.safetensors,
    model-00002-of-00004.safetensors,
    model-00003-of-00004.safetensors,
    model-00004-of-00004.safetensors
  ]
  recipe_checkpoint: null
  output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
  model_type: LLAMA3
resume_from_checkpoint: False


batch_size: 2
epochs: 3
optimizer:
  _component_: bitsandbytes.optim.PagedAdamW8bit
  lr: 2e-5
loss:
  _component_: torch.nn.CrossEntropyLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1
optimizer_in_bwd: True
compile: False


device: cuda


enable_activation_checkpointing: True


dtype: bf16


metric_logger:
  _component_: torchtune.utils.metric_logging.DiskLogger
  log_dir: ${output_dir}
output_dir: /tmp/full-llama3.1-finetune
log_every_n_steps: 1
log_peak_memory_stats: False

@joecummings joecummings added the triaged This issue has been assigned an owner and appropriate label label Jan 10, 2025
@ebsmothers
Copy link
Contributor

Hi @abdul-456 thanks for creating the issue. I've run the base 8B models on the slimorca dataset a lot without seeing NaNs. I have another run going now that I will keep an eye on. One thing about the base models is that some of the special tokens are untrained, so it's possible you're seeing issues due to that. If the NaN is deterministic, can you perhaps look at the norm of different model weights just before the NaN occurs? If e.g. it's coming from the embeddings that would help validate this hypothesis. It is possible to do some custom initialization of the untrained special tokens, I will look into that in parallel.

@abdul-456
Copy link
Author

abdul-456 commented Jan 11, 2025

Hi @ebsmothers, can you share your config file that you used for finetuning base 8b models on the slimorca dataset?

@ebsmothers
Copy link
Contributor

@abdul-456 I pretty much just used the base torchtune config but overrode the dataset. I also overrode the optimizer because I was hitting an (unrelated) error with the bitsandbytes optimizer. The final config can be seen here. Latest update on the run is that I made it about 40k steps and did not see any issues.

@abdul-456
Copy link
Author

I converted my dataset to the same format as slimorca, but the loss still reached NaN after 8000 steps on Llama 3.1 8B instruct model and on Base model, it reaches NaN after 2000 steps.

@abdul-456
Copy link
Author

Hey @ebsmothers , I tried finetuning llama 3.1 8B base model on slimorca dataset using your config file, but it also gives me NaN loss after 2300 steps out of 26000 total steps. Any idea why is it happening, I didn't change anything and kept everything the same, just override my config file with your config file.

@joecummings
Copy link
Contributor

Hey @ebsmothers , I tried finetuning llama 3.1 8B base model on slimorca dataset using your config file, but it also gives me NaN loss after 2300 steps out of 26000 total steps. Any idea why is it happening, I didn't change anything and kept everything the same, just override my config file with your config file.

Time to potentially go deeper here - I want to see what's happening in the data itself. Can you copy the script your using to your local machine? (tune cp full_finetune_single_device .) Then you can modify the training loop to manually examine the batches where the loss starts peaking. Hopefully it should be obvious that something is messed up, but feel free to share with us, too.

@abdul-456
Copy link
Author

I am finetuning it on an A6000 instance (48GB VRAM) on LambdaLabs, can this cause an issue?

@joecummings
Copy link
Contributor

I am finetuning it on an A6000 instance (48GB VRAM) on LambdaLabs, can this cause an issue?

I don't think so - I just tested w/ an A6000 on RunPod and saw no issues. Did you try modifying the training loop to examine the batches where loss starts diverging?

@vasrap
Copy link

vasrap commented Jan 17, 2025

I found that changing the optimizer to torch.optim.AdamW solves this issue for me. But I am unsure why this issue exists to start with.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged This issue has been assigned an owner and appropriate label
Projects
None yet
Development

No branches or pull requests

5 participants