From d0a0c787d7cfa3456e5d367ac63ca342a3947e6d Mon Sep 17 00:00:00 2001 From: Mikolaj Blaz Date: Sat, 1 Feb 2025 05:36:38 -0800 Subject: [PATCH] ADLR/megatron-lm!2530 - Deprecate unused checkpointing module --- .../strategies/two_stage.py | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/megatron/core/dist_checkpointing/strategies/two_stage.py b/megatron/core/dist_checkpointing/strategies/two_stage.py index 72e60bc79b..97a07d8821 100644 --- a/megatron/core/dist_checkpointing/strategies/two_stage.py +++ b/megatron/core/dist_checkpointing/strategies/two_stage.py @@ -1,23 +1,22 @@ # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. """ 2-stage checkpoint loading. """ -import os import time from collections import defaultdict from dataclasses import dataclass from functools import partial, wraps from itertools import chain -from logging import DEBUG, INFO, StreamHandler, getLogger +from logging import getLogger from operator import attrgetter, itemgetter from pathlib import Path -from typing import Iterable, List, NamedTuple, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from ..dict_utils import dict_list_map_inplace, map_reduce, nested_values -from ..mapping import ShardedStateDict, ShardedTensor, StateDict +from ..mapping import ShardedStateDict, ShardedTensor from .base import LoadShardedStrategy -from .tensorstore import TensorStoreLoadShardedStrategy, _load_from_array, open_ts_array +from .tensorstore import _load_from_array, open_ts_array from .zarr import flatten_range, load_zarr_based_sharded_metadata _import_trigger = None @@ -26,9 +25,16 @@ timers = defaultdict(list) logger = getLogger(__name__) +logger.warning( + 'megatron.core.dist_checkpointing.two_stage module is deprecated' + ' and will be removed in Megatron-Core v0.12. Please use' + ' FullyParallelLoadStrategyWrapper to accomplish a parallelized checkpoint load.' +) def timed(verbose=True): + """Timing decorator.""" + def timed_dec(fn): name = fn.__name__ @@ -59,6 +65,7 @@ class _ShardedTensorMetadata: def sharded_tensor_chunk_id(sharded_tensor: ShardedTensor): + """Id of a sharded tensor.""" return (sharded_tensor.key, sharded_tensor.global_offset) @@ -101,6 +108,7 @@ def __init__(self, data_parallel_group, cpu_transfer=True): self.global_rank = torch.distributed.get_rank() def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): + """Main load method.""" self.maybe_init_gloo_group() all_tensors_sorted = self._build_load_plan(sharded_state_dict) self._exchange_loaded_tensors(all_tensors_sorted, sharded_state_dict, checkpoint_dir) @@ -109,6 +117,7 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path): return sharded_state_dict def summarize_load_times(self): + """Summarize load times.""" torch.distributed.barrier() logger.info('Checkpoint loading finished. Summary:') # TODO: `timers` keys are not guaranteed to be the same across ranks which causes hangs @@ -124,6 +133,7 @@ def summarize_load_times(self): @timed(verbose=False) def load_tensor_from_storage(self, checkpoint_dir, ten_meta: _ShardedTensorMetadata): + """Load tensor from storage.""" logger.debug(f'_load_from_array({ten_meta.sharded_tensor_no_data.key}) init') ret = _load_from_array( ten_meta.sharded_tensor_no_data, @@ -136,6 +146,7 @@ def load_tensor_from_storage(self, checkpoint_dir, ten_meta: _ShardedTensorMetad @timed() def maybe_init_gloo_group(self): + """Create Gloo groups.""" if not self.cpu_transfer: return all_groups = [None] * torch.distributed.get_world_size() @@ -211,7 +222,8 @@ def _exchange_loaded_tensors( ) logger.debug( - f'exchange {ten_meta.sharded_tensor_no_data.key}, {exchange_tensor.shape}({exchange_tensor.numel()}), broadcast({src_rank} -> {self.dp_group_ranks})' + f'exchange {ten_meta.sharded_tensor_no_data.key}, {exchange_tensor.shape}\ +({exchange_tensor.numel()}), broadcast({src_rank} -> {self.dp_group_ranks})' ) torch.distributed.broadcast( exchange_tensor, group=self.data_parallel_group, src=src_rank