Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions torchrec/distributed/planner/shard_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,48 @@ def _get_expected_cache_prefetch_time(
prefetch_bytes = expected_cache_fetches * emb_dim * table_data_type_size
return prefetch_bytes / hbm_to_ddr_mem_bw

@classmethod
def _input_dist_expected_latency(
cls,
batch_sizes: List[int],
world_size: int,
local_world_size: int,
num_poolings: List[float],
input_lengths: List[float],
a2a_comm_data_type_size: float,
comms_bandwidths: GeneralizedCommsBandwidth,
is_weighted: bool = False,
) -> float:
"""
Calculates the expected latency for A2A input dist.

Args:
batch_sizes (int): The batch size for each input feature.
world_size (int): The total number of devices in the distributed setup.
local_world_size (int): The number of devices on a single host.
num_poolings (List[float]): Number of poolings per sample for each input feature.
input_lengths (List[float]): Average number of lookups per input feature.
a2a_comm_data_type_size (float): Data type size (in bytes) for forward all-to-all communication.
comms_bandwidths (GeneralizedCommsBandwidth): Object to query communication bandwidths.

Returns:
float: The expected latency (in seconds) for input distribution.
"""
batch_inputs = sum(
[x * y * z for x, y, z in zip(input_lengths, num_poolings, batch_sizes)]
)
input_read_size = math.ceil(batch_inputs * world_size * a2a_comm_data_type_size)

if is_weighted:
input_read_size *= 2

comms_bw = comms_bandwidths.get_bw(
world_size=world_size,
local_world_size=local_world_size,
collective_type=CollectiveType.ALL_TO_ALL,
)
return input_read_size / comms_bw

@classmethod
def _get_tw_sharding_perf(
cls,
Expand Down Expand Up @@ -567,6 +609,15 @@ def _get_tw_sharding_perf(
hbm_to_ddr_mem_bw, expected_cache_fetches, emb_dim, table_data_type_size
)

input_dist_comms = cls._input_dist_expected_latency(
batch_sizes=batch_sizes,
world_size=world_size,
local_world_size=local_world_size,
num_poolings=num_poolings,
input_lengths=input_lengths,
a2a_comm_data_type_size=input_data_type_size,
comms_bandwidths=comms_bandwidths,
)
# in order of model parallel execution, starting with:
# BWD DP -> BWD MP ... FWD MP -> FWD DP
return Perf(
Expand All @@ -575,6 +626,7 @@ def _get_tw_sharding_perf(
bwd_compute=bwd_compute + bwd_grad_indice_weights_kernel,
bwd_comms=bwd_comms,
prefetch_compute=prefetch_compute,
input_dist_comms=input_dist_comms,
)

@classmethod
Expand Down Expand Up @@ -674,13 +726,23 @@ def _get_rw_sharding_perf(
emb_dim,
table_data_type_size,
)
input_dist_comms = cls._input_dist_expected_latency(
batch_sizes=batch_sizes,
world_size=world_size,
local_world_size=local_world_size,
num_poolings=num_poolings,
input_lengths=input_lengths,
a2a_comm_data_type_size=input_data_type_size,
comms_bandwidths=comms_bandwidths,
)

return Perf(
fwd_compute=fwd_compute,
fwd_comms=fwd_comms,
bwd_compute=bwd_compute + bwd_grad_indice_weights_kernel,
bwd_comms=bwd_comms + bwd_batched_copy,
prefetch_compute=prefetch_compute,
input_dist_comms=input_dist_comms,
)

@classmethod
Expand Down Expand Up @@ -806,13 +868,23 @@ def _get_twrw_sharding_perf(
emb_dim,
table_data_type_size,
)
input_dist_comms = cls._input_dist_expected_latency(
batch_sizes=batch_sizes,
world_size=world_size,
local_world_size=local_world_size,
num_poolings=num_poolings,
input_lengths=input_lengths,
a2a_comm_data_type_size=input_data_type_size,
comms_bandwidths=comms_bandwidths,
)

return Perf(
fwd_compute=fwd_compute,
fwd_comms=fwd_comms,
bwd_compute=bwd_compute + bwd_grad_indice_weights_kernel,
bwd_comms=bwd_comms + bwd_batched_copy,
prefetch_compute=prefetch_compute,
input_dist_comms=input_dist_comms,
)

@classmethod
Expand Down
Loading
Loading