diff --git a/megatron/core/dist_checkpointing/strategies/two_stage.py b/megatron/core/dist_checkpointing/strategies/two_stage.py index b76bf6a103..50b31e2497 100644 --- a/megatron/core/dist_checkpointing/strategies/two_stage.py +++ b/megatron/core/dist_checkpointing/strategies/two_stage.py @@ -25,9 +25,16 @@ timers = defaultdict(list) logger = getLogger(__name__) +logger.warning( + 'megatron.core.dist_checkpointing.two_stage module is deprecated' + ' and will be removed in Megatron-Core v0.12. Please use' + ' FullyParallelLoadStrategyWrapper to accomplish a parallelized checkpoint load.' +) def timed(verbose=True): + """Timing decorator.""" + def timed_dec(fn): name = fn.__name__ @@ -58,6 +65,7 @@ class _ShardedTensorMetadata: def sharded_tensor_chunk_id(sharded_tensor: ShardedTensor): + """Id of a sharded tensor.""" return (sharded_tensor.key, sharded_tensor.global_offset) @@ -100,6 +108,7 @@ def __init__(self, data_parallel_group, cpu_transfer=True): self.global_rank = torch.distributed.get_rank() def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): + """Main load method.""" self.maybe_init_gloo_group() all_tensors_sorted = self._build_load_plan(sharded_state_dict) self._exchange_loaded_tensors(all_tensors_sorted, sharded_state_dict, checkpoint_dir) @@ -108,6 +117,7 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): return sharded_state_dict def summarize_load_times(self): + """Summarize load times.""" torch.distributed.barrier() logger.info('Checkpoint loading finished. Summary:') # TODO: `timers` keys are not guaranteed to be the same across ranks which causes hangs @@ -123,6 +133,7 @@ def summarize_load_times(self): @timed(verbose=False) def load_tensor_from_storage(self, checkpoint_dir, ten_meta: _ShardedTensorMetadata): + """Load tensor from storage.""" logger.debug(f'_load_from_array({ten_meta.sharded_tensor_no_data.key}) init') ret = _load_from_array( ten_meta.sharded_tensor_no_data, @@ -135,6 +146,7 @@ def load_tensor_from_storage(self, checkpoint_dir, ten_meta: _ShardedTensorMetad @timed() def maybe_init_gloo_group(self): + """Create Gloo groups.""" if not self.cpu_transfer: return all_groups = [None] * torch.distributed.get_world_size()