Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loss Fails to Converge in Nemo2-sft.ipynb with Precision 16 #12102

Open
twotwoiscute opened this issue Feb 8, 2025 · 0 comments
Open

Loss Fails to Converge in Nemo2-sft.ipynb with Precision 16 #12102

twotwoiscute opened this issue Feb 8, 2025 · 0 comments

Comments

@twotwoiscute
Copy link

twotwoiscute commented Feb 8, 2025

Description

I wrote a finetuning script based on the nemo2-sft.ipynb provided by this repository, running on 8 V100 (32GB) machines. However, I noticed that the original script uses bf16 precision, which the V100 does not support—causing it to eventually fall back to fp32. When I switched the precision to fp16 and ran the training process, I observed that the loss failed to converge and the gradient norm reported on wandb was extremely small (on the order of 1e-9). In contrast, using the default setting (precision=bf16, which falls back to fp32) resulted in a much longer training time, but the loss decreased and the gradient norm was significantly larger.

Environment

Docker image: nvcr.io/nvidia/nemo:24.12
torch.__version__ 
'2.5.0a0+e000cf0ad9.nv24.10'

# Driver
Driver Version: 535.183.01

# CUDA things
root@ai-server:/opt/pytorch/apex# nvcc -V 
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Sep_12_02:18:05_PDT_2024
Cuda compilation tools, release 12.6, V12.6.77
Build cuda_12.6.r12.6/compiler.34841621_0

Script

def get_parer():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "dataset",
        type=str,
        choices=["smart_chat", "dolly"],
        help="Specify the name of dataset for training"
    )
    parser.add_argument(
        "precision",
        type=str,
        choices=["16", "bf16"],
        help="Specify the precision for training"
    )
    parser.add_argument(
        "model",
        type=str,
        choices=["llama3", "llama3.1"],
        help="Select the model for training"
    )
    return parser.parse_args()

def smart_chat() -> run.Config[pl.LightningDataModule]:
    return run.Config(llm.SmartChatDataModule, seq_length=6400, micro_batch_size=4, global_batch_size=16, num_workers=4)

def dolly() -> run.Config[pl.LightningDataModule]:
    return run.Config(llm.DollyDataModule, seq_length=2048, micro_batch_size=1, global_batch_size=128, num_workers=4)

def trainer(args) -> run.Config[nl.Trainer]:
    strategy = run.Config(
        nl.MegatronStrategy,
        tensor_model_parallel_size=4,
        pipeline_model_parallel_size=2,
        pipeline_dtype=torch.float16 if args.precision == "16" else torch.bfloat16,
        megatron_amp_O2=True
    )
    trainer = run.Config(
        nl.Trainer,
        devices=8,
        max_steps=1000,
        accelerator="gpu",
        strategy=strategy,
        plugins=fp16_mixed() if args.precision == "16" else bf16_mixed(),
        log_every_n_steps=1,
        val_check_interval=250,
        num_sanity_val_steps=0,
    )
    return trainer

def wandb_logger(args, project: str = "nemo", entity: Optional[str] = None) -> run.Config[WandbLogger]:
    # This code is based on: NeMo/nemo/collections/llm/recipes/log/default.py:29
    name = f"{args.precision}_{args.model}_{args.dataset}"
    cfg = run.Config(
        WandbLogger,
        project=project,
        name=name,
        config={},
    )

    if entity:
        cfg.entity = entity

    return cfg

def logger(args) -> run.Config[nl.NeMoLogger]:
    ckpt = run.Config(
        nl.ModelCheckpoint,
        save_last=True,
        monitor="val_loss",
        save_top_k=1,
        save_on_train_epoch_end=True,
        save_optim_on_train_end=True,
    )

    return run.Config(
        nl.NeMoLogger,
        name="nemo2_sft",
        log_dir="/workspace/law_ai_model/nemo/result",
        use_datetime_version=False,
        ckpt=ckpt,
        wandb=wandb_logger(args)
    )

