Skip to content

Commit

Permalink
Merge branch 'mblaz/deprecate-two-stage' into 'main'
Browse files Browse the repository at this point in the history
Deprecate unused checkpointing module

See merge request ADLR/megatron-lm!2530
  • Loading branch information
jaredcasper committed Feb 1, 2025
2 parents 6356152 + d0a0c78 commit 731fbfd
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions megatron/core/dist_checkpointing/strategies/two_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,16 @@
timers = defaultdict(list)

logger = getLogger(__name__)
logger.warning(
'megatron.core.dist_checkpointing.two_stage module is deprecated'
' and will be removed in Megatron-Core v0.12. Please use'
' FullyParallelLoadStrategyWrapper to accomplish a parallelized checkpoint load.'
)


def timed(verbose=True):
"""Timing decorator."""

def timed_dec(fn):
name = fn.__name__

Expand Down Expand Up @@ -58,6 +65,7 @@ class _ShardedTensorMetadata:


def sharded_tensor_chunk_id(sharded_tensor: ShardedTensor):
"""Id of a sharded tensor."""
return (sharded_tensor.key, sharded_tensor.global_offset)


Expand Down Expand Up @@ -100,6 +108,7 @@ def __init__(self, data_parallel_group, cpu_transfer=True):
self.global_rank = torch.distributed.get_rank()

def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
"""Main load method."""
self.maybe_init_gloo_group()
all_tensors_sorted = self._build_load_plan(sharded_state_dict)
self._exchange_loaded_tensors(all_tensors_sorted, sharded_state_dict, checkpoint_dir)
Expand All @@ -108,6 +117,7 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path):
return sharded_state_dict

def summarize_load_times(self):
"""Summarize load times."""
torch.distributed.barrier()
logger.info('Checkpoint loading finished. Summary:')
# TODO: `timers` keys are not guaranteed to be the same across ranks which causes hangs
Expand All @@ -123,6 +133,7 @@ def summarize_load_times(self):

@timed(verbose=False)
def load_tensor_from_storage(self, checkpoint_dir, ten_meta: _ShardedTensorMetadata):
"""Load tensor from storage."""
logger.debug(f'_load_from_array({ten_meta.sharded_tensor_no_data.key}) init')
ret = _load_from_array(
ten_meta.sharded_tensor_no_data,
Expand All @@ -135,6 +146,7 @@ def load_tensor_from_storage(self, checkpoint_dir, ten_meta: _ShardedTensorMetad

@timed()
def maybe_init_gloo_group(self):
"""Create Gloo groups."""
if not self.cpu_transfer:
return
all_groups = [None] * torch.distributed.get_world_size()
Expand Down

0 comments on commit 731fbfd

Please sign in to comment.