Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1764,19 +1764,29 @@ def init_parameters(self) -> None:
weight_init_max,
)

def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
hash_zch_identities = self._get_hash_zch_identities(features)
if hash_zch_identities is None:
def forward(
self,
features: KeyedJaggedTensor,
) -> torch.Tensor:
forward_args: Dict[str, Any] = {}
identities_and_metadata = self._get_hash_zch_identities_and_metadata(features)
if identities_and_metadata is not None:
hash_zch_identities, hash_zch_runtime_meta = identities_and_metadata
forward_args["hash_zch_identities"] = hash_zch_identities
if hash_zch_runtime_meta is not None:
forward_args["hash_zch_runtime_meta"] = hash_zch_runtime_meta

if len(forward_args) == 0:
return self.emb_module(
indices=features.values().long(),
offsets=features.offsets().long(),
)

return self.emb_module(
indices=features.values().long(),
offsets=features.offsets().long(),
hash_zch_identities=hash_zch_identities,
)
else:
return self.emb_module(
indices=features.values().long(),
offsets=features.offsets().long(),
**forward_args,
)

# pyre-fixme[14]: `state_dict` overrides method defined in `Module` inconsistently.
def state_dict(
Expand Down Expand Up @@ -2841,9 +2851,12 @@ def forward(
features: KeyedJaggedTensor,
) -> torch.Tensor:
forward_args: Dict[str, Any] = {}
hash_zch_identities = self._get_hash_zch_identities(features)
if hash_zch_identities is not None:
identities_and_metadata = self._get_hash_zch_identities_and_metadata(features)
if identities_and_metadata is not None:
hash_zch_identities, hash_zch_runtime_meta = identities_and_metadata
forward_args["hash_zch_identities"] = hash_zch_identities
if hash_zch_runtime_meta is not None:
forward_args["hash_zch_runtime_meta"] = hash_zch_runtime_meta

weights = features.weights_or_none()
if weights is not None and not torch.is_floating_point(weights):
Expand Down
22 changes: 15 additions & 7 deletions torchrec/distributed/embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ def init_raw_id_tracker(
get_indexed_lookups, delete
)

def _get_hash_zch_identities(
def _get_hash_zch_identities_and_metadata(
self, features: KeyedJaggedTensor
) -> Optional[torch.Tensor]:
) -> Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]:
if self._raw_id_tracker_wrapper is None or not isinstance(
self.emb_module, SplitTableBatchedEmbeddingBagsCodegen
):
Expand All @@ -131,8 +131,8 @@ def _get_hash_zch_identities(
# across multiple training iterations. Current logic appends raw_ids from
# all batches sequentially. This may cause misalignment with
# features.values() which only contains the current batch.
raw_ids_dict = raw_id_tracker_wrapper.get_indexed_lookups(
table_names, emb_module.uuid
indexed_lookups_dict = raw_id_tracker_wrapper.get_indexed_lookups(
table_names, self.emb_module.uuid
)

# Build hash_zch_identities by concatenating raw IDs from tracked tables.
Expand All @@ -148,11 +148,14 @@ def _get_hash_zch_identities(
# raw_ids are included. If some tables lack identity while others have them,
# padding with -1 may be needed to maintain alignment.
all_raw_ids = []
all_runtime_meta = []
for table_name in table_names:
if table_name in raw_ids_dict:
raw_ids_list = raw_ids_dict[table_name]
if table_name in indexed_lookups_dict:
raw_ids_list, runtime_meta_list = indexed_lookups_dict[table_name]
for raw_ids in raw_ids_list:
all_raw_ids.append(raw_ids)
for runtime_meta in runtime_meta_list:
all_runtime_meta.append(runtime_meta)

if not all_raw_ids:
return None
Expand All @@ -162,7 +165,12 @@ def _get_hash_zch_identities(
f"hash_zch_identities row count ({hash_zch_identities.size(0)}) must match "
f"features.values() length ({features.values().numel()}) to maintain 1-to-1 alignment"
)
return hash_zch_identities

if all_runtime_meta:
hash_zch_runtime_meta = torch.cat(all_runtime_meta)
return (hash_zch_identities, hash_zch_runtime_meta)
else:
return (hash_zch_identities, None)


def create_virtual_table_local_metadata(
Expand Down
20 changes: 20 additions & 0 deletions torchrec/distributed/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def __init__(
torch.Tensor,
Optional[nn.Module],
Optional[torch.Tensor],
Optional[torch.Tensor],
],
None,
]
Expand Down Expand Up @@ -763,13 +764,22 @@ def compute(
"_hash_zch_identities",
):
if self.post_lookup_tracker_fn is not None:
runtime_meta = None
if (
hasattr(mcm, "_hash_zch_runtime_meta")
and mcm._hash_zch_runtime_meta is not None
):
runtime_meta = mcm._hash_zch_runtime_meta.index_select(
dim=0, index=mc_input[table].values()
)
self.post_lookup_tracker_fn(
KeyedJaggedTensor.from_jt_dict(mc_input),
torch.empty(0),
None,
mcm._hash_zch_identities.index_select(
dim=0, index=mc_input[table].values()
),
runtime_meta,
)
values = torch.cat([jt.values() for jt in output.values()])
else:
Expand All @@ -791,11 +801,20 @@ def compute(
values = mc_input[table].values()
if hasattr(mcm, "_hash_zch_identities"):
if self.post_lookup_tracker_fn is not None:
runtime_meta = None
if (
hasattr(mcm, "_hash_zch_runtime_meta")
and mcm._hash_zch_runtime_meta is not None
):
runtime_meta = mcm._hash_zch_runtime_meta.index_select(
dim=0, index=values
)
self.post_lookup_tracker_fn(
KeyedJaggedTensor.from_jt_dict(mc_input),
torch.empty(0),
None,
mcm._hash_zch_identities.index_select(dim=0, index=values),
runtime_meta,
)

remapped_kjts.append(
Expand Down Expand Up @@ -895,6 +914,7 @@ def register_post_lookup_tracker_fn(
torch.Tensor,
Optional[nn.Module],
Optional[torch.Tensor],
Optional[torch.Tensor],
],
None,
],
Expand Down
7 changes: 6 additions & 1 deletion torchrec/distributed/model_tracker/delta_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def append(
ids: torch.Tensor,
states: Optional[torch.Tensor] = None,
raw_ids: Optional[torch.Tensor] = None,
runtime_meta: Optional[torch.Tensor] = None,
) -> None:
"""
Append a batch of ids and states to the store for a specific table.
Expand Down Expand Up @@ -165,6 +166,7 @@ def append(
ids: torch.Tensor,
states: Optional[torch.Tensor] = None,
raw_ids: Optional[torch.Tensor] = None,
runtime_meta: Optional[torch.Tensor] = None,
) -> None:
table_fqn_lookup = self.per_fqn_lookups.get(fqn, [])
table_fqn_lookup.append(
Expand Down Expand Up @@ -284,10 +286,13 @@ def append(
ids: torch.Tensor,
states: Optional[torch.Tensor] = None,
raw_ids: Optional[torch.Tensor] = None,
runtime_meta: Optional[torch.Tensor] = None,
) -> None:
table_fqn_lookup = self.per_fqn_lookups.get(fqn, [])
table_fqn_lookup.append(
RawIndexedLookup(batch_idx=batch_idx, ids=ids, raw_ids=raw_ids)
RawIndexedLookup(
batch_idx=batch_idx, ids=ids, raw_ids=raw_ids, runtime_meta=runtime_meta
)
)
self.per_fqn_lookups[fqn] = table_fqn_lookup

Expand Down
46 changes: 38 additions & 8 deletions torchrec/distributed/model_tracker/trackers/raw_id_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,11 @@ def record_lookup(
states: torch.Tensor,
emb_module: Optional[nn.Module] = None,
raw_ids: Optional[torch.Tensor] = None,
runtime_meta: Optional[torch.Tensor] = None,
) -> None:
per_table_ids: Dict[str, List[torch.Tensor]] = {}
per_table_raw_ids: Dict[str, List[torch.Tensor]] = {}
per_table_runtime_meta: Dict[str, List[torch.Tensor]] = {}

# Skip storing invalid input or raw ids
if (
Expand All @@ -197,28 +199,50 @@ def record_lookup(
):
return

embeddings_2d = raw_ids.view(kjt.values().numel(), -1)
# Skip storing if runtime_meta is provided but has invalid shape
if runtime_meta is not None and not (
runtime_meta.numel() % kjt.values().numel() == 0
):
return

raw_ids_2d = raw_ids.view(kjt.values().numel(), -1)
runtime_meta_2d = None
# It is possible that runtime_meta is None while raw_ids is not None so we will proceed
if runtime_meta is not None:
runtime_meta_2d = runtime_meta.view(kjt.values().numel(), -1)

offset: int = 0
for key in kjt.keys():
table_fqn = self.table_to_fqn[key]
ids_list: List[torch.Tensor] = per_table_ids.get(table_fqn, [])
emb_list: List[torch.Tensor] = per_table_raw_ids.get(table_fqn, [])
raw_ids_list: List[torch.Tensor] = per_table_raw_ids.get(table_fqn, [])
runtime_meta_list: List[torch.Tensor] = per_table_runtime_meta.get(
table_fqn, []
)

ids = kjt[key].values()
ids_list.append(ids)
emb_list.append(embeddings_2d[offset : offset + ids.numel()])
raw_ids_list.append(raw_ids_2d[offset : offset + ids.numel()])
if runtime_meta_2d is not None:
runtime_meta_list.append(runtime_meta_2d[offset : offset + ids.numel()])
offset += ids.numel()

per_table_ids[table_fqn] = ids_list
per_table_raw_ids[table_fqn] = emb_list
per_table_raw_ids[table_fqn] = raw_ids_list
if runtime_meta_2d is not None:
per_table_runtime_meta[table_fqn] = runtime_meta_list

for table_fqn, ids_list in per_table_ids.items():
self.store.append(
batch_idx=self.curr_batch_idx,
fqn=table_fqn,
ids=torch.cat(ids_list),
raw_ids=torch.cat(per_table_raw_ids[table_fqn]),
runtime_meta=(
torch.cat(per_table_runtime_meta[table_fqn])
if table_fqn in per_table_runtime_meta
else None
),
)

def _clean_fqn_fn(self, fqn: str) -> str:
Expand Down Expand Up @@ -277,8 +301,8 @@ def get_indexed_lookups(
self,
tables: List[str],
consumer: Optional[str] = None,
) -> Dict[str, List[torch.Tensor]]:
raw_id_per_table: Dict[str, List[torch.Tensor]] = {}
) -> Dict[str, Tuple[List[torch.Tensor], List[torch.Tensor]]]:
result: Dict[str, Tuple[List[torch.Tensor], List[torch.Tensor]]] = {}
consumer = consumer or self.DEFAULT_CONSUMER
assert (
consumer in self.per_consumer_batch_idx
Expand All @@ -293,17 +317,23 @@ def get_indexed_lookups(

for table in tables:
raw_ids_list = []
runtime_meta_list = []
fqn = self.table_to_fqn[table]
if fqn in indexed_lookups:
for indexed_lookup in indexed_lookups[fqn]:
if indexed_lookup.raw_ids is not None:
raw_ids_list.append(indexed_lookup.raw_ids)
raw_id_per_table[table] = raw_ids_list
if indexed_lookup.runtime_meta is not None:
runtime_meta_list.append(indexed_lookup.runtime_meta)
if (
raw_ids_list
): # if raw_ids doesn't exist runtime_meta will not exist so no need to check for runtime_meta
result[table] = (raw_ids_list, runtime_meta_list)

if self._delete_on_read:
self.store.delete(up_to_idx=min(self.per_consumer_batch_idx.values()))

return raw_id_per_table
return result

def delete(self, up_to_idx: Optional[int]) -> None:
self.store.delete(up_to_idx)
1 change: 1 addition & 0 deletions torchrec/distributed/model_tracker/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class RawIndexedLookup:
batch_idx: int
ids: torch.Tensor
raw_ids: Optional[torch.Tensor] = None
runtime_meta: Optional[torch.Tensor] = None


@dataclass
Expand Down
Loading