diff --git a/torchrec/distributed/model_tracker/delta_store.py b/torchrec/distributed/model_tracker/delta_store.py index d067c055d..29f621d18 100644 --- a/torchrec/distributed/model_tracker/delta_store.py +++ b/torchrec/distributed/model_tracker/delta_store.py @@ -93,6 +93,7 @@ def append( ids: torch.Tensor, states: Optional[torch.Tensor] = None, raw_ids: Optional[torch.Tensor] = None, + runtime_meta: Optional[torch.Tensor] = None, ) -> None: """ Append a batch of ids and states to the store for a specific table. @@ -165,6 +166,7 @@ def append( ids: torch.Tensor, states: Optional[torch.Tensor] = None, raw_ids: Optional[torch.Tensor] = None, + runtime_meta: Optional[torch.Tensor] = None, ) -> None: table_fqn_lookup = self.per_fqn_lookups.get(fqn, []) table_fqn_lookup.append( @@ -284,10 +286,13 @@ def append( ids: torch.Tensor, states: Optional[torch.Tensor] = None, raw_ids: Optional[torch.Tensor] = None, + runtime_meta: Optional[torch.Tensor] = None, ) -> None: table_fqn_lookup = self.per_fqn_lookups.get(fqn, []) table_fqn_lookup.append( - RawIndexedLookup(batch_idx=batch_idx, ids=ids, raw_ids=raw_ids) + RawIndexedLookup( + batch_idx=batch_idx, ids=ids, raw_ids=raw_ids, runtime_meta=runtime_meta + ) ) self.per_fqn_lookups[fqn] = table_fqn_lookup diff --git a/torchrec/distributed/model_tracker/trackers/raw_id_tracker.py b/torchrec/distributed/model_tracker/trackers/raw_id_tracker.py index 0db0e6b50..d277741e3 100644 --- a/torchrec/distributed/model_tracker/trackers/raw_id_tracker.py +++ b/torchrec/distributed/model_tracker/trackers/raw_id_tracker.py @@ -185,33 +185,62 @@ def record_lookup( states: torch.Tensor, emb_module: Optional[nn.Module] = None, raw_ids: Optional[torch.Tensor] = None, + runtime_meta: Optional[torch.Tensor] = None, ) -> None: per_table_ids: Dict[str, List[torch.Tensor]] = {} per_table_raw_ids: Dict[str, List[torch.Tensor]] = {} + per_table_runtime_meta: Dict[str, List[torch.Tensor]] = {} - # Skip storing invalid input or raw ids - if ( - raw_ids is None - or (kjt.values().numel() == 0) - or not (raw_ids.numel() % kjt.values().numel() == 0) + # Skip storing invalid input or raw ids, note that runtime_meta will only exist if raw_ids exists so we can return early + if raw_ids is None: + logger.debug("Skipping record_lookup: raw_ids is None") + return + + if kjt.values().numel() == 0: + logger.debug("Skipping record_lookup: kjt.values() is empty") + return + + if not (raw_ids.numel() % kjt.values().numel() == 0): + logger.warning( + f"Skipping record_lookup. Raw_ids has invalid shape {raw_ids.shape}, expected multiple of {kjt.values().numel()}" + ) + return + + # Skip storing if runtime_meta is provided but has invalid shape + if runtime_meta is not None and not ( + runtime_meta.numel() % kjt.values().numel() == 0 ): + logger.warning( + f"Skipping record_lookup. Runtime_meta has invalid shape {runtime_meta.shape}, expected multiple of {kjt.values().numel()}" + ) return - embeddings_2d = raw_ids.view(kjt.values().numel(), -1) + raw_ids_2d = raw_ids.view(kjt.values().numel(), -1) + runtime_meta_2d = None + # It is possible that runtime_meta is None while raw_ids is not None so we will proceed + if runtime_meta is not None: + runtime_meta_2d = runtime_meta.view(kjt.values().numel(), -1) offset: int = 0 for key in kjt.keys(): table_fqn = self.table_to_fqn[key] ids_list: List[torch.Tensor] = per_table_ids.get(table_fqn, []) - emb_list: List[torch.Tensor] = per_table_raw_ids.get(table_fqn, []) + raw_ids_list: List[torch.Tensor] = per_table_raw_ids.get(table_fqn, []) + runtime_meta_list: List[torch.Tensor] = per_table_runtime_meta.get( + table_fqn, [] + ) ids = kjt[key].values() ids_list.append(ids) - emb_list.append(embeddings_2d[offset : offset + ids.numel()]) + raw_ids_list.append(raw_ids_2d[offset : offset + ids.numel()]) + if runtime_meta_2d is not None: + runtime_meta_list.append(runtime_meta_2d[offset : offset + ids.numel()]) offset += ids.numel() per_table_ids[table_fqn] = ids_list - per_table_raw_ids[table_fqn] = emb_list + per_table_raw_ids[table_fqn] = raw_ids_list + if runtime_meta_2d is not None: + per_table_runtime_meta[table_fqn] = runtime_meta_list for table_fqn, ids_list in per_table_ids.items(): self.store.append( @@ -219,6 +248,11 @@ def record_lookup( fqn=table_fqn, ids=torch.cat(ids_list), raw_ids=torch.cat(per_table_raw_ids[table_fqn]), + runtime_meta=( + torch.cat(per_table_runtime_meta[table_fqn]) + if table_fqn in per_table_runtime_meta + else None + ), ) def _clean_fqn_fn(self, fqn: str) -> str: @@ -277,8 +311,8 @@ def get_indexed_lookups( self, tables: List[str], consumer: Optional[str] = None, - ) -> Dict[str, List[torch.Tensor]]: - raw_id_per_table: Dict[str, List[torch.Tensor]] = {} + ) -> Dict[str, Tuple[List[torch.Tensor], List[torch.Tensor]]]: + result: Dict[str, Tuple[List[torch.Tensor], List[torch.Tensor]]] = {} consumer = consumer or self.DEFAULT_CONSUMER assert ( consumer in self.per_consumer_batch_idx @@ -293,17 +327,23 @@ def get_indexed_lookups( for table in tables: raw_ids_list = [] + runtime_meta_list = [] fqn = self.table_to_fqn[table] if fqn in indexed_lookups: for indexed_lookup in indexed_lookups[fqn]: if indexed_lookup.raw_ids is not None: raw_ids_list.append(indexed_lookup.raw_ids) - raw_id_per_table[table] = raw_ids_list + if indexed_lookup.runtime_meta is not None: + runtime_meta_list.append(indexed_lookup.runtime_meta) + if ( + raw_ids_list + ): # if raw_ids doesn't exist runtime_meta will not exist so no need to check for runtime_meta + result[table] = (raw_ids_list, runtime_meta_list) if self._delete_on_read: self.store.delete(up_to_idx=min(self.per_consumer_batch_idx.values())) - return raw_id_per_table + return result def delete(self, up_to_idx: Optional[int]) -> None: self.store.delete(up_to_idx) diff --git a/torchrec/distributed/model_tracker/types.py b/torchrec/distributed/model_tracker/types.py index 15988e872..d926b64a2 100644 --- a/torchrec/distributed/model_tracker/types.py +++ b/torchrec/distributed/model_tracker/types.py @@ -35,6 +35,7 @@ class RawIndexedLookup: batch_idx: int ids: torch.Tensor raw_ids: Optional[torch.Tensor] = None + runtime_meta: Optional[torch.Tensor] = None @dataclass