Skip to content

Commit

Permalink
Merge branch 'fix_1d_flatten_tensor' into 'main'
Browse files Browse the repository at this point in the history
Handle 1d flatten shard-tensor edge-case

See merge request ADLR/megatron-lm!2468
  • Loading branch information
ko3n1g committed Jan 22, 2025
2 parents 5c12382 + 244ec97 commit 0d59157
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 110 deletions.
27 changes: 15 additions & 12 deletions megatron/core/dist_checkpointing/strategies/resharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import math
from dataclasses import dataclass
from itertools import product
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Tuple, Union

import numpy as np
import torch
Expand All @@ -27,7 +27,6 @@
extract_matching_values,
)
from megatron.core.dist_checkpointing.mapping import (
ReplicaId,
ShardedStateDict,
ShardedTensorFactory,
StateDict,
Expand Down Expand Up @@ -84,11 +83,7 @@ def is_nd_flattened_tensor(sh_ten: Any) -> bool:
Returns:
bool: whether the given object is a flattened ShardedTensor and is N-dimensional (N > 1)
"""
return (
isinstance(sh_ten, ShardedTensor)
and sh_ten.flattened_range is not None
and len(sh_ten.global_shape) > 1
)
return isinstance(sh_ten, ShardedTensor) and sh_ten.flattened_range is not None


# information needed to restore. With current implementation, this is a nested state dict
Expand Down Expand Up @@ -132,8 +127,12 @@ def maybe_reformulate_nd_flattened_tensor(sh_ten: Any):
try:
sh_ten_reformulation_metadata = reformulation_metadata[sh_ten.key]
except KeyError as e:
# Handle legacy checkpointing where 1-D flatten tensor metadata was not saved
if len(sh_ten.global_shape) == 1:
return sh_ten
raise CheckpointingException(
f'Missing reformulation metadata for tensor {sh_ten}. Existing keys: {reformulation_metadata.keys()}'
f'Missing reformulation metadata for tensor {sh_ten}. '
f'Existing keys: {reformulation_metadata.keys()}'
) from e

ckpt_actual_saved_shape = sh_ten_reformulation_metadata.ckpt_reform_global_shape
Expand Down Expand Up @@ -235,13 +234,16 @@ def reformulate_single_nd_flattened_tensor(
):
# without `int`, it's an exact offset of the app shard expressed in ckpt_local_shape units
first_overlap_dim_offset = int(ckpt_fragm / app_fragm * app_chunk_dim_offset)
# `math.ceil` argument is an exact offset of the app next shard expressed in ckpt_local_shape units
# `math.ceil` argument is an exact offset of the app next shard expressed
# in ckpt_local_shape units
next_overlap_dim_offset = math.ceil(ckpt_fragm / app_fragm * (app_chunk_dim_offset + 1))
overlap_dim_offsets.append(range(first_overlap_dim_offset, next_overlap_dim_offset))

logger.debug(
f'Generated the following number of overlap shards for each dimension: {list(map(len, overlap_dim_offsets))}'
f' for fragmentation ckpt {ckpt_axis_fragmentation} vs app {sh_ten.axis_fragmentations} and chunk offset {sh_ten.local_chunk_offset_in_global()}'
f'Generated the following number of overlap shards for each dimension: '
f'{list(map(len, overlap_dim_offsets))} for fragmentation ckpt '
f'{ckpt_axis_fragmentation} vs app {sh_ten.axis_fragmentations} '
f'and chunk offset {sh_ten.local_chunk_offset_in_global()}'
)
reformulated_sh_tens = {}
for chunk_offset in product(*overlap_dim_offsets):
Expand Down Expand Up @@ -286,7 +288,8 @@ def sh_ten_merge_fn(sub_state_dict):
# For each ckpt shard, we fill the appropriate application shard part
dest_ten = app_non_flat_ten
src_ten = ckpt_ten.view(ckpt_local_shape)
# We don't need narrowing over `prepend_axis_num` axes so we take the [sh_ten.prepend_axis_num:] offsets slice
# We don't need narrowing over `prepend_axis_num` axes so we take
# the [sh_ten.prepend_axis_num:] offsets slice
for (
dim,
offset_for_saved_tensor,
Expand Down
85 changes: 45 additions & 40 deletions megatron/core/dist_checkpointing/strategies/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ def flat_copy(path: OBJ_PATH, value: Any) -> None:


def sharded_tensor_to_torch_sharded_tensor(
sh_tens: List[ShardedTensor], rank: Optional[int] = None
sh_tens: List[ShardedTensor],
rank: Optional[int] = None,
load_legacy_1d_flatten_tensors: bool = False,
) -> TorchShardedTensor:
"""Convert MCore ShardedTensor to PyT ShardedTensor. PyT requires information about all chunks.
Expand All @@ -138,13 +140,12 @@ def sharded_tensor_to_torch_sharded_tensor(
NOTE: this function assumes regular (grid) sharding of the MCore ShardedTensor.
The only local irregularities could be introduced with a `flattened_range` attribute.
This function handles 3 different type of ShardedTensors:
This function handles 2 different type of ShardedTensors:
1. Non-flat regular ShardedTensors (`not has_flattened_range`)
2. 1D flattened ShardedTensors (`is_flattened_range_1d`)
3. N-D flattened ShardedTensors (`has_flattened_range`)
2. N-D flattened ShardedTensors (`has_flattened_range`)
(1) and (2) type are saved according to their original shape.
Type (3) however requires global shape adjustment for efficiency:
(1) type are saved according to their original shape.
Type (2) however requires global shape adjustment for efficiency:
we treat [X, Y, Z] global shape tensor with local shape [x, y, z]
as a [X // x, Y // y, Z // z, x * y * z] tensor with last axis
partitioned according to `flattened_range` slices.
Expand All @@ -154,6 +155,8 @@ def sharded_tensor_to_torch_sharded_tensor(
sh_tens (List[ShardedTensor]): list of sharded tensors to convert
rank (int, optional): current process rank passed to PyT ShardedTensor.
If None, assumes rank in the default pg.
load_legacy_1d_flatten_tensors (bool, optional): flag indicating if 1-D flattened tensors
should be loaded in a legacy way. Defaults to False.
Returns (TorchShardedTensor): PyT ShardedTensor containing all passed shards.
Expand All @@ -163,41 +166,21 @@ def sharded_tensor_to_torch_sharded_tensor(

some_sh_ten = sh_tens[0]
has_flattened_range = some_sh_ten.flattened_range is not None
is_flattened_range_1d = has_flattened_range and len(some_sh_ten.global_shape) == 1

for sh_ten in sh_tens:
assert (sh_ten.flattened_range is not None) == has_flattened_range, sh_tens
if not sh_ten.data.is_contiguous():
sh_ten.data = sh_ten.data.contiguous()

if load_legacy_1d_flatten_tensors and len(some_sh_ten.global_shape) == 1:
# Legacy 1-D flattened tensors are loaded as non-flat regular ShardedTensors
has_flattened_range = False

local_global_offsets = {}

prepend_axis_num = sh_tens[0].prepend_axis_num
# Determine local shards according to tensor type (see docs)
if is_flattened_range_1d:
# Type (2) case: 1D flattened ShardedTensors
for sh_ten in sh_tens:
assert len(sh_ten.global_offset) == 1, sh_ten
assert sh_ten.prepend_axis_num == 0, sh_ten
local_global_offsets.setdefault(sh_ten.global_offset, []).append(sh_ten)

global_shape = some_sh_ten.global_shape
offsets_shape = (
some_sh_ten.local_shape
) # local shape is not flattened, we need it for chunk offsets

local_shards = [
Shard.from_tensor_and_offsets(
sh_ten.data,
[
sh_ten.global_offset[0] + sh_ten.flattened_range.start
], # additional flattened offset
rank,
)
for sh_ten in sh_tens
]

elif has_flattened_range:
if has_flattened_range:
# Type (3) case: N-D flattened ShardedTensors
for sh_ten in sh_tens:
local_global_offsets.setdefault(sh_ten.local_chunk_offset_in_global(), []).append(
Expand Down Expand Up @@ -250,10 +233,7 @@ def sharded_tensor_to_torch_sharded_tensor(
# local shard
placement = f"rank:{rank}/cuda"
for sh_ten in local_global_offsets[offset]:
if is_flattened_range_1d:
offset = (sh_ten.global_offset[0] + sh_ten.flattened_range.start,)
size = sh_ten.data.shape
elif has_flattened_range:
if has_flattened_range:
assert offset == sh_ten.local_chunk_offset_in_global()
# This is not an actual offset, but an offset of the whole shard
# This is needed for a PyT Dist internal integrity check
Expand All @@ -270,7 +250,7 @@ def sharded_tensor_to_torch_sharded_tensor(
# Due to a bug in PyT 24.05 container we must specify some concrete rank within a world size.
# The exact rank doesn't matter as long as it's different than my rank - hence (rank + 1) % WS.
placement = f"rank:{(rank + 1) % world_size}/cuda"
if has_flattened_range and not is_flattened_range_1d:
if has_flattened_range:
offset = offset + (0,)
size = (1,) * len(offsets_shape) + global_shape[-1:]
else:
Expand All @@ -296,7 +276,7 @@ def sharded_tensor_to_torch_sharded_tensor(
# This won't be stored in the checkpoint, only for runtime purposes
pyt_sh_ten.mcore_sh_ten = sh_ten.without_data()
pyt_sh_ten.mcore_metadata = {}
if has_flattened_range and not is_flattened_range_1d:
if has_flattened_range:
pyt_sh_ten.mcore_metadata['nd_reformulated_orig_global_shape'] = sh_ten.global_shape
return pyt_sh_ten

Expand All @@ -305,6 +285,7 @@ def mcore_to_pyt_state_dict(
state_dict: Dict[str, List[ShardedBase]],
is_loading: bool = False,
init_device: torch.device = torch.device("cpu"),
load_legacy_1d_flatten_tensors: bool = False,
) -> Dict[str, Union[TorchShardedTensor, io.BytesIO]]:
"""Convert state dict with ShardedTensors and ShardedObjects
to state dict compatible with PyT Dist format.
Expand Down Expand Up @@ -348,7 +329,9 @@ def _mcore_to_torch_sharded_tensor(sh_tens: List[ShardedTensor]) -> TorchSharded
if sh_ten.allow_shape_mismatch and is_loading:
sh_ten.data.zero_()

torch_sh_ten = sharded_tensor_to_torch_sharded_tensor(sh_tens, rank)
torch_sh_ten = sharded_tensor_to_torch_sharded_tensor(
sh_tens, rank, load_legacy_1d_flatten_tensors
)
torch_sh_ten.key = sh_tens[0].key
return torch_sh_ten

Expand Down Expand Up @@ -535,6 +518,12 @@ def _validate_global_shapes(self, metadata, sharded_tensors):
else:
expected_shape = nd_flattened_tensor_reformulated_global_shape(sh_ten)
if loaded_shape != expected_shape:
if is_nd_flattened_tensor(sh_ten) and len(sh_ten.global_shape) == 1:
# Handle legacy 1-D flattened tensors checkpoint format
# where the global shape is not stored in the metadata
expected_shape = sh_ten.global_shape
if loaded_shape == expected_shape:
continue
_msg = (
f'Global shape mismatch for loaded ({loaded_shape})'
f' and expected ({expected_shape}) tensor'
Expand Down Expand Up @@ -736,6 +725,12 @@ def get_reformulation_metadata(
'nd_reformulated_orig_global_shape'
]
except KeyError as e:
if len(sh_ten.global_shape) == 1:
warnings.warn(
f'Legacy checkpoint format detected for 1-D flattened tensor {sh_ten}. '
'Skip metadata reformulation.'
)
continue
raise CheckpointingException(
f'Cannot find global shape metadata for N-D flattened tensor {sh_ten} '
f'in checkpoint metadata: {ckpt_metadata.mcore_data}'
Expand All @@ -761,10 +756,18 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St
Returns: loaded state dict
"""
# Apply N-D tensors resharding
reformulation_metadata = get_reformulation_metadata(sharded_state_dict, checkpoint_dir)
sharded_state_dict, formulation_restore_data = apply_nd_flattened_tensors_reformulation(
sharded_state_dict, get_reformulation_metadata(sharded_state_dict, checkpoint_dir)
sharded_state_dict, reformulation_metadata
)

# Check if there are legacy 1-D flattened tensors in the checkpoint
has_legacy_1d_flattened_tensors = False
for sh_ten in nested_values(sharded_state_dict):
if is_nd_flattened_tensor(sh_ten) and sh_ten.key not in reformulation_metadata:
has_legacy_1d_flattened_tensors = True
break

flexible_shape_sharded_tensors = [
sh_ten
for sh_ten in nested_values(sharded_state_dict)
Expand All @@ -776,7 +779,9 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St
(sharded_state_dict, flat_mapping, rename_mapping) = (
_replace_state_dict_keys_with_sharded_keys(sharded_state_dict)
)
pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, True)
pyt_state_dict = mcore_to_pyt_state_dict(
sharded_state_dict, True, load_legacy_1d_flatten_tensors=has_legacy_1d_flattened_tensors
)
# Load PyT Distributed format
checkpoint.load_state_dict(
pyt_state_dict,
Expand Down
106 changes: 48 additions & 58 deletions tests/unit_tests/dist_checkpointing/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,39 @@ def sharded_state_dict(self):
return sharded_state_dict


class Model1dFlattenTensor(torch.nn.Module):
"""This model is used to test whether a 1d flatten tensor can be correctly
transformed into torch dist-ckpt form
"""

def __init__(self):
super().__init__()
self.config = TransformerConfig(hidden_size=128, num_attention_heads=1, num_layers=1)
self.weight_1d = torch.nn.Parameter(torch.randn(self.config.hidden_size))

def sharded_state_dict(self):
sharded_state_dict = self.state_dict(keep_vars=True)
sharded_state_dict['weight_1d'] = ShardedTensor.from_rank_offsets(
'weight_1d',
sharded_state_dict['weight_1d'],
(
(
0,
parallel_state.get_tensor_model_parallel_rank(),
parallel_state.get_tensor_model_parallel_world_size(),
)
),
replica_id=(
(
parallel_state.get_pipeline_model_parallel_rank(),
0,
parallel_state.get_data_parallel_rank(with_context_parallel=True),
)
),
)
return sharded_state_dict


class TestOptimizer:
def setup_method(self, method):
pass
Expand Down Expand Up @@ -152,6 +185,17 @@ def initialize_small_model(pre_process=True, post_process=True, seed=0, **config
return SwigluFactoryModel()


def initialize_1d_flatten_tensor_model(
pre_process=True, post_process=True, seed=0, **config_kwargs
):
# This model is used to test whether a 1d flatten tensor can be correctly
# transformed into torch dist-ckpt form
torch.manual_seed(seed)
model_parallel_cuda_manual_seed(seed)

return Model1dFlattenTensor()


def load_checkpoint_no_arg_checks(*args, **kwargs):
with mock.patch('megatron.training.checkpointing.check_checkpoint_args'):
with mock.patch('megatron.training.checkpointing.update_num_microbatches'):
Expand All @@ -165,7 +209,10 @@ def setup_method(self, method):
def teardown_method(self, method):
Utils.destroy_model_parallel()

@pytest.mark.parametrize("initialize_fn", [initialize_small_model, initialize_gpt_model])
@pytest.mark.parametrize(
"initialize_fn",
[initialize_small_model, initialize_gpt_model, initialize_1d_flatten_tensor_model],
)
@pytest.mark.parametrize("use_fpsl", [False, True])
# TODO: changing DP doesn't work in unit tests because of NCCL crashes
@pytest.mark.parametrize(
Expand Down Expand Up @@ -332,63 +379,6 @@ def test_finetune_doesnt_load_optimizer(
assert not diffs[0] and not diffs[1] and diffs[2]
assert not any(diff(optimizer.state_dict(), optim_unloaded_state_dict))

def test_can_load_deprecated_bucket_space_format(self, tmp_path_dist_ckpt):
# sync=True to make sure other ranks wait for rank 0 to finish creating directory.
tp = 4
pp = 2

Utils.initialize_model_parallel(tp, pp)
with TempNamedDir(
tmp_path_dist_ckpt / 'test_can_load_deprecated_bucket_space_format', sync=True
) as ckpt_dir:
mock_args = SimpleNamespace()
with mock.patch('megatron.training.checkpointing.get_args', new=lambda: mock_args):

init_basic_mock_args(mock_args, tp=tp, pp=pp)
init_checkpointing_mock_args(mock_args, ckpt_dir, True)

model, optimizer = setup_model_and_optimizer(
seed=2, tp=tp, pp=pp, initialize_fn=initialize_gpt_model
)

# Mock optimizer sharded_state_dict so that it ignores the externally
# passed sharding_type and uses 'fully_sharded_bucket_space' instead
orig_optim_sharded_state_dict_fn = optimizer.sharded_state_dict

def sharded_state_dict_bucket_space(
self, *args, sharding_type: str = 'fully_sharded_model_space', **kwargs
):
return orig_optim_sharded_state_dict_fn(
*args, sharding_type='fully_sharded_bucket_space', **kwargs
)

optimizer.sharded_state_dict = MethodType(
sharded_state_dict_bucket_space, optimizer
)
save_checkpoint(10, model, optimizer, None, 0)

flag = 0
key_list = []
torch.distributed.barrier()
if Utils.rank == 0:
sharded_metadata = load_tensors_metadata(ckpt_dir / 'iter_0000010')
key_list = list(sharded_metadata.keys())
# Check if actually using `fully_parallel_bucket_space` format.
key = (
"optimizer.distributed.dp_group_idx_0.gbuf_idx_0.dtype_"
"(torch.bfloat16, torch.bfloat16).bucket_idx_0.exp_avg_sq"
)
if key in key_list:
flag = 1

tensor = torch.tensor([flag], dtype=torch.long, device='cuda')
torch.distributed.broadcast(tensor, 0)
flag = tensor[0].item()
assert flag == 1, key_list

optimizer.sharded_state_dict = orig_optim_sharded_state_dict_fn
load_checkpoint_no_arg_checks(model, optimizer, None)


class TestFP32Optimizer:
def setup_method(self, method):
Expand Down

0 comments on commit 0d59157

Please sign in to comment.