From ba8231f78feb6192912fd2340662b4aaf0d3c114 Mon Sep 17 00:00:00 2001 From: Matthieu Le Date: Tue, 28 Jan 2025 07:41:15 -0800 Subject: [PATCH] ADLR/megatron-lm!2580 - Fix dataloader save state --- megatron/training/checkpointing.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py index 70b9eac4ca..47aca3b21f 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -374,7 +374,7 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati tensor_rank=tensor_rank, pipeline_rank=pipeline_rank, expert_parallel=expert_parallel, expert_rank=expert_rank, return_base_dir=return_base_dir) # Save dataloader state if the dataloader supports it (currently only Megatron Energon). - save_dataloader_state(train_data_iterator, iteration, getattr(args, "dataloader_save", None)) + maybe_save_dataloader_state(train_data_iterator, iteration, getattr(args, "dataloader_save", None)) # Save distributed optimizer's custom parameter state. if ( @@ -562,7 +562,7 @@ def remove_iter_ckpts(_iter_ckpts): remove_iter_ckpts(rm_iter_ckpts) -def save_dataloader_state(train_iterator, iteration, dataloader_save_path): +def maybe_save_dataloader_state(train_iterator, iteration, dataloader_save_path): """Saves dataloader state if the dataloader supports it. Currently, this is only used by Megatron Energon dataloader (multimodal) to store its state at a @@ -577,13 +577,13 @@ def save_dataloader_state(train_iterator, iteration, dataloader_save_path): iteration (int): Current iteration. dataloader_save_path (str): Path where the dataloader state is saved. """ - # If no dataloader or saving path is provided, then exit early. - if train_iterator is None or dataloader_save_path is None: + # If no dataloader or saving path is provided, exit early, otherwise, raise an error. + if train_iterator is None or dataloader_save_path is None or dataloader_save_path == "": return - # If dataloader doesn't support saving state, exit early. - if not hasattr(train_iterator, "save_state"): - return + # If dataloader doesn't support saving state, raise an error. + if not hasattr(train_iterator.iterable, "save_state"): + raise RuntimeError(f"Could not find a save_state for the train_iterator of type {type(train_iterator)}") # Save dataloader state for each data parallel rank only once. first_rank = mpu.is_pipeline_first_stage(ignore_virtual=True) and mpu.get_tensor_model_parallel_rank() == 0 @@ -592,7 +592,7 @@ def save_dataloader_state(train_iterator, iteration, dataloader_save_path): dp_rank = mpu.get_data_parallel_rank() print(f"saving dataloader checkpoint at iteration {iteration} to {dataloader_save_path}") - train_dataloader_state_dict = train_iterator.save_state() + train_dataloader_state_dict = train_iterator.iterable.save_state() data_state_save_path = get_checkpoint_name( dataloader_save_path, iteration, basename=f'train_dataloader_dprank{dp_rank:03d}.pt'