def adam_with_cosine_annealing(args) -> run.Config[nl.OptimizerModule]:
    opt_cfg = run.Config(
        OptimizerConfig,
        optimizer="adam",
        lr=1e-4,
        adam_beta1=0.9,
        adam_beta2=0.98,
        use_distributed_optimizer=True,
        clip_grad=1.0,
        fp16=True,
    ) if args.precision == "16" else run.Config(
        OptimizerConfig,
        optimizer="adam",
        lr=1e-4,
        adam_beta1=0.9,
        adam_beta2=0.98,
        use_distributed_optimizer=True,
        clip_grad=1.0,
        bf16=True
    )
    sched = run.Config(
        CosineAnnealingScheduler,
        warmup_steps=10,
        constant_steps=0,
        min_lr=0.0
    )
        
    return run.Config(
        nl.MegatronOptimizerModule,
        config=opt_cfg,
        lr_scheduler=sched
    )

def llama3_8b() -> run.Config[pl.LightningModule]:
    return run.Config(
        llm.LlamaModel, 
        config=run.Config(
            llm.Llama3Config8B,
            recompute_granularity="full",
            recompute_method="uniform",
            recompute_num_layers=8,
        )
    )

def llama31_8b() -> run.Config[pl.LightningModule]:
    return run.Config(
        llm.LlamaModel, 
        config=run.Config(
            llm.Llama31Config8B,
            recompute_granularity="full",
            recompute_method="uniform",
            recompute_num_layers=8,
        )
    )

def resume(args) -> run.Config[nl.AutoResume]:
    return run.Config(
        nl.AutoResume,
        restore_config=run.Config(nl.RestoreConfig,
            path="nemo://meta-llama/Meta-Llama-3-8B" if args.model == "llama3" else "nemo://meta-llama/Meta-Llama-3.1-8B"
        ),
        resume_if_exists=True,
    )

def configure_finetuning_recipe(args):
    return run.Partial(
        llm.finetune,
        model=llama3_8b() if args.model == "llama3" else llama31_8b(),
        trainer=trainer(args),
        data=dolly() if args.dataset == "dolly" else smart_chat(),
        log=logger(args),
        optim=adam_with_cosine_annealing(args),
        resume=resume(args),
    )

def local_executor_torchrun(nodes: int = 1, devices: int = 8) -> run.LocalExecutor:
    # Env vars for jobs are configured here
    env_vars = {
        "TORCH_NCCL_AVOID_RECORD_STREAMS": "1",
        "NCCL_NVLS_ENABLE": "0",
        "NVTE_DP_AMAX_REDUCE_INTERVAL": "0",
        "NVTE_ASYNC_AMAX_REDUCTION": "1",
    }

    executor = run.LocalExecutor(ntasks_per_node=devices, launcher="torchrun", env_vars=env_vars)
    return executor

if __name__ == '__main__':
    args = get_parer()
    t0 = time.time()
    run.run(
        configure_finetuning_recipe(args), 
        executor=local_executor_torchrun()
    )
    t1 = time.time()
    print(
        f"Args: {args} takes {round((t1-t0)/3600, 2)} hours to finish training."
    )

Report from wandb

  1. grad_norm for fp16 and bf16, the grad_norm of fp16 is nearly zero.
Image
  1. Loss for fp16 and bf16, the loss of bf16 is decreasing and much smaller than fp16, the loss of fp16 has no sign that the model is learning.
Image

Question

I believe that mixed-precision techniques—such as PyTorch’s Automatic Mixed Precision (AMP) or NVIDIA’s Apex—apply loss scaling to prevent underflow in FP16 gradients. However, it appears that the scaling does not occur when using fp16. What is the recommended workaround? Thanks.

@twotwoiscute twotwoiscute changed the title Loss Fails to Converge in Nemo2-sft.ipynb with Precision FP16 due to Extremely Small Gradients Loss Fails to Converge in Nemo2-sft.ipynb with Precision 16 Feb 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant