@@ -460,6 +460,44 @@ def _get_expected_cache_prefetch_time(
460460 prefetch_bytes = expected_cache_fetches * emb_dim * table_data_type_size
461461 return prefetch_bytes / hbm_to_ddr_mem_bw
462462
463+ @classmethod
464+ def _input_dist_expected_latency (
465+ cls ,
466+ batch_sizes : List [int ],
467+ world_size : int ,
468+ local_world_size : int ,
469+ num_poolings : List [float ],
470+ input_lengths : List [float ],
471+ a2a_comm_data_type_size : float ,
472+ comms_bandwidths : GeneralizedCommsBandwidth ,
473+ ) -> float :
474+ """
475+ Calculates the expected latency for A2A input dist.
476+
477+ Args:
478+ batch_sizes (int): The batch size for each input feature.
479+ world_size (int): The total number of devices in the distributed setup.
480+ local_world_size (int): The number of devices on a single host.
481+ num_poolings (List[float]): Number of poolings per sample for each input feature.
482+ input_lengths (List[float]): Average number of lookups per input feature.
483+ a2a_comm_data_type_size (float): Data type size (in bytes) for forward all-to-all communication.
484+ comms_bandwidths (GeneralizedCommsBandwidth): Object to query communication bandwidths.
485+
486+ Returns:
487+ float: The expected latency (in seconds) for input distribution.
488+ """
489+ batch_inputs = sum (
490+ [x * y * z for x , y , z in zip (input_lengths , num_poolings , batch_sizes )]
491+ )
492+ input_read_size = math .ceil (batch_inputs * world_size * a2a_comm_data_type_size )
493+
494+ comms_bw = comms_bandwidths .get_bw (
495+ world_size = world_size ,
496+ local_world_size = local_world_size ,
497+ collective_type = CollectiveType .ALL_TO_ALL ,
498+ )
499+ return input_read_size / comms_bw
500+
463501 @classmethod
464502 def _get_tw_sharding_perf (
465503 cls ,
@@ -550,6 +588,15 @@ def _get_tw_sharding_perf(
550588 hbm_to_ddr_mem_bw , expected_cache_fetches , emb_dim , table_data_type_size
551589 )
552590
591+ input_dist_comms = cls ._input_dist_expected_latency (
592+ batch_sizes = batch_sizes ,
593+ world_size = world_size ,
594+ local_world_size = local_world_size ,
595+ num_poolings = num_poolings ,
596+ input_lengths = input_lengths ,
597+ a2a_comm_data_type_size = input_data_type_size ,
598+ comms_bandwidths = comms_bandwidths ,
599+ )
553600 # in order of model parallel execution, starting with:
554601 # BWD DP -> BWD MP ... FWD MP -> FWD DP
555602 return Perf (
@@ -558,6 +605,7 @@ def _get_tw_sharding_perf(
558605 bwd_compute = bwd_compute + bwd_grad_indice_weights_kernel ,
559606 bwd_comms = bwd_comms ,
560607 prefetch_compute = prefetch_compute ,
608+ input_dist_comms = input_dist_comms ,
561609 )
562610
563611 @classmethod
@@ -657,13 +705,23 @@ def _get_rw_sharding_perf(
657705 emb_dim ,
658706 table_data_type_size ,
659707 )
708+ input_dist_comms = cls ._input_dist_expected_latency (
709+ batch_sizes = batch_sizes ,
710+ world_size = world_size ,
711+ local_world_size = local_world_size ,
712+ num_poolings = num_poolings ,
713+ input_lengths = input_lengths ,
714+ a2a_comm_data_type_size = input_data_type_size ,
715+ comms_bandwidths = comms_bandwidths ,
716+ )
660717
661718 return Perf (
662719 fwd_compute = fwd_compute ,
663720 fwd_comms = fwd_comms ,
664721 bwd_compute = bwd_compute + bwd_grad_indice_weights_kernel ,
665722 bwd_comms = bwd_comms + bwd_batched_copy ,
666723 prefetch_compute = prefetch_compute ,
724+ input_dist_comms = input_dist_comms ,
667725 )
668726
669727 @classmethod
@@ -789,13 +847,23 @@ def _get_twrw_sharding_perf(
789847 emb_dim ,
790848 table_data_type_size ,
791849 )
850+ input_dist_comms = cls ._input_dist_expected_latency (
851+ batch_sizes = batch_sizes ,
852+ world_size = world_size ,
853+ local_world_size = local_world_size ,
854+ num_poolings = num_poolings ,
855+ input_lengths = input_lengths ,
856+ a2a_comm_data_type_size = input_data_type_size ,
857+ comms_bandwidths = comms_bandwidths ,
858+ )
792859
793860 return Perf (
794861 fwd_compute = fwd_compute ,
795862 fwd_comms = fwd_comms ,
796863 bwd_compute = bwd_compute + bwd_grad_indice_weights_kernel ,
797864 bwd_comms = bwd_comms + bwd_batched_copy ,
798865 prefetch_compute = prefetch_compute ,
866+ input_dist_comms = input_dist_comms ,
799867 )
800868
801869 @classmethod
0 commit comments