Skip to content

Commit

Permalink
ADLR/megatron-lm!2517 - Reuse global metadata for first saves
Browse files Browse the repository at this point in the history
Co-authored-by: Mikołaj Błaż <[email protected]>
  • Loading branch information
2 people authored and ko3n1g committed Feb 3, 2025
1 parent 2a9793d commit 5d609e4
Show file tree
Hide file tree
Showing 10 changed files with 407 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

""" FS Reader with metadata cached support. """

import os
from typing import Union

from torch.distributed.checkpoint import FileSystemReader, Metadata


class CachedMetadataFileSystemReader(FileSystemReader):
"""
Extends FileSystemReader to cache metadata for improved performance.
Attributes:
_cached_metadata (Metadata or None): Cached metadata from the file system.
"""

def __init__(self, path: Union[str, os.PathLike]) -> None:
"""
Initialize with file system path.
Args:
path (Union[str, os.PathLike]): Path to the checkpoint directory or file.
"""
super().__init__(path=path)
self._cached_metadata = None

def read_metadata(self) -> Metadata:
"""
Read metadata from file system, caching for subsequent calls.
Returns:
Metadata: Checkpoint metadata.
"""
if self._cached_metadata is None:
self._cached_metadata = super().read_metadata()
return self._cached_metadata
18 changes: 18 additions & 0 deletions megatron/core/dist_checkpointing/strategies/filesystem_async.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

""" Storage writer for PyT Distributed format allowing asynchronous save. """
import dataclasses
import gc
import logging
import os
Expand Down Expand Up @@ -76,6 +77,8 @@ def __init__(self, *args, separation_hint: Optional[str] = None, **kwargs):
'single_file_per_rank flag not supported for FileSystemWriterAsync'
)

self.can_run_decentralized_global_plan: bool = True

# Intermediate state between preparation and finalization
self.write_buckets: Optional[List[WriteBucket]] = None
self.results_queue: Optional[mp.Queue] = None
Expand Down Expand Up @@ -334,6 +337,21 @@ def retrieve_write_results(self) -> List[WriteResult]:
)
return list(chain.from_iterable(write_results.values()))

def prepare_decentralized_global_plan(self, local_plan: SavePlan) -> SavePlan:
"""Instead of assigning indices by plan order, uses PyT rank (same outcome).
Args:
local_plan (SavePlan): local plan to turn to a global plan
(without interactions with other ranks)
Returns:
SavePlan - locally transformed plan equivalent to the plan that would be
created by the coordinator
"""
return dataclasses.replace(
local_plan, storage_data=_StoragePrefix(f"__{torch.distributed.get_rank()}_")
)


