Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Universal Checkpoint Conversion: Resumed Training Behaves as If Model Initialized from Scratch #6691

Open
purefall opened this issue Oct 30, 2024 · 1 comment
Labels
bug Something isn't working training

Comments

@purefall
Copy link

I'm experiencing an issue when using DeepSpeed's universal checkpointing. After converting my DeepSpeed checkpoint to a universal checkpoint using ds_to_universal.py, resuming training from the converted checkpoint results in the model behaving as if it was initialized from scratch—the training loss is significantly higher, similar to starting training from the beginning.

Environment:

  • DeepSpeed version: 0.15.0
  • PyTorch version: 2.4.1
  • Python version: 3.9
  • Model: Llama
  • ZeRO Optimization Stage: 3
  • Number of GPUs: 16
  • Distributed Backend: NCCL

Steps to Reproduce:

  1. Train a model using DeepSpeed with ZeRO Stage 3 optimization and save checkpoints.

  2. Use ds_to_universal.py to convert the DeepSpeed checkpoint to a universal checkpoint:

    python ds_to_universal.py \
        --input_folder path_to_deepspeed_checkpoint \
        --output_folder path_to_universal_checkpoint \
        --num_extract_workers 1 \
        --num_merge_workers 1
    • Correspondiing Logs:
Click me

[2024-10-30 11:42:34,345] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
args = Namespace(input_folder='.global/', output_folder='global_universal', num_extract_workers=1, num_merge_workers=1, keep_temp_folder=False, strict=True, inject_missing_state=False)
Convert DeepSpeed Checkpoint to Universal Checkpoint
Converting DeepSpeed checkpoint in ./global_universal
ds_to_universal.py:449: FutureWarning: You are using torch.load with weights_only=False (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for weights_only will be flipped to True. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via torch.serialization.add_safe_globals. We recommend you start setting weights_only=True for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(optim_files[0], map_location=torch.device('cpu'))
ds_to_universal.py:407: FutureWarning: You are using torch.load with weights_only=False (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for weights_only will be flipped to True. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via torch.serialization.add_safe_globals. We recommend you start setting weights_only=True for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
return torch.load(files[0], map_location=torch.device('cpu'))[PARAM_SHAPES]

*** 1. Extracting ZeRO fragments
0%| | 0/16 [00:00<?, ?it/s] ds_to_universal.py:153: FutureWarning: You are using torch.load with weights_only=False (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for weights_only will be flipped to True. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via torch.serialization.add_safe_globals. We recommend you start setting weights_only=True for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(optim_files[dp_index], map_location='cpu')
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [01:46<00:00, 6.68s/it]

*** 2. Merging slices .....
0%| | 0/201 [00:00<?, ?it/s] ds_to_universal.py:217: FutureWarning: You are using torch.load with weights_only=False (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for weights_only will be flipped to True. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via torch.serialization.add_safe_globals. We recommend you start setting weights_only=True for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
shards = [torch.load(p) for p in paths]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 201/201 [01:34<00:00, 2.13it/s]

*** 3. Saving common optimizer states
ds_to_universal.py:423: FutureWarning: You are using torch.load with weights_only=False (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for weights_only will be flipped to True. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via torch.serialization.add_safe_globals. We recommend you start setting weights_only=True for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
sd = torch.load(optim_files[0], map_location=torch.device('cpu'))
*** Done!

  1. Modify the training script to resume training from the universal checkpoint:

    • Set load_universal to True in the DeepSpeed config.
    • Load the checkpoint using model.load_checkpoint().
  2. Resume training using the converted universal checkpoint.

Expected Behavior:

  • The model should resume training from the checkpointed state, with training loss consistent with the point at which training was paused.

Actual Behavior:

  • Upon resuming training, the training loss is significantly higher, as if the model weights and optimizer states were not correctly restored.
  • It appears that the model is starting from random initialization rather than the checkpointed state.

Additional Details:

  • Optimizer State Loading:

    • After loading the checkpoint the optimizer states are updated
  • Checkpoint Conversion:

    • The conversion script ds_to_universal.py runs without errors.
    • The generated universal checkpoint seems to have the correct structure.
  • Resuming from Original Checkpoint:

    • When resuming training from the original DeepSpeed checkpoint (before conversion), the training loss is consistent, indicating the issue arises after conversion to the universal checkpoint.

Code Snippets:

Relevant sections of my training script:

# DeepSpeed initialization
model, optimizer, _, lr_scheduler = deepspeed.initialize(
    model=model,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    model_parameters=model.parameters(),
    config=deepspeed_config,
)

assert model.load_universal_checkpoint() == args.universal_checkpoint, f"{args.universal_checkpoint} checkpoint not found"
model.load_checkpoint("path_to_universal_checkpoint")

DeepSpeed configuration:

{
    "bf16": {
        "enabled": "auto"
    },
    "zero_optimization": {
        "stage": 3,
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 5e8,
        "reduce_bucket_size": 5e8,
        "stage3_max_live_parameters": 5e8,
        "stage3_max_reuse_distance": 5e8,
        "stage3_gather_16bit_weights_on_model_save": true
    },
    "train_micro_batch_size_per_gpu": args.batch_size,
    "gradient_clipping": args.gradient_clipping,
    "gradient_accumulation_steps": args.gradient_accumulation_steps,
    "checkpoint": {
        "load_universal": true
    }
}

Attempts to Resolve:

  • Verified Configuration:
    • Ensured that load_universal is set to True in the DeepSpeed config.
    • Captured optimizer and lr_scheduler returned by deepspeed.initialize().
  • Monitoring:
    • Checked learning rate and global steps before and after loading the checkpoint.
    • Observed that the learning rate remains the same, and global steps do not reflect the checkpointed state.
  • Testing:
    • Reduced the number of extract and merge workers to 1 when running ds_to_universal.py to rule out parallelism issues.
    • Ensured version consistency of DeepSpeed and PyTorch across training, conversion, and resumption.

Questions:

  • All the examples provided already have a high loss like the ones in the repo continued training from loss of 7 or so, is there an example continuing training from a chekcpoint that is trained well for a long time/high number of tokens?
  • Is there a known issue with ds_to_universal.py not correctly converting optimizer states for ZeRO Stage 3 and continuing with higher loss?

Any help or guidance would be greatly appreciated.


@purefall purefall added bug Something isn't working training labels Oct 30, 2024
@purefall
Copy link
Author

purefall commented Nov 5, 2024

Any update on this issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

1 participant