Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Very slow convergence with bf16 #2254

Open
EugenHotaj opened this issue Jan 11, 2025 · 18 comments
Open

Very slow convergence with bf16 #2254

EugenHotaj opened this issue Jan 11, 2025 · 18 comments
Assignees

Comments

@EugenHotaj
Copy link
Contributor

EugenHotaj commented Jan 11, 2025

We've noticed very bad convergence when training in bf16 vs fp32.

As a comparison, here are the loss curves between bf16:
Screenshot 2025-01-11 at 2 07 33 PM

and fp32:
Screenshot 2025-01-11 at 2 08 40 PM

This is a full finetune of 8B llama running on 8 nodes (64 GPUS) but the issue exists even on 1 node (8 GPUS). The runs are identical besides the dtype. Notice that even after 250 steps the bf16 run does not go below 0.7 loss. In theory, it should be possible to get similar convergence rates with either dtype (at least I think there are multiple existence proofs inside Meta 😛).

One thing I tried doing was setting FSDP's reduce_dtype=fp32 (had to hardcode because torchtune doesn't expose this option AFAICT) but it did not seem to help much. Any other options we should be looking into?

Need to confirm this but I think one thing that would greatly help is to keep optimizer states in fp32. It would use a lot more memory than end-to-end bf16 but at least it would not slow down training as much as doing everything in fp32. Is there an easy way to do this in torchtune/pytorch? Would doing something like below work?

model = create_model(dtype=fp32)
optimizer(model.parameters())
model.to(bf16)
@EugenHotaj EugenHotaj changed the title Very bad convergence with bf16 Very slow convergence with bf16 Jan 12, 2025
@felipemello1
Copy link
Contributor

felipemello1 commented Jan 12, 2025

Hey @EugenHotaj, thanks for flagging this! I can take a look at it this week. To unblock you, if you have time, can you run it with bf16 and try NOT compiling the loss? You can comment it out in the recipe or set compile = false (but that also disables it for the model)

I have a vague memory of someone saying that that CE + compile skips upcasting to fp32.

@EugenHotaj
Copy link
Contributor Author

@felipemello1 thanks for taking a look! I'm not using compile right now, just running the default llama3_1 config with a different dataset. I could try turning on compile (with the change you suggested) to see if it makes a difference.

@EugenHotaj
Copy link
Contributor Author

EugenHotaj commented Jan 13, 2025

Compile (skipping the loss) didn't seem to help.

@EugenHotaj
Copy link
Contributor Author

I tried manually hacking the Adam optimizer to keep moments in fp32 but didn't help unfortunately.

One thing I discovered is that FSDP2 allows you to keep your weights in fp32 and do the compute in bf16 (here -- might be good to expose these params in the torchtune config btw). Good news is that training is basically as fast as using bf16 again now, but we're still using 2x the memory unfortunately.

@EugenHotaj
Copy link
Contributor Author

@felipemello1 any luck on this?

@felipemello1
Copy link
Contributor

felipemello1 commented Jan 14, 2025

hey @EugenHotaj , just got time to start it today.

  • I am running it now to reproduce the pattern on single device
  • Then I will take a look if there is some quick win related to the loss
  • Then I will try to reproduce it on HF

If it doesnt happen in HF, we know its a torchtune issue. If it happens there, then I wonder if our premise is correct that bf16 convergence == fp32 convergence

@felipemello1
Copy link
Contributor

Single device llama 3b, no difference. Bf16 actually seems better

Image

tune run full_finetune_single_device --config llama3_2/3B_full_single_device enable_activation_checkpointing=True enable_activation_offloading=True compile=False dataset.packed=True dataset.split=train[:5%] tokenizer.max_seq_len=1024 metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=profiling log_every_n_steps=1 log_peak_memory_stats=True gradient_accumulation_steps=1 max_steps_per_epoch=50 epochs=1 batch_size=1 profiler.enabled=False optimizer_in_bwd=True optimizer=torch.optim.AdamW dtype=fp32

gonna try to reproduce it using distributed

@EugenHotaj
Copy link
Contributor Author

@felipemello1 one thing I would test is setting gas > 0. I think what might be happening is that we lose precision when accumulating gradients.

@felipemello1
Copy link
Contributor

For distributed, no large changes. Let me take a look at grad accumulation

Image

tune run --nproc_per_node 8 full_finetune_distributed --config llama3_2/3B_full enable_activation_checkpointing=True enable_activation_offloading=True compile=False dataset.packed=True dataset.split=train[:5%] tokenizer.max_seq_len=1024 metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=profiling log_every_n_steps=1 log_peak_memory_stats=True gradient_accumulation_steps=1 max_steps_per_epoch=50 epochs=1 batch_size=1 profiler.enabled=False optimizer_in_bwd=True optimizer=torch.optim.AdamW dtype=fp32

@felipemello1
Copy link
Contributor

felipemello1 commented Jan 14, 2025

used single_device, grad_acc=4, bsz=1, nothing crazy :/

Image

tune run full_finetune_single_device --config llama3_2/3B_full_single_device enable_activation_checkpointing=True enable_activation_offloading=True compile=False dataset.packed=True dataset.split=train[:5%] tokenizer.max_seq_len=1024 metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=profiling log_every_n_steps=1 log_peak_memory_stats=True gradient_accumulation_steps=4 max_steps_per_epoch=30 epochs=1 batch_size=1 profiler.enabled=False optimizer_in_bwd=False optimizer=torch.optim.AdamW dtype=fp32

same for distributed, but I can run for longe to confirm

Image

