diff --git a/torchrec/distributed/planner/shard_estimators.py b/torchrec/distributed/planner/shard_estimators.py index acf672047..07b1e3131 100644 --- a/torchrec/distributed/planner/shard_estimators.py +++ b/torchrec/distributed/planner/shard_estimators.py @@ -477,6 +477,48 @@ def _get_expected_cache_prefetch_time( prefetch_bytes = expected_cache_fetches * emb_dim * table_data_type_size return prefetch_bytes / hbm_to_ddr_mem_bw + @classmethod + def _input_dist_expected_latency( + cls, + batch_sizes: List[int], + world_size: int, + local_world_size: int, + num_poolings: List[float], + input_lengths: List[float], + a2a_comm_data_type_size: float, + comms_bandwidths: GeneralizedCommsBandwidth, + is_weighted: bool = False, + ) -> float: + """ + Calculates the expected latency for A2A input dist. + + Args: + batch_sizes (int): The batch size for each input feature. + world_size (int): The total number of devices in the distributed setup. + local_world_size (int): The number of devices on a single host. + num_poolings (List[float]): Number of poolings per sample for each input feature. + input_lengths (List[float]): Average number of lookups per input feature. + a2a_comm_data_type_size (float): Data type size (in bytes) for forward all-to-all communication. + comms_bandwidths (GeneralizedCommsBandwidth): Object to query communication bandwidths. + + Returns: + float: The expected latency (in seconds) for input distribution. + """ + batch_inputs = sum( + [x * y * z for x, y, z in zip(input_lengths, num_poolings, batch_sizes)] + ) + input_read_size = math.ceil(batch_inputs * world_size * a2a_comm_data_type_size) + + if is_weighted: + input_read_size *= 2 + + comms_bw = comms_bandwidths.get_bw( + world_size=world_size, + local_world_size=local_world_size, + collective_type=CollectiveType.ALL_TO_ALL, + ) + return input_read_size / comms_bw + @classmethod def _get_tw_sharding_perf( cls, @@ -567,6 +609,15 @@ def _get_tw_sharding_perf( hbm_to_ddr_mem_bw, expected_cache_fetches, emb_dim, table_data_type_size ) + input_dist_comms = cls._input_dist_expected_latency( + batch_sizes=batch_sizes, + world_size=world_size, + local_world_size=local_world_size, + num_poolings=num_poolings, + input_lengths=input_lengths, + a2a_comm_data_type_size=input_data_type_size, + comms_bandwidths=comms_bandwidths, + ) # in order of model parallel execution, starting with: # BWD DP -> BWD MP ... FWD MP -> FWD DP return Perf( @@ -575,6 +626,7 @@ def _get_tw_sharding_perf( bwd_compute=bwd_compute + bwd_grad_indice_weights_kernel, bwd_comms=bwd_comms, prefetch_compute=prefetch_compute, + input_dist_comms=input_dist_comms, ) @classmethod @@ -674,6 +726,15 @@ def _get_rw_sharding_perf( emb_dim, table_data_type_size, ) + input_dist_comms = cls._input_dist_expected_latency( + batch_sizes=batch_sizes, + world_size=world_size, + local_world_size=local_world_size, + num_poolings=num_poolings, + input_lengths=input_lengths, + a2a_comm_data_type_size=input_data_type_size, + comms_bandwidths=comms_bandwidths, + ) return Perf( fwd_compute=fwd_compute, @@ -681,6 +742,7 @@ def _get_rw_sharding_perf( bwd_compute=bwd_compute + bwd_grad_indice_weights_kernel, bwd_comms=bwd_comms + bwd_batched_copy, prefetch_compute=prefetch_compute, + input_dist_comms=input_dist_comms, ) @classmethod @@ -806,6 +868,15 @@ def _get_twrw_sharding_perf( emb_dim, table_data_type_size, ) + input_dist_comms = cls._input_dist_expected_latency( + batch_sizes=batch_sizes, + world_size=world_size, + local_world_size=local_world_size, + num_poolings=num_poolings, + input_lengths=input_lengths, + a2a_comm_data_type_size=input_data_type_size, + comms_bandwidths=comms_bandwidths, + ) return Perf( fwd_compute=fwd_compute, @@ -813,6 +884,7 @@ def _get_twrw_sharding_perf( bwd_compute=bwd_compute + bwd_grad_indice_weights_kernel, bwd_comms=bwd_comms + bwd_batched_copy, prefetch_compute=prefetch_compute, + input_dist_comms=input_dist_comms, ) @classmethod diff --git a/torchrec/distributed/planner/tests/test_shard_estimators.py b/torchrec/distributed/planner/tests/test_shard_estimators.py index a2c7ed5e6..5bb1cb91e 100644 --- a/torchrec/distributed/planner/tests/test_shard_estimators.py +++ b/torchrec/distributed/planner/tests/test_shard_estimators.py @@ -141,6 +141,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.000654920154856466, bwd_comms=6.357828776041667e-05, + input_dist_comms=1.2715657552083334e-05, ) ], ("fused_uvm", "table_wise"): [ @@ -149,6 +150,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.18358230590820312, bwd_comms=6.357828776041667e-05, + input_dist_comms=1.2715657552083334e-05, ) ], ("fused_uvm_caching", "table_wise"): [ @@ -157,6 +159,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.02865675019054878, bwd_comms=6.357828776041667e-05, + input_dist_comms=1.2715657552083334e-05, ) ], ("fused", "column_wise"): [ @@ -165,6 +168,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.000654920154856466, bwd_comms=6.357828776041667e-05, + input_dist_comms=1.2715657552083334e-05, ) ], ("fused_uvm", "column_wise"): [ @@ -173,6 +177,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.18358230590820312, bwd_comms=6.357828776041667e-05, + input_dist_comms=1.2715657552083334e-05, ) ], ("fused_uvm_caching", "column_wise"): [ @@ -181,6 +186,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.02865675019054878, bwd_comms=6.357828776041667e-05, + input_dist_comms=1.2715657552083334e-05, ) ], ("fused", "table_column_wise"): [ @@ -189,6 +195,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.000654920154856466, bwd_comms=6.357828776041667e-05, + input_dist_comms=1.2715657552083334e-05, ) ], ("fused_uvm", "table_column_wise"): [ @@ -197,6 +204,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.18358230590820312, bwd_comms=6.357828776041667e-05, + input_dist_comms=1.2715657552083334e-05, ) ], ("fused_uvm_caching", "table_column_wise"): [ @@ -205,6 +213,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.02865675019054878, bwd_comms=6.357828776041667e-05, + input_dist_comms=1.2715657552083334e-05, ) ], ("fused", "row_wise"): [ @@ -213,12 +222,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.0001360873049052397, bwd_comms=0.00016798276699240525, + input_dist_comms=1.2715657552083334e-05, ), Perf( fwd_compute=6.804365245261984e-05, fwd_comms=6.357828776041667e-05, bwd_compute=0.0001360873049052397, bwd_comms=0.00016798276699240525, + input_dist_comms=1.2715657552083334e-05, ), ], ("fused_uvm", "row_wise"): [ @@ -227,12 +238,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.03814697265625, bwd_comms=0.029329458872477215, + input_dist_comms=1.2715657552083334e-05, ), Perf( fwd_compute=0.019073486328125, fwd_comms=6.357828776041667e-05, bwd_compute=0.03814697265625, bwd_comms=0.029329458872477215, + input_dist_comms=1.2715657552083334e-05, ), ], ("fused_uvm_caching", "row_wise"): [ @@ -241,12 +254,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.0059546493902439025, bwd_comms=0.004631910866838161, + input_dist_comms=1.2715657552083334e-05, ), Perf( fwd_compute=0.0029773246951219513, fwd_comms=6.357828776041667e-05, bwd_compute=0.0059546493902439025, bwd_comms=0.004631910866838161, + input_dist_comms=1.2715657552083334e-05, ), ], ("fused", "table_row_wise"): [ @@ -255,12 +270,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.0001360873049052397, bwd_comms=0.00016798276699240525, + input_dist_comms=1.2715657552083334e-05, ), Perf( fwd_compute=6.804365245261984e-05, fwd_comms=6.357828776041667e-05, bwd_compute=0.0001360873049052397, bwd_comms=0.00016798276699240525, + input_dist_comms=1.2715657552083334e-05, ), ], ("fused_uvm", "table_row_wise"): [ @@ -269,12 +286,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.03814697265625, bwd_comms=0.029329458872477215, + input_dist_comms=1.2715657552083334e-05, ), Perf( fwd_compute=0.019073486328125, fwd_comms=6.357828776041667e-05, bwd_compute=0.03814697265625, bwd_comms=0.029329458872477215, + input_dist_comms=1.2715657552083334e-05, ), ], ("fused_uvm_caching", "table_row_wise"): [ @@ -283,12 +302,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.0059546493902439025, bwd_comms=0.004631910866838161, + input_dist_comms=1.2715657552083334e-05, ), Perf( fwd_compute=0.0029773246951219513, fwd_comms=6.357828776041667e-05, bwd_compute=0.0059546493902439025, bwd_comms=0.004631910866838161, + input_dist_comms=1.2715657552083334e-05, ), ], # grid_shard is the same as table_row_wise @@ -298,12 +319,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.0001360873049052397, bwd_comms=0.00016798276699240525, + input_dist_comms=1.2715657552083334e-05, ), Perf( fwd_compute=6.804365245261984e-05, fwd_comms=6.357828776041667e-05, bwd_compute=0.0001360873049052397, bwd_comms=0.00016798276699240525, + input_dist_comms=1.2715657552083334e-05, ), ], ("fused_uvm", "grid_shard"): [ @@ -312,12 +335,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.03814697265625, bwd_comms=0.029329458872477215, + input_dist_comms=1.2715657552083334e-05, ), Perf( fwd_compute=0.019073486328125, fwd_comms=6.357828776041667e-05, bwd_compute=0.03814697265625, bwd_comms=0.029329458872477215, + input_dist_comms=1.2715657552083334e-05, ), ], ("fused_uvm_caching", "grid_shard"): [ @@ -326,12 +351,300 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05, bwd_compute=0.0059546493902439025, bwd_comms=0.004631910866838161, + input_dist_comms=1.2715657552083334e-05, ), Perf( fwd_compute=0.0029773246951219513, fwd_comms=6.357828776041667e-05, bwd_compute=0.0059546493902439025, bwd_comms=0.004631910866838161, + input_dist_comms=1.2715657552083334e-05, + ), + ], + } + + perfs = { + ( + sharding_option.compute_kernel, + sharding_option.sharding_type, + ): [shard.perf for shard in sharding_option.shards] + for sharding_option in sharding_options + } + + self.assertEqual(expected_perfs, perfs) + + def test_1_weighted_table_perf(self) -> None: + """ + Test perf estimation for a single weighted table. + Weighted tables have additional overhead for processing per-sample weights + during backward computation (bwd_grad_indice_weights_kernel). + """ + weighted_tables = [ + EmbeddingBagConfig( + num_embeddings=100, + embedding_dim=10, + name="weighted_table_0", + feature_names=["weighted_feature_0"], + ) + ] + model = TestSparseNN(tables=[], weighted_tables=weighted_tables) + """ + GRID_SHARD only is available if specified by user in parameter constraints, however, + adding parameter constraints does not work because of the non deterministic nature of + _filter_sharding_types (set & set) operation when constraints are present, we mock the + call to _filter_sharding_types to ensure the order of the sharding types list is always + the same. + """ + self.enumerator._filter_sharding_types = MagicMock( + return_value=self._sharding_types + ) + sharding_options = self.enumerator.enumerate( + module=model, + sharders=[ + cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder()) + ], + ) + + # Weighted tables should have higher bwd_compute due to bwd_grad_indice_weights_kernel + expected_perfs = { + ("dense", "data_parallel"): [ + Perf( + fwd_compute=0.00010206547867892977, + fwd_comms=0, + bwd_compute=0.00031640298390468225, + bwd_comms=0.000225593945387348, + ), + Perf( + fwd_compute=0.00010206547867892977, + fwd_comms=0, + bwd_compute=0.00031640298390468225, + bwd_comms=0.000225593945387348, + ), + ], + ("fused", "table_wise"): [ + Perf( + fwd_compute=0.00034234462640224357, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.001061268341846955, + bwd_comms=6.357828776041667e-05, + input_dist_comms=1.2715657552083334e-05, + ) + ], + ("fused_uvm", "table_wise"): [ + Perf( + fwd_compute=0.0959634780883789, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.2974867820739746, + bwd_comms=6.357828776041667e-05, + input_dist_comms=1.2715657552083334e-05, + ) + ], + ("fused_uvm_caching", "table_wise"): [ + Perf( + fwd_compute=0.014979664872332316, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.04643696110423018, + bwd_comms=6.357828776041667e-05, + input_dist_comms=1.2715657552083334e-05, + ) + ], + ("fused", "column_wise"): [ + Perf( + fwd_compute=0.00034234462640224357, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.001061268341846955, + bwd_comms=6.357828776041667e-05, + input_dist_comms=1.2715657552083334e-05, + ) + ], + ("fused_uvm", "column_wise"): [ + Perf( + fwd_compute=0.0959634780883789, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.2974867820739746, + bwd_comms=6.357828776041667e-05, + input_dist_comms=1.2715657552083334e-05, + ) + ], + ("fused_uvm_caching", "column_wise"): [ + Perf( + fwd_compute=0.014979664872332316, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.04643696110423018, + bwd_comms=6.357828776041667e-05, + input_dist_comms=1.2715657552083334e-05, + ) + ], + ("fused", "row_wise"): [ + Perf( + fwd_compute=7.229638073090858e-05, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.0002241187802658166, + bwd_comms=0.00016798276699240525, + input_dist_comms=1.2715657552083334e-05, + ), + Perf( + fwd_compute=7.229638073090858e-05, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.0002241187802658166, + bwd_comms=0.00016798276699240525, + input_dist_comms=1.2715657552083334e-05, + ), + ], + ("fused_uvm", "row_wise"): [ + Perf( + fwd_compute=0.020265579223632812, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.06282329559326172, + bwd_comms=0.029329458872477215, + input_dist_comms=1.2715657552083334e-05, + ), + Perf( + fwd_compute=0.020265579223632812, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.06282329559326172, + bwd_comms=0.029329458872477215, + input_dist_comms=1.2715657552083334e-05, + ), + ], + ("fused_uvm_caching", "row_wise"): [ + Perf( + fwd_compute=0.003163407488567073, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.009806563214557928, + bwd_comms=0.004631910866838161, + input_dist_comms=1.2715657552083334e-05, + ), + Perf( + fwd_compute=0.003163407488567073, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.009806563214557928, + bwd_comms=0.004631910866838161, + input_dist_comms=1.2715657552083334e-05, + ), + ], + ("fused", "table_row_wise"): [ + Perf( + fwd_compute=7.229638073090858e-05, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.0002241187802658166, + bwd_comms=0.00016798276699240525, + input_dist_comms=1.2715657552083334e-05, + ), + Perf( + fwd_compute=7.229638073090858e-05, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.0002241187802658166, + bwd_comms=0.00016798276699240525, + input_dist_comms=1.2715657552083334e-05, + ), + ], + ("fused_uvm", "table_row_wise"): [ + Perf( + fwd_compute=0.020265579223632812, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.06282329559326172, + bwd_comms=0.029329458872477215, + input_dist_comms=1.2715657552083334e-05, + ), + Perf( + fwd_compute=0.020265579223632812, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.06282329559326172, + bwd_comms=0.029329458872477215, + input_dist_comms=1.2715657552083334e-05, + ), + ], + ("fused_uvm_caching", "table_row_wise"): [ + Perf( + fwd_compute=0.003163407488567073, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.009806563214557928, + bwd_comms=0.004631910866838161, + input_dist_comms=1.2715657552083334e-05, + ), + Perf( + fwd_compute=0.003163407488567073, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.009806563214557928, + bwd_comms=0.004631910866838161, + input_dist_comms=1.2715657552083334e-05, + ), + ], + ("fused", "table_column_wise"): [ + Perf( + fwd_compute=0.00034234462640224357, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.001061268341846955, + bwd_comms=6.357828776041667e-05, + input_dist_comms=1.2715657552083334e-05, + ) + ], + ("fused_uvm", "table_column_wise"): [ + Perf( + fwd_compute=0.0959634780883789, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.2974867820739746, + bwd_comms=6.357828776041667e-05, + input_dist_comms=1.2715657552083334e-05, + ) + ], + ("fused_uvm_caching", "table_column_wise"): [ + Perf( + fwd_compute=0.014979664872332316, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.04643696110423018, + bwd_comms=6.357828776041667e-05, + input_dist_comms=1.2715657552083334e-05, + ) + ], + # grid_shard is the same as table_row_wise + ("fused", "grid_shard"): [ + Perf( + fwd_compute=7.229638073090858e-05, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.0002241187802658166, + bwd_comms=0.00016798276699240525, + input_dist_comms=1.2715657552083334e-05, + ), + Perf( + fwd_compute=7.229638073090858e-05, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.0002241187802658166, + bwd_comms=0.00016798276699240525, + input_dist_comms=1.2715657552083334e-05, + ), + ], + ("fused_uvm", "grid_shard"): [ + Perf( + fwd_compute=0.020265579223632812, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.06282329559326172, + bwd_comms=0.029329458872477215, + input_dist_comms=1.2715657552083334e-05, + ), + Perf( + fwd_compute=0.020265579223632812, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.06282329559326172, + bwd_comms=0.029329458872477215, + input_dist_comms=1.2715657552083334e-05, + ), + ], + ("fused_uvm_caching", "grid_shard"): [ + Perf( + fwd_compute=0.003163407488567073, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.009806563214557928, + bwd_comms=0.004631910866838161, + input_dist_comms=1.2715657552083334e-05, + ), + Perf( + fwd_compute=0.003163407488567073, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.009806563214557928, + bwd_comms=0.004631910866838161, + input_dist_comms=1.2715657552083334e-05, ), ], } @@ -860,6 +1173,7 @@ def test_1_table_perf(self) -> None: bwd_compute=0.000654920154856466, bwd_comms=6.357828776041667e-05 * 2, # bw is set to half in this test + input_dist_comms=2.5431315104166668e-05, ) ], ("fused_uvm", "table_wise"): [ @@ -868,6 +1182,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.18358230590820312, bwd_comms=6.357828776041667e-05 * 2, + input_dist_comms=2.5431315104166668e-05, ) ], ("fused_uvm_caching", "table_wise"): [ @@ -876,6 +1191,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.02865675019054878, bwd_comms=6.357828776041667e-05 * 2, + input_dist_comms=2.5431315104166668e-05, ) ], ("fused", "column_wise"): [ @@ -884,6 +1200,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.000654920154856466, bwd_comms=6.357828776041667e-05 * 2, + input_dist_comms=2.5431315104166668e-05, ) ], ("fused_uvm", "column_wise"): [ @@ -892,6 +1209,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.18358230590820312, bwd_comms=6.357828776041667e-05 * 2, + input_dist_comms=2.5431315104166668e-05, ) ], ("fused_uvm_caching", "column_wise"): [ @@ -900,6 +1218,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.02865675019054878, bwd_comms=6.357828776041667e-05 * 2, + input_dist_comms=2.5431315104166668e-05, ) ], ("fused", "table_column_wise"): [ @@ -908,6 +1227,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.000654920154856466, bwd_comms=6.357828776041667e-05 * 2, + input_dist_comms=2.5431315104166668e-05, ) ], ("fused_uvm", "table_column_wise"): [ @@ -916,6 +1236,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.18358230590820312, bwd_comms=6.357828776041667e-05 * 2, + input_dist_comms=2.5431315104166668e-05, ) ], ("fused_uvm_caching", "table_column_wise"): [ @@ -924,6 +1245,7 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.02865675019054878, bwd_comms=6.357828776041667e-05 * 2, + input_dist_comms=2.5431315104166668e-05, ) ], ("fused", "row_wise"): [ @@ -932,12 +1254,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.0001360873049052397, bwd_comms=0.00016798276699240525 + 6.357828776041667e-05, + input_dist_comms=2.5431315104166668e-05, ), Perf( fwd_compute=6.804365245261984e-05, fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.0001360873049052397, bwd_comms=0.00016798276699240525 + 6.357828776041667e-05, + input_dist_comms=2.5431315104166668e-05, ), ], ("fused_uvm", "row_wise"): [ @@ -946,12 +1270,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.03814697265625, bwd_comms=0.02939303716023763, # 0.029329458872477215 + 6.357828776041667e-05, + input_dist_comms=2.5431315104166668e-05, ), Perf( fwd_compute=0.019073486328125, fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.03814697265625, bwd_comms=0.02939303716023763, # 0.029329458872477215 + 6.357828776041667e-05, + input_dist_comms=2.5431315104166668e-05, ), ], ("fused_uvm_caching", "row_wise"): [ @@ -960,12 +1286,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.0059546493902439025, bwd_comms=0.004695489154598577, # 0.004631910866838161 + 6.357828776041667e-05 + input_dist_comms=2.5431315104166668e-05, ), Perf( fwd_compute=0.0029773246951219513, fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.0059546493902439025, bwd_comms=0.004695489154598577, # 0.004631910866838161 + 6.357828776041667e-05 + input_dist_comms=2.5431315104166668e-05, ), ], ("fused", "table_row_wise"): [ @@ -974,12 +1302,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.0001360873049052397, bwd_comms=0.00016798276699240525 + 6.357828776041667e-05, + input_dist_comms=2.5431315104166668e-05, ), Perf( fwd_compute=6.804365245261984e-05, fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.0001360873049052397, bwd_comms=0.00016798276699240525 + 6.357828776041667e-05, + input_dist_comms=2.5431315104166668e-05, ), ], ("fused_uvm", "table_row_wise"): [ @@ -988,12 +1318,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.03814697265625, bwd_comms=0.02939303716023763, # 0.029329458872477215 + 6.357828776041667e-05, + input_dist_comms=2.5431315104166668e-05, ), Perf( fwd_compute=0.019073486328125, fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.03814697265625, bwd_comms=0.02939303716023763, # 0.029329458872477215 + 6.357828776041667e-05, + input_dist_comms=2.5431315104166668e-05, ), ], ("fused_uvm_caching", "table_row_wise"): [ @@ -1002,12 +1334,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.0059546493902439025, bwd_comms=0.004695489154598577, # 0.004631910866838161 + 6.357828776041667e-05 + input_dist_comms=2.5431315104166668e-05, ), Perf( fwd_compute=0.0029773246951219513, fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.0059546493902439025, bwd_comms=0.004695489154598577, # 0.004631910866838161 + 6.357828776041667e-05 + input_dist_comms=2.5431315104166668e-05, ), ], # grid_shard is the same as table_row_wise @@ -1017,12 +1351,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.0001360873049052397, bwd_comms=0.00016798276699240525 + 6.357828776041667e-05, + input_dist_comms=2.5431315104166668e-05, ), Perf( fwd_compute=6.804365245261984e-05, fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.0001360873049052397, bwd_comms=0.00016798276699240525 + 6.357828776041667e-05, + input_dist_comms=2.5431315104166668e-05, ), ], ("fused_uvm", "grid_shard"): [ @@ -1031,12 +1367,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.03814697265625, bwd_comms=0.02939303716023763, # 0.029329458872477215 + 6.357828776041667e-05, + input_dist_comms=2.5431315104166668e-05, ), Perf( fwd_compute=0.019073486328125, fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.03814697265625, bwd_comms=0.02939303716023763, # 0.029329458872477215 + 6.357828776041667e-05, + input_dist_comms=2.5431315104166668e-05, ), ], ("fused_uvm_caching", "grid_shard"): [ @@ -1045,12 +1383,14 @@ def test_1_table_perf(self) -> None: fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.0059546493902439025, bwd_comms=0.004695489154598577, # 0.004631910866838161 + 6.357828776041667e-05, + input_dist_comms=2.5431315104166668e-05, ), Perf( fwd_compute=0.0029773246951219513, fwd_comms=6.357828776041667e-05 * 2, bwd_compute=0.0059546493902439025, bwd_comms=0.004695489154598577, # 0.004631910866838161 + 6.357828776041667e-05 + input_dist_comms=2.5431315104166668e-05, ), ], } @@ -1070,6 +1410,7 @@ def test_1_table_perf(self) -> None: ): [shard.perf for shard in sharding_option.shards] for sharding_option in sharding_options2 } + self.assertEqual(expected_perfs, perfs) self.assertEqual(expected_perfs, perfs2) diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py index 32a750290..4891bbc53 100644 --- a/torchrec/distributed/planner/types.py +++ b/torchrec/distributed/planner/types.py @@ -54,6 +54,7 @@ class Perf: fwd_comms: float bwd_compute: float bwd_comms: float + input_dist_comms: float = 0.0 prefetch_compute: float = 0.0 @property @@ -87,6 +88,7 @@ def __add__(self, other: "Perf") -> "Perf": fwd_comms=self.fwd_comms + other.fwd_comms, bwd_compute=self.bwd_compute + other.bwd_compute, bwd_comms=self.bwd_comms + other.bwd_comms, + input_dist_comms=self.input_dist_comms + other.input_dist_comms, prefetch_compute=self.prefetch_compute + other.prefetch_compute, ) @@ -97,6 +99,7 @@ def __hash__(self) -> int: self.fwd_comms, self.bwd_compute, self.bwd_comms, + self.input_dist_comms, self.prefetch_compute, ) ) diff --git a/torchrec/distributed/test_utils/test_model.py b/torchrec/distributed/test_utils/test_model.py index cb7004670..86e624ba6 100644 --- a/torchrec/distributed/test_utils/test_model.py +++ b/torchrec/distributed/test_utils/test_model.py @@ -1567,7 +1567,7 @@ def __init__( torch.device("meta"), return_remapped=True, ) - elif isinstance(tables[0], EmbeddingConfig): + elif len(tables) > 0 and isinstance(tables[0], EmbeddingConfig): self.sparse = TestECSparseArch( tables, # pyre-ignore [6] sparse_device,