@@ -45,6 +45,10 @@ class Range:
45
45
"""
46
46
A range represents a start and end points for indexing a shard
47
47
from a full tensor.
48
+
49
+ Args:
50
+ start (int): Start index.
51
+ end (int): End index.
48
52
"""
49
53
50
54
def __init__ (self , start : int , end : int ):
@@ -53,6 +57,13 @@ def __init__(self, start: int, end: int):
53
57
self .size = end - start
54
58
55
59
def normalize (self , start : int = 0 ):
60
+ """Shift start/end indexes to start at new start index.
61
+
62
+ Both start and end indexes will be shifted by [new start] - [old start].
63
+
64
+ Args:
65
+ start (int): New start index.
66
+ """
56
67
return Range (start , start + self .size )
57
68
58
69
def __str__ (self ):
@@ -63,6 +74,11 @@ def __len__(self):
63
74
64
75
65
76
class DistributedOptimizer (MixedPrecisionOptimizer ):
77
+ """Distributed optimizer, for all data types (fp16, bf16, and fp32).
78
+
79
+ See __init__() below for argument details.
80
+ """
81
+
66
82
@classmethod
67
83
def _build_model_gbuf_param_range_map (
68
84
cls ,
@@ -613,7 +629,7 @@ def load_state_dict(self, state_dict):
613
629
614
630
# Get the Torch optimizer's state dict.
615
631
# - This 'inner' optimizer at this point is unallocated, and only
616
- # contains an integer odering of parameters within each group, and
632
+ # contains an integer ordering of parameters within each group, and
617
633
# the ordering of parameters within its flattened parameter state
618
634
# list.
619
635
inner_state_dict = self .optimizer .state_dict ()
@@ -622,34 +638,45 @@ def load_state_dict(self, state_dict):
622
638
for idx , group in enumerate (state_dict ["optimizer" ]["param_groups" ])
623
639
]
624
640
625
- # Allocate 'dummy' data for optimizer state (i.e., torch.empty() below)
626
- # - Real data is overwritten during load_parameter_state().
627
- state_dict_state = []
628
- for gbuf_range_maps in self .gbuf_ranges :
629
- for gbuf_range_map_for_all_buckets in gbuf_range_maps .values ():
630
- for gbuf_range_map in gbuf_range_map_for_all_buckets :
631
- for model_param , param_range_map in gbuf_range_map ["param_map" ].items ():
641
+ # Allocate or retrieve optimizer state (i.e., tensors).
642
+ if len (self .optimizer .state ) == 0 :
643
+ # Allocate empty optimizer state if not previously initialized.
644
+ # - If len(self.optimizer.state) == 0, this means that the optimizer
645
+ # state has not been previously initialized. Once it has been
646
+ # initialized, we skip this code block to avoid reallocating
647
+ # empty tensors (i.e., torch.empty), which in turn reduces memory
648
+ # fragmentation.
649
+ # - Real data is overwritten during load_parameter_state().
650
+ state_dict_state = []
651
+ for gbuf_range_maps in self .gbuf_ranges :
652
+ for gbuf_range_map_for_all_buckets in gbuf_range_maps .values ():
653
+ for gbuf_range_map in gbuf_range_map_for_all_buckets :
654
+ for model_param , param_range_map in gbuf_range_map ["param_map" ].items ():
632
655
633
- # Get parameter ordering information (see method docstring
634
- # for details).
635
- group_index , group_order = self .model_param_group_index_map [model_param ]
636
- state_order = inner_state_dict ["param_groups" ][group_index ]["params" ][
637
- group_order
638
- ]
639
-
640
- # Allocate dummy tensors.
641
- numel = len (param_range_map ["gbuf_world" ])
642
- init_shard = lambda : torch .empty (
643
- (numel ,), dtype = torch .float32 , device = torch .cuda .current_device ()
644
- )
656
+ # Get parameter ordering information (see method docstring
657
+ # for details).
658
+ group_index , group_order = self .model_param_group_index_map [model_param ]
659
+ state_order = inner_state_dict ["param_groups" ][group_index ]["params" ][
660
+ group_order
661
+ ]
645
662
646
- state_dict_state .append (
647
- (state_order , {"exp_avg" : init_shard (), "exp_avg_sq" : init_shard ()})
648
- )
663
+ # Allocate dummy tensors.
664
+ numel = len (param_range_map ["gbuf_world" ])
665
+ init_shard = lambda : torch .empty (
666
+ (numel ,), dtype = torch .float32 , device = torch .cuda .current_device ()
667
+ )
668
+
669
+ state_dict_state .append (
670
+ (state_order , {"exp_avg" : init_shard (), "exp_avg_sq" : init_shard ()})
671
+ )
672
+
673
+ # Sort by state order (see method docstring for details).
674
+ state_dict_state .sort (key = lambda s : s [0 ])
675
+ state_dict_state = {s [0 ]: s [1 ] for s in state_dict_state }
649
676
650
- # Sort by state order (see method docstring for details).
651
- state_dict_state . sort ( key = lambda s : s [ 0 ])
652
- state_dict_state = { s [ 0 ]: s [ 1 ] for s in state_dict_state }
677
+ else :
678
+ # Retrieve existing optimizer state.
679
+ state_dict_state = inner_state_dict [ "state" ]
653
680
654
681
# Extract 'step', for non-Apex/TE support.
655
682
if not HAVE_APEX_OR_TE :
@@ -894,7 +921,10 @@ def sharded_state_dict(
894
921
}
895
922
896
923
if is_loading :
897
- self .init_state_fn (self .optimizer )
924
+ # Call the distributed optimizer's specialized load_state_dict(),
925
+ # which conditionally skips re-allocating the optimizer's state if
926
+ # already initialized, which in turn reduces memory fragmentation.
927
+ self .load_state_dict (self .state_dict ())
898
928
899
929
if sharding_type == 'fully_sharded_bucket_space' :
900
930
param_state = self .sharded_param_state_fs_bucket_space (
0 commit comments