tune run --nproc_per_node 8 full_finetune_distributed --config llama3_2/3B_full enable_activation_checkpointing=True enable_activation_offloading=True compile=False dataset.packed=True dataset.split=train[:30%] tokenizer.max_seq_len=1024 metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=profiling log_every_n_steps=1 log_peak_memory_stats=True gradient_accumulation_steps=4 max_steps_per_epoch=50 epochs=1 batch_size=1 profiler.enabled=False optimizer_in_bwd=False optimizer=torch.optim.AdamW dtype=fp32

@felipemello1
Copy link
Contributor

can you reproduce it with some public dataset and share the config with me? I can test it with 8xA100

@EugenHotaj
Copy link
Contributor Author

EugenHotaj commented Jan 14, 2025

Hmm that's definitely surprising. Are you able to try with the following config (but your own dataset):

  • Maybe the issue is not as prevalent for 3b model
  • Maybe 2e-5 LR is too high
  • Let it run for longer, say 100-200 steps.

If you still don't see any issues I can try to repro with a public dataset.

name: ???
output_dir: /traindata/eugen/runs/torchtune/${name}

# Dataset Arguments
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: /traindata/llama_hf_ckpt/Llama-3.1-8B-Instruct/original/tokenizer.model
  max_seq_len: 131072

dataset: <your-ds>
  
seed: 42
shuffle: True

# Model Arguments
model:
  _component_: torchtune.models.llama3_1.llama3_1_8b

checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /traindata/llama_hf_ckpt/Llama-3.1-8B-Instruct/
  checkpoint_files: [
    model-00001-of-00004.safetensors,
    model-00002-of-00004.safetensors,
    model-00003-of-00004.safetensors,
    model-00004-of-00004.safetensors
  ]
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: LLAMA3
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 1
epochs: 1

optimizer:
  _component_: torch.optim.Adam
  lr: 6e-7
  fused: True
loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: 230
compile: False  # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False  # True saves memory. Requires gradient_accumulation_steps=1
gradient_accumulation_steps: 8  # Use to increase effective batch size
clip_grad_norm: 1.0

# Training env
device: cuda
dtype: bf16

# Memory management
enable_activation_checkpointing: True  # True reduces memory
enable_activation_offloading: False  # True reduces memory
custom_sharded_layers: ['tok_embeddings', 'output']  # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed.

# Logging
metric_logger:
  _component_: torchtune.training.metric_logging.WandBLogger
  project: torchtune
  name: ${name}
  log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True

@felipemello1
Copy link
Contributor

felipemello1 commented Jan 14, 2025

maybe alpaca dataset is too easy? any suggestions of a harder public dataset?

@EugenHotaj
Copy link
Contributor Author

@felipemello1 on second thought, even in your graphs we already see loss divergence between bf16 and fp32 with the latter being lower. It's likely this gap would grow over time.

Image Image

@felipemello1
Copy link
Contributor

felipemello1 commented Jan 14, 2025

From @gau-nernst

It's expected that full BF16 training may have convergence issue compared to FP32, especially with small LR. It's simply because BF16 doesn't have enough mantissa bits (precision) for small weight update.
The usual approach is mixed precision training (weight in FP32, compute in BF16). You can also try Adam with stochastic rounding, which aims to help with small weight update BF16, but it still couldn't quite match the convergence of FP32 weight https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim#stochastic-rounding-for-bf16-weight

@EugenHotaj , wanna give the stochastic rounding a try? (only works with torch nightlies)

@EugenHotaj
Copy link
Contributor Author

EugenHotaj commented Jan 18, 2025

@felipemello1 thanks for the pointer! Just had a chance to try it out today but seems to make a significant difference:

Image Image

These were just some quick 1-node runs but looks pretty promising. I'll kick off some real runs and report back.

When would this make it to stable if we want to use it? The implementation doesn't look too difficult but there seems to be some ops missing in 2.5.1 and couldn't get this to run when I ported over the code.

Also, I wasn't able to run the original code and kept getting the following failure here (cc @gau-nernst):

[rank0]:     raise RuntimeError(
[rank0]: torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function lt>(*(FakeTensor(..., device='cuda:0', size=(128256, 4096), dtype=torch.int32), DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(16032, 4096), dtype=torch.int32), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3, 4, 5, 6, 7]), placements=(Shard(dim=0),))), **{}):
[rank0]: aten.lt.Tensor: got mixed torch.Tensor and DTensor, need to convert all torch.Tensor to DTensor before calling distributed operators!

Not sure if it's a bug or just something off with my setup but the following fixed it:

Image

@gau-nernst
Copy link
Contributor

@EugenHotaj Stochastic-rounding Adam/AdamW + FSDP2 requires torch nightly pytorch/ao#1505

I'm curious that the curves look quite different. I haven't tried SR AdamW in fine-tuning setting much, so not sure if that's to be expected. Do you use the same hparams for all settings? My uneducated guess is that you seem to use a higher LR for BF16-SR.

@EugenHotaj
Copy link
Contributor Author

EugenHotaj commented Jan 18, 2025

@gau-nernst the issue I posted above seem slightly different from pytorch/ao#1505. I actually hit pytorch/ao#1505 as well but it was fixed by using nightlies as you mentioned. Not sure if I'm on the wrong version though, I basically just did:

pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu124

I'm curious that the curves look quite different.

The hparams are identical, only setting I'm changing is bf16_stochastic_round. We observed the same thing when using fp32 vs b16 params (if you see my very first message). It's pretty surprising to me as well, my guess is it might be due to our very low learning rate (6e-7).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants