-
Notifications
You must be signed in to change notification settings - Fork 493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Finetuning Llama 3.1 8B Base Model on ChatML Format Dataset – Loss Reaches NaN After 2000 Steps #2246
Comments
Woah, this looks pretty wacky. Can you post your full training config and the recipe you're using? |
Sure, the below is the recipe i am using: 8B_full_single_device.yaml: tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
dataset:
_component_: torchtune.datasets.alpaca_dataset
source: "abdulmannan-01/rag_combined_dataset_orca_and_openscholar_alpaca_format"
split: "train"
train_on_input: false
max_seq_len: 2048
packed: false
model:
_component_: torchtune.models.llama3_1.llama3_1_8b
checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
checkpoint_files: [
model-00001-of-00004.safetensors,
model-00002-of-00004.safetensors,
model-00003-of-00004.safetensors,
model-00004-of-00004.safetensors
]
recipe_checkpoint: null
output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
model_type: LLAMA3
resume_from_checkpoint: False
batch_size: 2
epochs: 3
optimizer:
_component_: bitsandbytes.optim.PagedAdamW8bit
lr: 2e-5
loss:
_component_: torch.nn.CrossEntropyLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1
optimizer_in_bwd: True
compile: False
device: cuda
enable_activation_checkpointing: True
dtype: bf16
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
log_dir: ${output_dir}
output_dir: /tmp/full-llama3.1-finetune
log_every_n_steps: 1
log_peak_memory_stats: False |
Hi @abdul-456 thanks for creating the issue. I've run the base 8B models on the slimorca dataset a lot without seeing NaNs. I have another run going now that I will keep an eye on. One thing about the base models is that some of the special tokens are untrained, so it's possible you're seeing issues due to that. If the NaN is deterministic, can you perhaps look at the norm of different model weights just before the NaN occurs? If e.g. it's coming from the embeddings that would help validate this hypothesis. It is possible to do some custom initialization of the untrained special tokens, I will look into that in parallel. |
Hi @ebsmothers, can you share your config file that you used for finetuning base 8b models on the slimorca dataset? |
@abdul-456 I pretty much just used the base torchtune config but overrode the dataset. I also overrode the optimizer because I was hitting an (unrelated) error with the bitsandbytes optimizer. The final config can be seen here. Latest update on the run is that I made it about 40k steps and did not see any issues. |
I converted my dataset to the same format as slimorca, but the loss still reached NaN after 8000 steps on Llama 3.1 8B instruct model and on Base model, it reaches NaN after 2000 steps. |
Hey @ebsmothers , I tried finetuning llama 3.1 8B base model on slimorca dataset using your config file, but it also gives me NaN loss after 2300 steps out of 26000 total steps. Any idea why is it happening, I didn't change anything and kept everything the same, just override my config file with your config file. |
Time to potentially go deeper here - I want to see what's happening in the data itself. Can you copy the script your using to your local machine? ( |
I am finetuning it on an A6000 instance (48GB VRAM) on LambdaLabs, can this cause an issue? |
I don't think so - I just tested w/ an A6000 on RunPod and saw no issues. Did you try modifying the training loop to examine the batches where loss starts diverging? |
I found that changing the optimizer to |
I am encountering significant challenges while attempting to finetune the Llama 3.1 8B base model using Torchtune on my custom ChatML-formatted dataset:
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>
{user_question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{model_answer}<|eot_id|>
Previously, I successfully finetuned the Instruct variant of the Llama 3.1 8B model on an Alpaca-style dataset, but the resulting model produced gibberish responses. To address this, I converted my citation-based dataset into the Llama 3.1 ChatML format. Despite these adjustments, during finetuning, the training loss escalates to NaN after approximately 2000 steps, rendering the finetuning process ineffective.
Below are excerpts from the training logs where the loss becomes NaN:
Step 4363 | loss:14433.9736328125 lr:2e-05 tokens_per_second_per_gpu:309.3285827636719 peak_memory_active:18.35957908630371 peak_memory_alloc:18.35957908630371 peak_memory_reserved:19.521484375
...
Step 4398 | loss:nan lr:2e-05 tokens_per_second_per_gpu:212.1794891357422 peak_memory_active:18.29242515563965 peak_memory_alloc:18.29242515563965 peak_memory_reserved:19.521484375
Step 4399 | loss:nan lr:2e-05 tokens_per_second_per_gpu:249.82809448242188 peak_memory_active:18.321777820587158 peak_memory_alloc:18.321777820587158 peak_memory_reserved:19.521484375
The text was updated successfully, but these errors were encountered: