Skip to content

Commit cf0f9b2

Browse files
committed
Merge branch 'lmcafee/dist-ckpt-reduce-frag-v2' into 'main'
Reduce fragmentation when loading dist-opt + dist-ckpt. See merge request ADLR/megatron-lm!1742
2 parents d7dda92 + ad729e8 commit cf0f9b2

File tree

2 files changed

+58
-27
lines changed

2 files changed

+58
-27
lines changed

megatron/core/optimizer/distrib_optimizer.py

+57-27
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ class Range:
4545
"""
4646
A range represents a start and end points for indexing a shard
4747
from a full tensor.
48+
49+
Args:
50+
start (int): Start index.
51+
end (int): End index.
4852
"""
4953

5054
def __init__(self, start: int, end: int):
@@ -53,6 +57,13 @@ def __init__(self, start: int, end: int):
5357
self.size = end - start
5458

5559
def normalize(self, start: int = 0):
60+
"""Shift start/end indexes to start at new start index.
61+
62+
Both start and end indexes will be shifted by [new start] - [old start].
63+
64+
Args:
65+
start (int): New start index.
66+
"""
5667
return Range(start, start + self.size)
5768

5869
def __str__(self):
@@ -63,6 +74,11 @@ def __len__(self):
6374

6475

6576
class DistributedOptimizer(MixedPrecisionOptimizer):
77+
"""Distributed optimizer, for all data types (fp16, bf16, and fp32).
78+
79+
See __init__() below for argument details.
80+
"""
81+
6682
@classmethod
6783
def _build_model_gbuf_param_range_map(
6884
cls,
@@ -613,7 +629,7 @@ def load_state_dict(self, state_dict):
613629

614630
# Get the Torch optimizer's state dict.
615631
# - This 'inner' optimizer at this point is unallocated, and only
616-
# contains an integer odering of parameters within each group, and
632+
# contains an integer ordering of parameters within each group, and
617633
# the ordering of parameters within its flattened parameter state
618634
# list.
619635
inner_state_dict = self.optimizer.state_dict()
@@ -622,34 +638,45 @@ def load_state_dict(self, state_dict):
622638
for idx, group in enumerate(state_dict["optimizer"]["param_groups"])
623639
]
624640

625-
# Allocate 'dummy' data for optimizer state (i.e., torch.empty() below)
626-
# - Real data is overwritten during load_parameter_state().
627-
state_dict_state = []
628-
for gbuf_range_maps in self.gbuf_ranges:
629-
for gbuf_range_map_for_all_buckets in gbuf_range_maps.values():
630-
for gbuf_range_map in gbuf_range_map_for_all_buckets:
631-
for model_param, param_range_map in gbuf_range_map["param_map"].items():
641+
# Allocate or retrieve optimizer state (i.e., tensors).
642+
if len(self.optimizer.state) == 0:
643+
# Allocate empty optimizer state if not previously initialized.
644+
# - If len(self.optimizer.state) == 0, this means that the optimizer
645+
# state has not been previously initialized. Once it has been
646+
# initialized, we skip this code block to avoid reallocating
647+
# empty tensors (i.e., torch.empty), which in turn reduces memory
648+
# fragmentation.
649+
# - Real data is overwritten during load_parameter_state().
650+
state_dict_state = []
651+
for gbuf_range_maps in self.gbuf_ranges:
652+
for gbuf_range_map_for_all_buckets in gbuf_range_maps.values():
653+
for gbuf_range_map in gbuf_range_map_for_all_buckets:
654+
for model_param, param_range_map in gbuf_range_map["param_map"].items():
632655

633-
# Get parameter ordering information (see method docstring
634-
# for details).
635-
group_index, group_order = self.model_param_group_index_map[model_param]
636-
state_order = inner_state_dict["param_groups"][group_index]["params"][
637-
group_order
638-
]
639-
640-
# Allocate dummy tensors.
641-
numel = len(param_range_map["gbuf_world"])
642-
init_shard = lambda: torch.empty(
643-
(numel,), dtype=torch.float32, device=torch.cuda.current_device()
644-
)
656+
# Get parameter ordering information (see method docstring
657+
# for details).
658+
group_index, group_order = self.model_param_group_index_map[model_param]
659+
state_order = inner_state_dict["param_groups"][group_index]["params"][
660+
group_order
661+
]
645662

646-
state_dict_state.append(
647-
(state_order, {"exp_avg": init_shard(), "exp_avg_sq": init_shard()})
648-
)
663+
# Allocate dummy tensors.
664+
numel = len(param_range_map["gbuf_world"])
665+
init_shard = lambda: torch.empty(
666+
(numel,), dtype=torch.float32, device=torch.cuda.current_device()
667+
)
668+
669+
state_dict_state.append(
670+
(state_order, {"exp_avg": init_shard(), "exp_avg_sq": init_shard()})
671+
)
672+
673+
# Sort by state order (see method docstring for details).
674+
state_dict_state.sort(key=lambda s: s[0])
675+
state_dict_state = {s[0]: s[1] for s in state_dict_state}
649676

650-
# Sort by state order (see method docstring for details).
651-
state_dict_state.sort(key=lambda s: s[0])
652-
state_dict_state = {s[0]: s[1] for s in state_dict_state}
677+
else:
678+
# Retrieve existing optimizer state.
679+
state_dict_state = inner_state_dict["state"]
653680

654681
# Extract 'step', for non-Apex/TE support.
655682
if not HAVE_APEX_OR_TE:
@@ -894,7 +921,10 @@ def sharded_state_dict(
894921
}
895922

896923
if is_loading:
897-
self.init_state_fn(self.optimizer)
924+
# Call the distributed optimizer's specialized load_state_dict(),
925+
# which conditionally skips re-allocating the optimizer's state if
926+
# already initialized, which in turn reduces memory fragmentation.
927+
self.load_state_dict(self.state_dict())
898928

899929
if sharding_type == 'fully_sharded_bucket_space':
900930
param_state = self.sharded_param_state_fs_bucket_space(

megatron/training/checkpointing.py

+1
Original file line numberDiff line numberDiff line change
@@ -1144,6 +1144,7 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
11441144
f'p {mpu.get_pipeline_model_parallel_rank()} ] '
11451145
f'at iteration {iteration}')
11461146

1147+
torch.cuda.empty_cache()
11471148
return iteration, num_floating_point_operations_so_far
11481149

11491150

0 commit comments

Comments
 (0)