Skip to content

Commit

Permalink
Merge branch 'mlaz/fix-24.05-pyt-dist' into 'main'
Browse files Browse the repository at this point in the history
PyT Dist fix for 24.05 container

See merge request ADLR/megatron-lm!1823
  • Loading branch information
ericharper committed Aug 9, 2024
2 parents 1bb6337 + 58a8a62 commit 203b463
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions megatron/core/dist_checkpointing/strategies/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def sharded_tensor_to_torch_sharded_tensor(
]

# Create a ShardedTensor without invoking communication. Determine global shards
world_size = torch.distributed.get_world_size()
shard_metadata = []
# NOTE: here we assume a regular grid of shards
for fragment_offsets in itertools.product(*map(range, some_sh_ten.axis_fragmentations)):
Expand All @@ -244,13 +245,16 @@ def sharded_tensor_to_torch_sharded_tensor(

else:
# for shards from other ranks we provide simplistic data - this information will be discarded
# during TorchShardedTensor._init_from_local_shards_and_global_metadata call
# during TorchShardedTensor._init_from_local_shards_and_global_metadata call.
# Due to a bug in PyT 24.05 container we must specify some concrete rank within a world size.
# The exact rank doesn't matter as long as it's different than my rank - hence (rank + 1) % WS.
placement = f"rank:{(rank + 1) % world_size}/cuda"
if has_flattened_range and not is_flattened_range_1d:
offset = offset + (0,)
size = (1,) * len(offsets_shape) + global_shape[-1:]
else:
size = offsets_shape
shard_metadata.append(ShardMetadata(offset, size, "cuda"))
shard_metadata.append(ShardMetadata(offset, size, placement))

tensor = some_sh_ten.data
sharded_tensor_metadata = ShardedTensorMetadata(
Expand Down

0 comments on commit 203b463

Please sign in to comment.