2929 Storage ,
3030 Topology ,
3131)
32- from torchrec .distributed .planner .utils import bytes_to_gb , reset_shard_rank
32+ from torchrec .distributed .planner .utils import (
33+ bytes_to_gb ,
34+ mb_to_bytes ,
35+ reset_shard_rank ,
36+ )
3337from torchrec .distributed .types import ShardingType
3438
3539logger : logging .Logger = logging .getLogger (__name__ )
@@ -57,6 +61,25 @@ def _get_uniform_sharding_options(
5761 return uniform_sharding_options
5862
5963
64+ def _get_shards_assignment (
65+ sharding_options : List [ShardingOption ],
66+ ) -> List [List [Optional [int ]]]:
67+ assignment_per_option = []
68+ for sharding_option in sharding_options :
69+ assignment_per_option .append (sharding_option .get_shards_assignment ())
70+ return assignment_per_option
71+
72+
73+ def _apply_shards_assignment (
74+ sharding_options : List [ShardingOption ],
75+ assignment_per_option : List [List [Optional [int ]]],
76+ ) -> None :
77+ assert len (sharding_options ) == len (assignment_per_option )
78+ for sharding_option , assignment in zip (sharding_options , assignment_per_option ):
79+ for shard_id , rank in enumerate (assignment ):
80+ sharding_option .shards [shard_id ].rank = rank
81+
82+
6083@dataclass
6184class ShardingOptionGroup :
6285 sharding_options : List [ShardingOption ]
@@ -171,6 +194,7 @@ def partition(
171194 self ,
172195 proposal : List [ShardingOption ],
173196 storage_constraint : Topology ,
197+ hbm_per_device : Optional [int ] = None ,
174198 ) -> List [ShardingOption ]:
175199 """
176200 Places sharding options on topology based on each sharding option's
@@ -230,13 +254,27 @@ def partition(
230254 f"GreedyPerfPartitioner - sort_by: { self ._sort_by } , balance_modules: { self ._balance_modules } "
231255 )
232256
233- _topology : Topology = copy .deepcopy (storage_constraint )
234257 minheap_devices : Optional [List [OrderedDeviceHardware ]] = None
235- _host_level_devices = self ._get_host_level_devices (_topology )
258+
259+ # Don't store the topology since topology cannot be changed
260+ # the algorithm will only be modifying the device perf & storage sizes so copy them only
261+ devices = [
262+ DeviceHardware (
263+ rank = d .rank ,
264+ storage = Storage (hbm = hbm_per_device or d .storage .hbm , ddr = d .storage .ddr ),
265+ perf = copy .deepcopy (d .perf ),
266+ )
267+ for d in storage_constraint .devices
268+ ]
269+
270+ host_level_devices = GreedyPerfPartitioner ._get_host_level_devices (
271+ storage_constraint , devices
272+ )
236273
237274 # first partition the uniform sharding options (RW & DP)
238275 uniform_sharding_options = _get_uniform_sharding_options (proposal )
239- self ._uniform_partition (uniform_sharding_options , _topology .devices )
276+
277+ GreedyPerfPartitioner ._uniform_partition (uniform_sharding_options , devices )
240278
241279 # group the rest sharding options by colocation type (co-host, co-device, none)
242280 # and sort the groups by storage in reverse order
@@ -249,15 +287,15 @@ def partition(
249287 sharding_option_group .sharding_options [0 ].partition_by
250288 == PartitionByType .MULTI_HOST .value
251289 ):
252- self ._multi_hosts_partition (sharding_option_group , _host_level_devices )
290+ self ._multi_hosts_partition (sharding_option_group , host_level_devices )
253291 # _multi_hosts_partition invalidates minheap_devices, force rebuild before using
254292 minheap_devices = None
255293
256294 elif (
257295 sharding_option_group .sharding_options [0 ].partition_by
258296 == PartitionByType .HOST .value
259297 ):
260- self ._cohost_partition (sharding_option_group , _host_level_devices )
298+ self ._cohost_partition (sharding_option_group , host_level_devices )
261299 # _cohost_partition invalidates minheap_devices, force rebuild before using
262300 minheap_devices = None
263301 elif (
@@ -266,7 +304,7 @@ def partition(
266304 ):
267305 if minheap_devices is None :
268306 minheap_devices = self ._establish_minheap (
269- _topology . devices , _topology .local_world_size
307+ devices , storage_constraint .local_world_size
270308 )
271309 assert (
272310 len (sharding_option_group .sharding_options ) == 1
@@ -279,8 +317,6 @@ def partition(
279317 raise RuntimeError (
280318 f"Unexpected sharding option group { sharding_option_group } "
281319 )
282- # pyre-ignore [16]: `GreedyPerfPartitioner` has no attribute `_topology`.
283- self ._topology : Topology = _topology
284320 return proposal
285321
286322 @classmethod
@@ -432,7 +468,9 @@ def _multi_hosts_partition(
432468 sharding_option = sharding_option_group .sharding_options [0 ]
433469 try :
434470 if sharding_option .sharding_type == ShardingType .GRID_SHARD .value :
435- cls ._uniform_partition ([sharding_option ], host_devices )
471+ GreedyPerfPartitioner ._uniform_partition (
472+ [sharding_option ], host_devices
473+ )
436474 else :
437475 raise PlannerError (
438476 error_type = PlannerErrorType .PARTITION ,
@@ -486,7 +524,9 @@ def _cohost_partition(
486524 sharding_option .sharding_type
487525 == ShardingType .TABLE_ROW_WISE .value
488526 ):
489- cls ._uniform_partition ([sharding_option ], host_devices )
527+ GreedyPerfPartitioner ._uniform_partition (
528+ [sharding_option ], host_devices
529+ )
490530 # _uniform_partition invalidates minheap_devices, force rebuild
491531 # before using
492532 minheap_devices = None
@@ -521,20 +561,22 @@ def _cohost_partition(
521561 message = f"can't find a host for sharding option group { sharding_option_group } " ,
522562 )
523563
524- @classmethod
525- def _get_host_level_devices (cls , _topology : Topology ) -> List [List [DeviceHardware ]]:
526- num_hosts : int = _topology .world_size // _topology .local_world_size
564+ @staticmethod
565+ def _get_host_level_devices (
566+ topology : Topology , all_devices : List [DeviceHardware ]
567+ ) -> List [List [DeviceHardware ]]:
568+ num_hosts : int = topology .world_size // topology .local_world_size
527569 host_level_devices : List [List [DeviceHardware ]] = []
528570 for i in range (num_hosts ):
529- devices_in_host = _topology . devices [
530- i * _topology .local_world_size : (i + 1 ) * _topology .local_world_size
571+ devices_in_host = all_devices [
572+ i * topology .local_world_size : (i + 1 ) * topology .local_world_size
531573 ]
532574 host_level_devices .append (devices_in_host )
533575 return host_level_devices
534576
535- @classmethod
577+ @staticmethod
536578 def _uniform_partition (
537- cls , sharding_options : List [ShardingOption ], devices : List [DeviceHardware ]
579+ sharding_options : List [ShardingOption ], devices : List [DeviceHardware ]
538580 ) -> None :
539581 for sharding_option in sharding_options :
540582 if sharding_option .num_shards != len (devices ):
@@ -543,16 +585,17 @@ def _uniform_partition(
543585 message = f"For a uniform partition, the number of shards ({ sharding_option .num_shards } ) must equal the number of devices ({ len (devices )} )" ,
544586 )
545587 for i in range (len (devices )):
546- storage_needed = cast (Storage , sharding_option .shards [i ].storage )
588+ shard = sharding_option .shards [i ]
589+ storage_needed = cast (Storage , shard .storage )
547590 if not storage_needed .fits_in (devices [i ].storage ):
548591 raise PlannerError (
549592 error_type = PlannerErrorType .PARTITION ,
550593 message = f"Shard of size { storage_needed } bytes does not fit on any rank. Device memory cap: { devices [i ].storage } ." ,
551594 )
552595 else :
553- sharding_option . shards [ i ] .rank = devices [i ].rank
596+ shard .rank = devices [i ].rank
554597 devices [i ].storage -= storage_needed
555- devices [i ].perf += cast (Perf , sharding_option . shards [ i ] .perf )
598+ devices [i ].perf += cast (Perf , shard .perf )
556599
557600
558601class MemoryBalancedPartitioner (Partitioner ):
@@ -598,16 +641,15 @@ def partition(
598641 _partitioner = GreedyPerfPartitioner (
599642 sort_by = SortBy .PERF , balance_modules = self ._balance_modules
600643 )
601- # copying storage_constraint, since we modify it in place
602- _topology : Topology = copy .deepcopy (storage_constraint )
603644
604645 # set up default plan to fall back on
605- default_plan = _partitioner .partition (proposal , _topology )
606- default_plan = copy .deepcopy (default_plan )
646+ default_plan = _partitioner .partition (proposal , storage_constraint )
647+ best_shard_assignment = _get_shards_assignment (default_plan )
648+
607649 original_plan_perf = _perf_model .rate (default_plan )
608650
609651 # compute shard and default plan HBM stats
610- hbm_by_rank = [0 ] * _topology .world_size
652+ hbm_by_rank = [0 ] * storage_constraint .world_size
611653 hbm_requirement : int = 0
612654 max_shard_hbm : int = 0
613655 for sharding_option in default_plan :
@@ -626,7 +668,7 @@ def partition(
626668 )
627669
628670 # Lower bound for the search is the maximum of avg. HBM usage or the biggest shard
629- avg_hbm_usage : int = int (hbm_requirement / _topology .world_size )
671+ avg_hbm_usage : int = int (hbm_requirement / storage_constraint .world_size )
630672 min_hbm_per_device : int = max (avg_hbm_usage , max_shard_hbm )
631673 logger .info (
632674 "Searching in the range (min_hbm_per_device, max_hbm_per_device): "
@@ -636,16 +678,19 @@ def partition(
636678
637679 # binary search with (min, max] setting
638680 search_count = 0
681+ hbm_diff = mb_to_bytes (10 ) # 10MB
639682 while (
640683 search_count < self ._max_search_count
641- and min_hbm_per_device + 10 * 1024 ** 2 < max_hbm_per_device # 10MB
684+ and min_hbm_per_device + hbm_diff < max_hbm_per_device
642685 ):
643686 search_count += 1
644687 reset_shard_rank (proposal )
645688 mid_hbm_per_device : int = (max_hbm_per_device + min_hbm_per_device ) // 2
646- set_hbm_per_device ( _topology , mid_hbm_per_device )
689+
647690 try :
648- new_plan = _partitioner .partition (proposal , _topology )
691+ new_plan = _partitioner .partition (
692+ proposal , storage_constraint , mid_hbm_per_device
693+ )
649694 new_plan_perf = _perf_model .rate (new_plan )
650695 perf_diff = (
651696 (new_plan_perf - original_plan_perf ) / original_plan_perf
@@ -674,7 +719,7 @@ def partition(
674719 f"Found a more memory-balanced plan with { round (bytes_to_gb (mid_hbm_per_device ), 3 )} "
675720 f"GB per device for embedding tables. The new plan is { perf_diff_str } "
676721 )
677- default_plan = copy . deepcopy (new_plan )
722+ best_shard_assignment = _get_shards_assignment (new_plan )
678723 max_hbm_per_device = mid_hbm_per_device
679724 except PlannerError :
680725 logger .info (
@@ -683,9 +728,5 @@ def partition(
683728 )
684729 min_hbm_per_device = mid_hbm_per_device
685730
731+ _apply_shards_assignment (default_plan , best_shard_assignment )
686732 return default_plan
687-
688-
689- def set_hbm_per_device (storage_constraint : Topology , hbm_per_device : int ) -> None :
690- for device in storage_constraint .devices :
691- device .storage .hbm = hbm_per_device
0 commit comments