Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling of "auto" in deepspeed config causes crash under Zero3 #2154

Open
2 of 4 tasks
Ben-Schneider-code opened this issue Oct 2, 2024 · 1 comment · May be fixed by #2224
Open
2 of 4 tasks

Handling of "auto" in deepspeed config causes crash under Zero3 #2154

Ben-Schneider-code opened this issue Oct 2, 2024 · 1 comment · May be fixed by #2224
Labels
🐛 bug Something isn't working 🚀 deepspeed Related to deepspeed 🏋 DPO Related to DPO 🙋 help wanted Open invitation for community members to contribute

Comments

@Ben-Schneider-code
Copy link
Contributor

Ben-Schneider-code commented Oct 2, 2024

System Info

  • Platform: Linux-5.15.0-86-generic-x86_64-with-glibc2.35
  • Python version: 3.10.14
  • PyTorch version: 2.4.1
  • CUDA device: NVIDIA A100-SXM4-80GB
  • Transformers version: 4.45.0.dev0
  • Accelerate version: 0.34.2
  • Accelerate config: not found
  • Datasets version: 3.0.0
  • HF Hub version: 0.25.0
  • TRL version: 0.12.0.dev0+5c21de3
  • bitsandbytes version: 0.41.1
  • DeepSpeed version: 0.14.5+ffe0af23
  • Diffusers version: 0.30.3
  • Liger-Kernel version: 0.3.1
  • LLM-Blender version: 0.0.2
  • OpenAI version: 1.51.0
  • PEFT version: 0.13.0

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

This issue was reported in the hf transformers repo initially here:
huggingface/transformers#29348
I can probably put together a fix for trl when I have some more free time if y'all are interested, since I understand the behaviour now.

Current Behaviour

The base huggingface transformer calls hf_deepspeed_config.trainer_config_finalize(args, model, num_training_steps) to change the values of total_num_steps" and warmup_num_steps from auto to be their calculated value during the inner training loop (when the total_num_steps is know). However, in DPOTrainer if total_num_steps is set to "auto" then the trainer will crash when deepspeed.initialize is called when wrapping the ref model at self.ref_model = self._prepare_deepspeed(self.ref_model).

DS config

{
    "resource": {
        "num_gpus": 0
    },
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "weight_decay": "auto",
            "torch_adam": true,
            "adam_w_mode": true
        }
    },
    "scheduler": {
        "type": "WarmupDecayLR",
        "params": {
            "total_num_steps": "auto",
            "warmup_min_lr": "auto",
            "warmup_max_lr": "auto",
            "warmup_num_steps": "auto"
        }
    },
    "zero_optimization": {
        "stage": 3,
        "allgather_partitions": true,
        "allgather_bucket_size": 2e8,
        "overlap_comm": true,
        "reduce_scatter": true,
        "reduce_bucket_size": "auto",
        "contiguous_gradients": true,
        "stage3_gather_16bit_weights_on_model_save": "auto"
    },
    "gradient_accumulation_steps": 1,
    "gradient_clipping": "auto",
    "steps_per_print": 2000,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false
}

Script

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOTrainer, DPOConfig  # Make sure you have this module
from datasets import load_dataset
# Load your LLaMA 2 model and tokenizer
model_name = "/home/b3schnei/pretrained/Llama-2-7b"  # Change this to the specific LLaMA 2 model you want to use
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
from datasets import Dataset

# Load your reference model (if applicable)
ref_model = AutoModelForCausalLM.from_pretrained(model_name)

# Define training arguments
training_args = DPOConfig(
    learning_rate=2e-4,
    num_train_epochs=1,
    per_device_train_batch_size=4,
    output_dir='./results',
    logging_steps=10,
    remove_unused_columns=False,
    max_length=1024,
    max_prompt_length=512,
    fp16=True,
    deepspeed="/home/b3schnei/transformers_debug/debug/29348/ds_config.json"  # Ensure you have this configuration file
)



train_dataset = load_dataset("json", data_files="debug/29348/dpo.json",split="train")

# Initialize the DPOTrainer
dpo_trainer = DPOTrainer(
    model=model,
    ref_model=ref_model,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
    args=training_args,
)

# Start training
if __name__ == "__main__":
    dpo_trainer.train()

Crash log