def _split_by_size_and_type(
bins: int, items: List[WriteItem], separation_hint: Optional[str] = None
Expand Down
4 changes: 4 additions & 0 deletions megatron/core/dist_checkpointing/strategies/fully_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
import torch.distributed as dist
from torch.distributed.checkpoint import Metadata

from megatron.core.dist_checkpointing import ShardedTensor
from megatron.core.dist_checkpointing.core import CheckpointingException
Expand Down Expand Up @@ -170,6 +171,7 @@ def __init__(
self.exchange_algo = exchange_algo

self.cached_distribution: Optional[ShardDistribution] = None
self.cached_global_metadata: Optional[Metadata] = None

@debug_time("FullyParallelLoadStrategyWrapper.load", logger)
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> StateDict:
Expand Down Expand Up @@ -249,6 +251,8 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St

self.fill_in_deferred_sharded_tensors(sharded_tensors, all_loaded_tensors)
merge(loaded_state_dict, sharded_tensors)
if hasattr(self.base_strategy, "cached_global_metadata"):
self.cached_global_metadata = self.base_strategy.cached_global_metadata
return loaded_state_dict

def _defer_loading_sharded_tensors(
Expand Down
95 changes: 90 additions & 5 deletions megatron/core/dist_checkpointing/strategies/state_dict_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from logging import getLogger
from time import time
from typing import TYPE_CHECKING, Optional, Tuple, cast
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import torch
import torch.distributed as dist
Expand All @@ -16,19 +16,37 @@

if TYPE_CHECKING:
from .filesystem_async import FileSystemWriterAsync
from .torch import MCoreSavePlanner


logger = getLogger(__name__)

from dataclasses import fields


def _compare_dataclasses(obj1, obj2):
if type(obj1) != type(obj2):
return f"Objects are of different types: {type(obj1)} and {type(obj2)}"

differences = []
for field in fields(obj1):
value1 = getattr(obj1, field.name)
value2 = getattr(obj2, field.name)
if value1 != value2:
differences.append(f"{field.name}: {value1} != {value2}")

return differences if differences else "All fields are equal"


def save_state_dict_async_plan(
state_dict: STATE_DICT_TYPE,
storage_writer: 'FileSystemWriterAsync',
process_group: Optional[dist.ProcessGroup] = None,
coordinator_rank: int = 0,
planner: Optional[SavePlanner] = None,
planner: Optional[Union[SavePlanner, 'MCoreSavePlanner']] = None,
cached_ckpt_structure: Optional[Tuple[SavePlan, SavePlan, bool]] = None,
) -> Tuple[Tuple['FileSystemWriterAsync', Metadata, _DistWrapper], SavePlan, bool]:
loaded_all_plans: Optional[List[SavePlan]] = None,
) -> Tuple[Tuple['FileSystemWriterAsync', Union[Metadata, None], _DistWrapper], SavePlan, bool]:
"""
First stage of saving a state dict to storage.
Expand Down Expand Up @@ -62,7 +80,7 @@ def save_state_dict_async_plan(
Returns: Tuple of:
- storage writer (the one passed as input)
- metadata from planning
- metadata from planning (or None if we reuse cached global metadata)
- distributed wrapper used for planning
The return value of this function should be passed as an input to
`save_state_dict_async_finalize` and cached_plan to skip `reduce_scatter` at planning.
Expand All @@ -80,6 +98,7 @@ def save_state_dict_async_plan(
global_metadata = None
logger.debug(f"rank: {rank}, starting state dict save")
local_plan = cached_local_plan
global_md_verify_reuse = False

def local_step():
nonlocal local_plan
Expand All @@ -101,11 +120,34 @@ def global_step(all_local_plans):
return all_local_plans

# Execute local and global planning
# Ideally we want to use the cached plan. Otherwise if the planner and storage_writer
# allow it (`can_run_decentralized_global_plan`) we gather the plans to create
# the metadata but prepare the plans independently on each rank.
# In the worst case we have to reduce_scatter all the plans.
start_plan = time()
if validated_cache_reuse and cached_central_plan:
logger.debug(f"rank: {rank}, Passed cache reusable")
local_step()
central_plan = cached_central_plan
elif getattr(planner, 'can_run_decentralized_global_plan', False) and getattr(
storage_writer, 'can_run_decentralized_global_plan', False
):
local_plan = local_step()
global_md_verify_reuse = verify_global_md_reuse(
loaded_all_plans, local_plan, rank, dist_wrapper
)

if not loaded_all_plans or not global_md_verify_reuse:
all_local_plans = dist_wrapper.gather_object(local_plan)
if dist_wrapper.is_coordinator:
_, global_metadata = planner.create_global_plan(all_local_plans)
global_metadata.all_local_plans = all_local_plans
else:
logger.debug(f"rank: {rank}, Passed cached global metadata")
global_metadata = None
local_plan = planner.create_decentralized_global_plan(local_plan)
local_plan = storage_writer.prepare_decentralized_global_plan(local_plan)
central_plan = local_plan
else:
central_plan = dist_wrapper.reduce_scatter("plan", local_step, global_step)
central_plan = planner.finish_plan(central_plan)
Expand All @@ -118,13 +160,56 @@ def global_step(all_local_plans):
end = time()
logger.debug(f"{time()} rank: {rank}, write(async) time: {end - start}")
return (
(storage_writer, cast(Metadata, global_metadata), dist_wrapper),
(storage_writer, global_metadata, dist_wrapper),
central_plan,
local_plan,
cached_central_plan == central_plan,
global_md_verify_reuse,
)


def verify_global_md_reuse(
loaded_all_plans: List[SavePlan], local_plan: SavePlan, rank: int, dist_wrapper: _DistWrapper
) -> bool:
"""
Verifies that global metadata reuse is possible by checking the loaded plans from the
checkpoint are consistent, which means we have the same settings when resuming training.
Args:
loaded_all_plans: List[SavePlan], The loaded plans from the checkpoint
(stored in checkpoint metadata).
local_plan: SavePlan, The local save plan.
rank: Current process rank.
dist_wrapper (_DistWrapper): distributed wrapper created during planning
Returns: True iff the global metadata reuse is possible.
"""
logger.debug(f"verifying reuse of global metadata")
if not loaded_all_plans:
global_md_verify_reuse = False
logger.debug("loaded global metadata reuse verification: no loaded plans passed")

elif len(loaded_all_plans) == dist_wrapper.get_world_size():
local_verify_reuse = all(
getattr(local_plan, f.name) == getattr(loaded_all_plans[rank], f.name)
for f in fields(local_plan)
if f.name != 'storage_data'
)

if not local_verify_reuse:
logger.debug(
f"local_verify_reuse is False: diffs -"
f" {_compare_dataclasses(local_plan, loaded_all_plans[rank])}"
)
all_results = torch.tensor([local_verify_reuse], dtype=torch.int, device='cuda')
torch.distributed.all_reduce(all_results, op=torch.distributed.ReduceOp.MIN)
# Check if all reduced results are True
global_md_verify_reuse = all_results.item() == 1
else:
global_md_verify_reuse = False
return global_md_verify_reuse


def save_state_dict_async_finalize(
storage_writer: 'FileSystemWriterAsync', global_metadata: Metadata, dist_wrapper: _DistWrapper
) -> None:
Expand Down
Loading

0 comments on commit 5d609e4

Please sign in to comment.