From 5d609e4ea0cfa6964d45788a3ad8a922eceb35c2 Mon Sep 17 00:00:00 2001 From: Shay Aharon Date: Mon, 3 Feb 2025 07:39:06 -0800 Subject: [PATCH] ADLR/megatron-lm!2517 - Reuse global metadata for first saves MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Mikołaj Błaż --- .../cached_metadata_filesystem_reader.py | 38 ++++ .../strategies/filesystem_async.py | 18 ++ .../strategies/fully_parallel.py | 4 + .../strategies/state_dict_saver.py | 95 +++++++++- .../dist_checkpointing/strategies/torch.py | 74 +++++++- megatron/training/checkpointing.py | 17 +- .../model_config.yaml | 1 + .../model_config.yaml | 1 + .../model_config.yaml | 1 + .../test_global_metadata_reuse.py | 170 ++++++++++++++++++ 10 files changed, 407 insertions(+), 12 deletions(-) create mode 100644 megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py create mode 100644 tests/unit_tests/dist_checkpointing/test_global_metadata_reuse.py diff --git a/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py b/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py new file mode 100644 index 0000000000..f8cb0326e2 --- /dev/null +++ b/megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py @@ -0,0 +1,38 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +""" FS Reader with metadata cached support. """ + +import os +from typing import Union + +from torch.distributed.checkpoint import FileSystemReader, Metadata + + +class CachedMetadataFileSystemReader(FileSystemReader): + """ + Extends FileSystemReader to cache metadata for improved performance. + + Attributes: + _cached_metadata (Metadata or None): Cached metadata from the file system. + """ + + def __init__(self, path: Union[str, os.PathLike]) -> None: + """ + Initialize with file system path. + + Args: + path (Union[str, os.PathLike]): Path to the checkpoint directory or file. + """ + super().__init__(path=path) + self._cached_metadata = None + + def read_metadata(self) -> Metadata: + """ + Read metadata from file system, caching for subsequent calls. + + Returns: + Metadata: Checkpoint metadata. + """ + if self._cached_metadata is None: + self._cached_metadata = super().read_metadata() + return self._cached_metadata diff --git a/megatron/core/dist_checkpointing/strategies/filesystem_async.py b/megatron/core/dist_checkpointing/strategies/filesystem_async.py index 47ab4d1126..cc39931d06 100644 --- a/megatron/core/dist_checkpointing/strategies/filesystem_async.py +++ b/megatron/core/dist_checkpointing/strategies/filesystem_async.py @@ -1,6 +1,7 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. """ Storage writer for PyT Distributed format allowing asynchronous save. """ +import dataclasses import gc import logging import os @@ -76,6 +77,8 @@ def __init__(self, *args, separation_hint: Optional[str] = None, **kwargs): 'single_file_per_rank flag not supported for FileSystemWriterAsync' ) + self.can_run_decentralized_global_plan: bool = True + # Intermediate state between preparation and finalization self.write_buckets: Optional[List[WriteBucket]] = None self.results_queue: Optional[mp.Queue] = None @@ -334,6 +337,21 @@ def retrieve_write_results(self) -> List[WriteResult]: ) return list(chain.from_iterable(write_results.values())) + def prepare_decentralized_global_plan(self, local_plan: SavePlan) -> SavePlan: + """Instead of assigning indices by plan order, uses PyT rank (same outcome). + + Args: + local_plan (SavePlan): local plan to turn to a global plan + (without interactions with other ranks) + + Returns: + SavePlan - locally transformed plan equivalent to the plan that would be + created by the coordinator + """ + return dataclasses.replace( + local_plan, storage_data=_StoragePrefix(f"__{torch.distributed.get_rank()}_") + ) + def _split_by_size_and_type( bins: int, items: List[WriteItem], separation_hint: Optional[str] = None diff --git a/megatron/core/dist_checkpointing/strategies/fully_parallel.py b/megatron/core/dist_checkpointing/strategies/fully_parallel.py index c36ee5cc99..8c426c74cd 100644 --- a/megatron/core/dist_checkpointing/strategies/fully_parallel.py +++ b/megatron/core/dist_checkpointing/strategies/fully_parallel.py @@ -6,6 +6,7 @@ import torch import torch.distributed as dist +from torch.distributed.checkpoint import Metadata from megatron.core.dist_checkpointing import ShardedTensor from megatron.core.dist_checkpointing.core import CheckpointingException @@ -170,6 +171,7 @@ def __init__( self.exchange_algo = exchange_algo self.cached_distribution: Optional[ShardDistribution] = None + self.cached_global_metadata: Optional[Metadata] = None @debug_time("FullyParallelLoadStrategyWrapper.load", logger) def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> StateDict: @@ -249,6 +251,8 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St self.fill_in_deferred_sharded_tensors(sharded_tensors, all_loaded_tensors) merge(loaded_state_dict, sharded_tensors) + if hasattr(self.base_strategy, "cached_global_metadata"): + self.cached_global_metadata = self.base_strategy.cached_global_metadata return loaded_state_dict def _defer_loading_sharded_tensors( diff --git a/megatron/core/dist_checkpointing/strategies/state_dict_saver.py b/megatron/core/dist_checkpointing/strategies/state_dict_saver.py index 7b35209f21..65c394b9ba 100644 --- a/megatron/core/dist_checkpointing/strategies/state_dict_saver.py +++ b/megatron/core/dist_checkpointing/strategies/state_dict_saver.py @@ -4,7 +4,7 @@ from logging import getLogger from time import time -from typing import TYPE_CHECKING, Optional, Tuple, cast +from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -16,19 +16,37 @@ if TYPE_CHECKING: from .filesystem_async import FileSystemWriterAsync + from .torch import MCoreSavePlanner logger = getLogger(__name__) +from dataclasses import fields + + +def _compare_dataclasses(obj1, obj2): + if type(obj1) != type(obj2): + return f"Objects are of different types: {type(obj1)} and {type(obj2)}" + + differences = [] + for field in fields(obj1): + value1 = getattr(obj1, field.name) + value2 = getattr(obj2, field.name) + if value1 != value2: + differences.append(f"{field.name}: {value1} != {value2}") + + return differences if differences else "All fields are equal" + def save_state_dict_async_plan( state_dict: STATE_DICT_TYPE, storage_writer: 'FileSystemWriterAsync', process_group: Optional[dist.ProcessGroup] = None, coordinator_rank: int = 0, - planner: Optional[SavePlanner] = None, + planner: Optional[Union[SavePlanner, 'MCoreSavePlanner']] = None, cached_ckpt_structure: Optional[Tuple[SavePlan, SavePlan, bool]] = None, -) -> Tuple[Tuple['FileSystemWriterAsync', Metadata, _DistWrapper], SavePlan, bool]: + loaded_all_plans: Optional[List[SavePlan]] = None, +) -> Tuple[Tuple['FileSystemWriterAsync', Union[Metadata, None], _DistWrapper], SavePlan, bool]: """ First stage of saving a state dict to storage. @@ -62,7 +80,7 @@ def save_state_dict_async_plan( Returns: Tuple of: - storage writer (the one passed as input) - - metadata from planning + - metadata from planning (or None if we reuse cached global metadata) - distributed wrapper used for planning The return value of this function should be passed as an input to `save_state_dict_async_finalize` and cached_plan to skip `reduce_scatter` at planning. @@ -80,6 +98,7 @@ def save_state_dict_async_plan( global_metadata = None logger.debug(f"rank: {rank}, starting state dict save") local_plan = cached_local_plan + global_md_verify_reuse = False def local_step(): nonlocal local_plan @@ -101,11 +120,34 @@ def global_step(all_local_plans): return all_local_plans # Execute local and global planning + # Ideally we want to use the cached plan. Otherwise if the planner and storage_writer + # allow it (`can_run_decentralized_global_plan`) we gather the plans to create + # the metadata but prepare the plans independently on each rank. + # In the worst case we have to reduce_scatter all the plans. start_plan = time() if validated_cache_reuse and cached_central_plan: logger.debug(f"rank: {rank}, Passed cache reusable") local_step() central_plan = cached_central_plan + elif getattr(planner, 'can_run_decentralized_global_plan', False) and getattr( + storage_writer, 'can_run_decentralized_global_plan', False + ): + local_plan = local_step() + global_md_verify_reuse = verify_global_md_reuse( + loaded_all_plans, local_plan, rank, dist_wrapper + ) + + if not loaded_all_plans or not global_md_verify_reuse: + all_local_plans = dist_wrapper.gather_object(local_plan) + if dist_wrapper.is_coordinator: + _, global_metadata = planner.create_global_plan(all_local_plans) + global_metadata.all_local_plans = all_local_plans + else: + logger.debug(f"rank: {rank}, Passed cached global metadata") + global_metadata = None + local_plan = planner.create_decentralized_global_plan(local_plan) + local_plan = storage_writer.prepare_decentralized_global_plan(local_plan) + central_plan = local_plan else: central_plan = dist_wrapper.reduce_scatter("plan", local_step, global_step) central_plan = planner.finish_plan(central_plan) @@ -118,13 +160,56 @@ def global_step(all_local_plans): end = time() logger.debug(f"{time()} rank: {rank}, write(async) time: {end - start}") return ( - (storage_writer, cast(Metadata, global_metadata), dist_wrapper), + (storage_writer, global_metadata, dist_wrapper), central_plan, local_plan, cached_central_plan == central_plan, + global_md_verify_reuse, ) +def verify_global_md_reuse( + loaded_all_plans: List[SavePlan], local_plan: SavePlan, rank: int, dist_wrapper: _DistWrapper +) -> bool: + """ + Verifies that global metadata reuse is possible by checking the loaded plans from the + checkpoint are consistent, which means we have the same settings when resuming training. + Args: + loaded_all_plans: List[SavePlan], The loaded plans from the checkpoint + (stored in checkpoint metadata). + local_plan: SavePlan, The local save plan. + rank: Current process rank. + dist_wrapper (_DistWrapper): distributed wrapper created during planning + + Returns: True iff the global metadata reuse is possible. + + """ + logger.debug(f"verifying reuse of global metadata") + if not loaded_all_plans: + global_md_verify_reuse = False + logger.debug("loaded global metadata reuse verification: no loaded plans passed") + + elif len(loaded_all_plans) == dist_wrapper.get_world_size(): + local_verify_reuse = all( + getattr(local_plan, f.name) == getattr(loaded_all_plans[rank], f.name) + for f in fields(local_plan) + if f.name != 'storage_data' + ) + + if not local_verify_reuse: + logger.debug( + f"local_verify_reuse is False: diffs -" + f" {_compare_dataclasses(local_plan, loaded_all_plans[rank])}" + ) + all_results = torch.tensor([local_verify_reuse], dtype=torch.int, device='cuda') + torch.distributed.all_reduce(all_results, op=torch.distributed.ReduceOp.MIN) + # Check if all reduced results are True + global_md_verify_reuse = all_results.item() == 1 + else: + global_md_verify_reuse = False + return global_md_verify_reuse + + def save_state_dict_async_finalize( storage_writer: 'FileSystemWriterAsync', global_metadata: Metadata, dist_wrapper: _DistWrapper ) -> None: diff --git a/megatron/core/dist_checkpointing/strategies/torch.py b/megatron/core/dist_checkpointing/strategies/torch.py index ee83a7015d..9272677339 100644 --- a/megatron/core/dist_checkpointing/strategies/torch.py +++ b/megatron/core/dist_checkpointing/strategies/torch.py @@ -55,6 +55,7 @@ StrategyAction, register_default_strategy, ) +from .cached_metadata_filesystem_reader import CachedMetadataFileSystemReader from .filesystem_async import FileSystemWriterAsync from .resharding import ( TensorReformulationMetadata, @@ -443,6 +444,7 @@ def __init__( *args, dedup_replicated_tensors: Optional[bool] = None, nd_flattened_global_shapes: Optional[Dict[str, Tuple[int, ...]]] = None, + can_run_decentralized_global_plan: bool = True, **kwargs, ) -> None: # `dedup_replicated_tensors` was deprecated in 2.3; this check avoids warnings @@ -451,6 +453,14 @@ def __init__( kwargs['dedup_replicated_tensors'] = dedup_replicated_tensors super().__init__(*args, **kwargs) self.nd_flattened_global_shapes = nd_flattened_global_shapes or {} + self.can_run_decentralized_global_plan = can_run_decentralized_global_plan + if can_run_decentralized_global_plan: + assert ( + not dedup_replicated_tensors + ), 'Cannot run decentralized plan with dedup_replicated_tensors=True' + assert ( + not self.flatten_state_dict + ), 'Cannot run decentralized plan with flatten_state_dict=True' def create_local_plan(self) -> SavePlan: """Adds IOBytes write request on non-coordinator ranks.""" @@ -486,6 +496,23 @@ def create_global_plan(self, all_plans: List[MCoreSavePlan]) -> Tuple[List[SaveP metadata.mcore_data = dict(ChainMap(*(plan.mcore_data for plan in all_plans))) return global_plan, metadata + def create_decentralized_global_plan(self, local_plan: SavePlan) -> SavePlan: + """Nothing to do, just some checks. + + Args: + local_plan (SavePlan): local plan to turn to a global plan + (without interactions with other ranks) + + Returns: + SavePlan - locally transformed plan equivalent to the plan that would be + created by the coordinator + """ + assert ( + not self.flatten_state_dict + ), 'Cannot run decentralized plan with flatten_state_dict=True' + assert not local_plan.planner_data, 'Planner data should be empty with decentralized plan' + return local_plan + def transform_object(self, write_item: WriteItem, object: Any): """Make no transformations - bytes objects are already serialized.""" return object @@ -623,6 +650,8 @@ def __init__( self.separation_hint = separation_hint + self.validated_loaded_metadata_reuse = False + def async_save( self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path ) -> AsyncRequest: @@ -652,7 +681,14 @@ def async_save( # From the 3rd iteration, `save_state_dict_async_plan` will not generate `global_metadata` # (return None) so `self.cached_global_metadata` is reused args_cached_plans = None + loaded_all_plans = None if self.use_cached_ckpt_structure: + loaded_all_plans = getattr(self.cached_global_metadata, "all_local_plans", None) + if loaded_all_plans is None: + logger.debug( + "no all_local_plans in metadata - can't verify global metadata reuse..." + ) + args_cached_plans = ( self.cached_central_plan, self.cached_local_plan, @@ -664,24 +700,44 @@ def async_save( self.cached_central_plan, self.cached_local_plan, self.validated_cache_reuse, + self.validated_loaded_metadata_reuse, ) = save_state_dict_async_plan( pyt_state_dict, writer, None, coordinator, - planner=MCoreSavePlanner(dedup_replicated_tensors=not self.keep_only_main_replica), + planner=MCoreSavePlanner( + dedup_replicated_tensors=not self.keep_only_main_replica, flatten_state_dict=False + ), cached_ckpt_structure=args_cached_plans, + loaded_all_plans=loaded_all_plans, ) rank = torch.distributed.get_rank() if self.use_cached_ckpt_structure: - if self.validated_cache_reuse: + if ( + loaded_all_plans + and self.cached_global_metadata + and self.validated_loaded_metadata_reuse + ): + if coordinator == rank: + logger.debug( + f"rank: {rank}, reuse global metadata from loaded" + f" .metadata, {save_state_dict_ret[1]}" + ) + save_state_dict_ret = list(save_state_dict_ret) + save_state_dict_ret[1] = self.cached_global_metadata + + elif self.validated_cache_reuse: logger.debug(f"rank: {rank}, cache validated") if save_state_dict_ret[1]: # when global_metadata is not cached self.cached_global_metadata = save_state_dict_ret[1] # Cache Metadata # Only Coordinator rank holds cached global_metadata # (None is returned for global_metadata) elif coordinator == rank: - logger.debug(f"rank: {rank}, reuse metadata, {save_state_dict_ret[1]}") + logger.debug( + f"rank: {rank}, reuse global metadata cached from previous" + f" save iteration, {save_state_dict_ret[1]}" + ) save_state_dict_ret = list(save_state_dict_ret) save_state_dict_ret[1] = self.cached_global_metadata @@ -745,6 +801,10 @@ def get_reformulation_metadata( class TorchDistLoadShardedStrategy(LoadShardedStrategy): """Basic load strategy for the PyT Distributed format.""" + def __init__(self): + self.cached_global_metadata: Optional[Metadata] = None + super().__init__() + def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> StateDict: """Translates MCore ShardedTensors to PyT ShardedTensors & loads from PyT Distributed fmt. @@ -783,13 +843,19 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St sharded_state_dict, True, load_legacy_1d_flatten_tensors=has_legacy_1d_flattened_tensors ) # Load PyT Distributed format + fsr = CachedMetadataFileSystemReader(checkpoint_dir) checkpoint.load_state_dict( pyt_state_dict, - FileSystemReader(checkpoint_dir), + fsr, planner=MCoreLoadPlanner( shapes_validation_sharded_tensors=flexible_shape_sharded_tensors ), ) + + self.cached_global_metadata = ( + fsr.read_metadata() + ) # no storage interaction thanks to caching + pyt_state_dict = cast( Dict[str, Union[TorchShardedTensor, List[io.BytesIO]]], pyt_state_dict ) diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py index b67e895552..51080227ca 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -433,6 +433,14 @@ def save_checkpoint(iteration, model, optimizer, opt_param_scheduler, num_floati save_strategy = get_default_save_sharded_strategy(args.ckpt_format) if args.ckpt_assume_constant_structure and args.ckpt_format == 'torch_dist': save_strategy.use_cached_ckpt_structure = args.ckpt_assume_constant_structure + if checkpointing_context is not None and 'load_strategy' in checkpointing_context: + cached_global_metadata = getattr(checkpointing_context['load_strategy'], 'cached_global_metadata', None) + if cached_global_metadata is not None: + logger.debug("Plugging in the read metadata from the load strategy...") + save_strategy.cached_global_metadata = cached_global_metadata + else: + logger.debug("Failed to plug in the read metadata from the load strategy...") + if args.ckpt_fully_parallel_save: save_strategy = FullyParallelSaveStrategyWrapper(save_strategy, mpu.get_data_parallel_group(with_context_parallel=True), args.ckpt_assume_constant_structure) @@ -771,7 +779,8 @@ def _load_non_persistent_base_checkpoint( f'Loading from a non-persistent checkpoint (non-persistent iter {non_persistent_iteration})' ) return _load_global_dist_base_checkpoint( - non_persistent_global_dir, args, rank0, sharded_state_dict, non_persistent_iteration, False + non_persistent_global_dir, args, rank0, sharded_state_dict, non_persistent_iteration, False, + checkpointing_context=checkpointing_context ) elif args.non_persistent_ckpt_type == "local": intermediate_state_dict, checkpoint_name = checkpointing_context[ @@ -789,7 +798,7 @@ def _load_non_persistent_base_checkpoint( def _load_global_dist_base_checkpoint( - load_dir, args, rank0, sharded_state_dict, iteration, release + load_dir, args, rank0, sharded_state_dict, iteration, release, checkpointing_context=None ): """ Load the base state_dict from the given directory containing the global distributed checkpoint """ if rank0: @@ -813,6 +822,8 @@ def _load_global_dist_base_checkpoint( load_strategy = FullyParallelLoadStrategyWrapper( load_strategy, mpu.get_data_parallel_group(with_context_parallel=True) ) + if checkpointing_context is not None: + checkpointing_context["load_strategy"] = load_strategy state_dict = dist_checkpointing.load(sharded_state_dict, checkpoint_name, load_strategy, strict=args.dist_ckpt_strictness) return state_dict, checkpoint_name, release, CheckpointType.GLOBAL @@ -886,7 +897,7 @@ def _load_base_checkpoint( # Handle global distributed checkpoint if is_dist_ckpt: return _load_global_dist_base_checkpoint( - load_dir, args, rank0, sharded_state_dict, iteration, release + load_dir, args, rank0, sharded_state_dict, iteration, release, checkpointing_context=checkpointing_context ) # Handle global legacy checkpoint if rank0: diff --git a/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp4_vp1_resume_torch_dist_dgx_a100_1N8G/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp4_vp1_resume_torch_dist_dgx_a100_1N8G/model_config.yaml index 1a5bde9821..ffed68baf4 100644 --- a/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp4_vp1_resume_torch_dist_dgx_a100_1N8G/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp4_vp1_resume_torch_dist_dgx_a100_1N8G/model_config.yaml @@ -46,6 +46,7 @@ MODEL_ARGS: --use-checkpoint-opt_param-scheduler: true --use-mcore-models: true --ckpt-format: torch_dist + --ckpt-assume-constant-structure: true --data-cache-path: ${DATA_CACHE_PATH} --bf16: true --log-memory-to-tensorboard: true diff --git a/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp4_vp1_resume_torch_dist_dist_optimizer_overlap_grad_reduce_param_gather_dgx_a100_1N8G/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp4_vp1_resume_torch_dist_dist_optimizer_overlap_grad_reduce_param_gather_dgx_a100_1N8G/model_config.yaml index a61f7bc062..e2587a1494 100644 --- a/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp4_vp1_resume_torch_dist_dist_optimizer_overlap_grad_reduce_param_gather_dgx_a100_1N8G/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp4_vp1_resume_torch_dist_dist_optimizer_overlap_grad_reduce_param_gather_dgx_a100_1N8G/model_config.yaml @@ -51,6 +51,7 @@ MODEL_ARGS: --use-checkpoint-opt_param-scheduler: true --use-mcore-models: true --ckpt-format: torch_dist + --ckpt-assume-constant-structure: true --data-cache-path: ${DATA_CACHE_PATH} --bf16: true --log-memory-to-tensorboard: true diff --git a/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp2_pp1_resume_torch_dist_te_8experts2parallel_dist_optimizer_dgx_a100_1N8G/model_config.yaml b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp2_pp1_resume_torch_dist_te_8experts2parallel_dist_optimizer_dgx_a100_1N8G/model_config.yaml index 5dc27dee75..fb7007baa4 100644 --- a/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp2_pp1_resume_torch_dist_te_8experts2parallel_dist_optimizer_dgx_a100_1N8G/model_config.yaml +++ b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp2_pp1_resume_torch_dist_te_8experts2parallel_dist_optimizer_dgx_a100_1N8G/model_config.yaml @@ -52,6 +52,7 @@ MODEL_ARGS: --use-checkpoint-opt_param-scheduler: true --use-mcore-models: true --ckpt-format: torch_dist + --ckpt-assume-constant-structure: true --data-cache-path: ${DATA_CACHE_PATH} --bf16: true --log-memory-to-tensorboard: true diff --git a/tests/unit_tests/dist_checkpointing/test_global_metadata_reuse.py b/tests/unit_tests/dist_checkpointing/test_global_metadata_reuse.py new file mode 100644 index 0000000000..5b3a9ddad7 --- /dev/null +++ b/tests/unit_tests/dist_checkpointing/test_global_metadata_reuse.py @@ -0,0 +1,170 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + + +from types import SimpleNamespace +from unittest import mock + +import pytest + +from megatron.training.checkpointing import load_checkpoint, save_checkpoint +from tests.unit_tests.dist_checkpointing import ( + TempNamedDir, + init_basic_mock_args, + init_checkpointing_mock_args, + setup_model_and_optimizer, +) +from tests.unit_tests.test_utilities import Utils + + +class TestGlobalMetadataReuse: + def setup_method(self, method): + pass + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.parametrize(('tp,pp'), [(2, 4)]) + def test_global_metadata_reuse(self, tmp_path_dist_ckpt, tp, pp): + Utils.initialize_model_parallel(tp, pp) + num_floating_point_operations_so_far = 0 + model, optimizer = setup_model_and_optimizer(1, tp, pp) + opt_param_scheduler = None + + mock_args = SimpleNamespace() + with TempNamedDir( + tmp_path_dist_ckpt / "test_global_metadata_reuse" + ) as non_persistent_ckpt_dir, mock.patch( + 'megatron.training.checkpointing.get_args', new=lambda: mock_args + ), mock.patch( + "megatron.training.checkpointing.update_num_microbatches" + ): + init_basic_mock_args(mock_args, tp, pp) + init_checkpointing_mock_args(mock_args, non_persistent_ckpt_dir) + mock_args.non_persistent_ckpt_type = "global" + mock_args.ckpt_assume_constant_structure = True + save_ckpt_context = {} + + # Check we avoid reduce_scatter + with mock.patch( + 'torch.distributed.checkpoint.utils._DistWrapper.reduce_scatter' + ) as reduce_scatter_mock: + save_checkpoint( + 1, + model, + optimizer, + opt_param_scheduler, + num_floating_point_operations_so_far, + save_ckpt_context, + ) + + assert reduce_scatter_mock.call_count == 0 + + assert save_ckpt_context['save_strategy'].cached_global_metadata is None + + resume_ckpt_context = {} + _, _ = load_checkpoint( + model, optimizer, opt_param_scheduler, checkpointing_context=resume_ckpt_context + ) + + load_strategy_cached_metadata = resume_ckpt_context[ + 'load_strategy' + ].cached_global_metadata + assert load_strategy_cached_metadata is not None + assert getattr(load_strategy_cached_metadata, "all_local_plans", None) is not None + + # Check we avoid reduce_scatter + with mock.patch( + 'torch.distributed.checkpoint.utils._DistWrapper.reduce_scatter' + ) as reduce_scatter_mock: + save_checkpoint( + 2, + model, + optimizer, + opt_param_scheduler, + num_floating_point_operations_so_far, + resume_ckpt_context, + ) + assert reduce_scatter_mock.call_count == 0 + + assert ( + load_strategy_cached_metadata + is resume_ckpt_context['save_strategy'].cached_global_metadata + ) + + assert resume_ckpt_context['save_strategy'].validated_loaded_metadata_reuse + + @pytest.mark.parametrize(('tp,pp'), [(2, 4)]) + def test_no_global_metadata_reuse_on_different_parallelism(self, tmp_path_dist_ckpt, tp, pp): + Utils.initialize_model_parallel(tp, pp) + num_floating_point_operations_so_far = 0 + model, optimizer = setup_model_and_optimizer(1, tp, pp) + opt_param_scheduler = None + + mock_args = SimpleNamespace() + with TempNamedDir( + tmp_path_dist_ckpt / "test_global_metadata_reuse" + ) as non_persistent_ckpt_dir, mock.patch( + 'megatron.training.checkpointing.get_args', new=lambda: mock_args + ), mock.patch( + "megatron.training.checkpointing.update_num_microbatches" + ): + init_basic_mock_args(mock_args, tp, pp) + init_checkpointing_mock_args(mock_args, non_persistent_ckpt_dir) + mock_args.non_persistent_ckpt_type = "global" + mock_args.ckpt_assume_constant_structure = True + mock_args.ckpt_fully_parallel_save = True + + save_ckpt_context = {} + + # Check we avoid reduce_scatter + with mock.patch( + 'torch.distributed.checkpoint.utils._DistWrapper.reduce_scatter' + ) as reduce_scatter_mock: + save_checkpoint( + 1, + model, + optimizer, + opt_param_scheduler, + num_floating_point_operations_so_far, + save_ckpt_context, + ) + + assert reduce_scatter_mock.call_count == 0 + + assert save_ckpt_context['save_strategy'].base_strategy.cached_global_metadata is None + + Utils.destroy_model_parallel() + Utils.initialize_model_parallel(pp, tp) + model, optimizer = setup_model_and_optimizer(1, pp, tp) + init_basic_mock_args(mock_args, pp, tp) + mock_args.no_load_rng = True + + resume_ckpt_context = {} + _, _ = load_checkpoint( + model, optimizer, opt_param_scheduler, checkpointing_context=resume_ckpt_context + ) + + load_strategy_cached_metadata = resume_ckpt_context[ + 'load_strategy' + ].cached_global_metadata + + assert load_strategy_cached_metadata is not None + assert getattr(load_strategy_cached_metadata, "all_local_plans", None) is not None + + # Check we avoid reduce_scatter + with mock.patch( + 'torch.distributed.checkpoint.utils._DistWrapper.reduce_scatter' + ) as reduce_scatter_mock: + save_checkpoint( + 2, + model, + optimizer, + opt_param_scheduler, + num_floating_point_operations_so_far, + resume_ckpt_context, + ) + assert reduce_scatter_mock.call_count == 0 + + assert not resume_ckpt_context[ + 'save_strategy' + ].base_strategy.validated_loaded_metadata_reuse