[2024-10-02 01:18:15,497] [INFO] [utils.py:789:see_memory_usage] CPU Virtual Memory: used = 121.5 GB, percent = 12.1%
[2024-10-02 01:18:15,497] [INFO] [logging.py:96:log_dist] [Rank 0] DeepSpeed Final Optimizer = DeepSpeedZeroOptimizer_Stage3
[rank0]: Traceback (most recent call last):
[rank0]: File "/home/b3schnei/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/pydevd.py", line 3489, in
[rank0]: main()
[rank0]: File "/home/b3schnei/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/pydevd.py", line 3482, in main
[rank0]: globals = debugger.run(setup['file'], None, None, is_module)
[rank0]: File "/home/b3schnei/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/pydevd.py", line 2510, in run
[rank0]: return self._exec(is_module, entry_point_fn, module_name, file, globals, locals)
[rank0]: File "/home/b3schnei/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/pydevd.py", line 2517, in _exec
[rank0]: globals = pydevd_runpy.run_path(file, globals, 'main')
[rank0]: File "/home/b3schnei/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path
[rank0]: return _run_module_code(code, init_globals, run_name,
[rank0]: File "/home/b3schnei/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code
[rank0]: _run_code(code, mod_globals, init_globals,
[rank0]: File "/home/b3schnei/.vscode-server/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code
[rank0]: exec(code, run_globals)
[rank0]: File "/home/b3schnei/transformers_debug/debug/29348/reproduce.py", line 34, in
[rank0]: dpo_trainer = DPOTrainer(
[rank0]: File "/home/b3schnei/anaconda3/envs/test_transformers/lib/python3.10/site-packages/huggingface_hub/utils/_deprecation.py", line 101, in inner_f
[rank0]: return f(*args, **kwargs)
[rank0]: File "/home/b3schnei/anaconda3/envs/test_transformers/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 883, in init
[rank0]: self.ref_model = self._prepare_deepspeed(self.ref_model)
[rank0]: File "/home/b3schnei/anaconda3/envs/test_transformers/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 924, in prepare_deepspeed
[rank0]: model, *
= deepspeed.initialize(model=model, config=config_kwargs)
[rank0]: File "/home/b3schnei/anaconda3/envs/test_transformers/lib/python3.10/site-packages/deepspeed/init.py", line 181, in initialize
[rank0]: engine = DeepSpeedEngine(args=args,
[rank0]: File "/home/b3schnei/anaconda3/envs/test_transformers/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 307, in init
[rank0]: self._configure_lr_scheduler(lr_scheduler)
[rank0]: File "/home/b3schnei/anaconda3/envs/test_transformers/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 907, in _configure_lr_scheduler
[rank0]: lr_scheduler = self._scheduler_from_config(self.optimizer)
[rank0]: File "/home/b3schnei/anaconda3/envs/test_transformers/lib/python3.10/site-packages/deepspeed/runtime/engine.py", line 962, in _scheduler_from_config
[rank0]: instantiated_scheduler = scheduler(optimizer, **scheduler_params)
[rank0]: File "/home/b3schnei/anaconda3/envs/test_transformers/lib/python3.10/site-packages/deepspeed/runtime/lr_schedules.py", line 758, in init
[rank0]: if self.total_num_steps < self.warmup_num_steps:
[rank0]: TypeError: '<' not supported between instances of 'str' and 'int'

Expected behavior

I expect the DPOTrainer to initialize under Zero3 when setting ds_config values to "auto" like in transformer's trainer.

@Ben-Schneider-code Ben-Schneider-code added the 🐛 bug Something isn't working label Oct 2, 2024
@qgallouedec qgallouedec added 🏋 DPO Related to DPO 🚀 deepspeed Related to deepspeed labels Oct 7, 2024
@qgallouedec
Copy link
Member

I can probably put together a fix for trl when I have some more free time if y'all are interested, since I understand the behaviour now.

Thanks for reporting, help in proposing a fix would be greatly appreciated.

@qgallouedec qgallouedec added the 🙋 help wanted Open invitation for community members to contribute label Oct 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 bug Something isn't working 🚀 deepspeed Related to deepspeed 🏋 DPO Related to DPO 🙋 help wanted Open invitation for community members to contribute
Projects
None yet
2 participants