Skip to content

Commit 9a6f3d0

Browse files
isururanawakafacebook-github-bot
authored andcommitted
Input distribution latency estimations
Summary: This introduces input distribution latency estimations. Input distribution is two step communication happens inside SDD pipelines. - split exchange: Exchanges buffer sizes to receive input IDS from KJTs. The cost does not depend on Input and it meta data exchanging phase. Hence, this diff excludes that from the computations. - ID exchange: this exchanges actual IDs to lookup. we estimated the cost by analyzing all-to-all comms Differential Revision: D87389540
1 parent a533012 commit 9a6f3d0

File tree

3 files changed

+126
-0
lines changed

3 files changed

+126
-0
lines changed

torchrec/distributed/planner/shard_estimators.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,44 @@ def _get_expected_cache_prefetch_time(
460460
prefetch_bytes = expected_cache_fetches * emb_dim * table_data_type_size
461461
return prefetch_bytes / hbm_to_ddr_mem_bw
462462

463+
@classmethod
464+
def _input_dist_expected_latency(
465+
cls,
466+
batch_sizes: List[int],
467+
world_size: int,
468+
local_world_size: int,
469+
num_poolings: List[float],
470+
input_lengths: List[float],
471+
a2a_comm_data_type_size: float,
472+
comms_bandwidths: GeneralizedCommsBandwidth,
473+
) -> float:
474+
"""
475+
Calculates the expected latency for A2A input dist.
476+
477+
Args:
478+
batch_sizes (int): The batch size for each input feature.
479+
world_size (int): The total number of devices in the distributed setup.
480+
local_world_size (int): The number of devices on a single host.
481+
num_poolings (List[float]): Number of poolings per sample for each input feature.
482+
input_lengths (List[float]): Average number of lookups per input feature.
483+
a2a_comm_data_type_size (float): Data type size (in bytes) for forward all-to-all communication.
484+
comms_bandwidths (GeneralizedCommsBandwidth): Object to query communication bandwidths.
485+
486+
Returns:
487+
float: The expected latency (in seconds) for input distribution.
488+
"""
489+
batch_inputs = sum(
490+
[x * y * z for x, y, z in zip(input_lengths, num_poolings, batch_sizes)]
491+
)
492+
input_read_size = math.ceil(batch_inputs * world_size * a2a_comm_data_type_size)
493+
494+
comms_bw = comms_bandwidths.get_bw(
495+
world_size=world_size,
496+
local_world_size=local_world_size,
497+
collective_type=CollectiveType.ALL_TO_ALL,
498+
)
499+
return input_read_size / comms_bw
500+
463501
@classmethod
464502
def _get_tw_sharding_perf(
465503
cls,
@@ -550,6 +588,15 @@ def _get_tw_sharding_perf(
550588
hbm_to_ddr_mem_bw, expected_cache_fetches, emb_dim, table_data_type_size
551589
)
552590

591+
input_dist_comms = cls._input_dist_expected_latency(
592+
batch_sizes=batch_sizes,
593+
world_size=world_size,
594+
local_world_size=local_world_size,
595+
num_poolings=num_poolings,
596+
input_lengths=input_lengths,
597+
a2a_comm_data_type_size=input_data_type_size,
598+
comms_bandwidths=comms_bandwidths,
599+
)
553600
# in order of model parallel execution, starting with:
554601
# BWD DP -> BWD MP ... FWD MP -> FWD DP
555602
return Perf(
@@ -558,6 +605,7 @@ def _get_tw_sharding_perf(
558605
bwd_compute=bwd_compute + bwd_grad_indice_weights_kernel,
559606
bwd_comms=bwd_comms,
560607
prefetch_compute=prefetch_compute,
608+
input_dist_comms=input_dist_comms,
561609
)
562610

563611
@classmethod
@@ -657,13 +705,23 @@ def _get_rw_sharding_perf(
657705
emb_dim,
658706
table_data_type_size,
659707
)
708+
input_dist_comms = cls._input_dist_expected_latency(
709+
batch_sizes=batch_sizes,
710+
world_size=world_size,
711+
local_world_size=local_world_size,
712+
num_poolings=num_poolings,
713+
input_lengths=input_lengths,
714+
a2a_comm_data_type_size=input_data_type_size,
715+
comms_bandwidths=comms_bandwidths,
716+
)
660717

661718
return Perf(
662719
fwd_compute=fwd_compute,
663720
fwd_comms=fwd_comms,
664721
bwd_compute=bwd_compute + bwd_grad_indice_weights_kernel,
665722
bwd_comms=bwd_comms + bwd_batched_copy,
666723
prefetch_compute=prefetch_compute,
724+
input_dist_comms=input_dist_comms,
667725
)
668726

669727
@classmethod
@@ -789,13 +847,23 @@ def _get_twrw_sharding_perf(
789847
emb_dim,
790848
table_data_type_size,
791849
)
850+
input_dist_comms = cls._input_dist_expected_latency(
851+
batch_sizes=batch_sizes,
852+
world_size=world_size,
853+
local_world_size=local_world_size,
854+
num_poolings=num_poolings,
855+
input_lengths=input_lengths,
856+
a2a_comm_data_type_size=input_data_type_size,
857+
comms_bandwidths=comms_bandwidths,
858+
)
792859

793860
return Perf(
794861
fwd_compute=fwd_compute,
795862
fwd_comms=fwd_comms,
796863
bwd_compute=bwd_compute + bwd_grad_indice_weights_kernel,
797864
bwd_comms=bwd_comms + bwd_batched_copy,
798865
prefetch_compute=prefetch_compute,
866+
input_dist_comms=input_dist_comms,
799867
)
800868

801869
@classmethod

0 commit comments

Comments
 (0)