Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion torchrec/distributed/model_tracker/delta_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def append(
ids: torch.Tensor,
states: Optional[torch.Tensor] = None,
raw_ids: Optional[torch.Tensor] = None,
runtime_meta: Optional[torch.Tensor] = None,
) -> None:
"""
Append a batch of ids and states to the store for a specific table.
Expand Down Expand Up @@ -165,6 +166,7 @@ def append(
ids: torch.Tensor,
states: Optional[torch.Tensor] = None,
raw_ids: Optional[torch.Tensor] = None,
runtime_meta: Optional[torch.Tensor] = None,
) -> None:
table_fqn_lookup = self.per_fqn_lookups.get(fqn, [])
table_fqn_lookup.append(
Expand Down Expand Up @@ -284,10 +286,13 @@ def append(
ids: torch.Tensor,
states: Optional[torch.Tensor] = None,
raw_ids: Optional[torch.Tensor] = None,
runtime_meta: Optional[torch.Tensor] = None,
) -> None:
table_fqn_lookup = self.per_fqn_lookups.get(fqn, [])
table_fqn_lookup.append(
RawIndexedLookup(batch_idx=batch_idx, ids=ids, raw_ids=raw_ids)
RawIndexedLookup(
batch_idx=batch_idx, ids=ids, raw_ids=raw_ids, runtime_meta=runtime_meta
)
)
self.per_fqn_lookups[fqn] = table_fqn_lookup

Expand Down
66 changes: 53 additions & 13 deletions torchrec/distributed/model_tracker/trackers/raw_id_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,40 +185,74 @@ def record_lookup(
states: torch.Tensor,
emb_module: Optional[nn.Module] = None,
raw_ids: Optional[torch.Tensor] = None,
runtime_meta: Optional[torch.Tensor] = None,
) -> None:
per_table_ids: Dict[str, List[torch.Tensor]] = {}
per_table_raw_ids: Dict[str, List[torch.Tensor]] = {}
per_table_runtime_meta: Dict[str, List[torch.Tensor]] = {}

# Skip storing invalid input or raw ids
if (
raw_ids is None
or (kjt.values().numel() == 0)
or not (raw_ids.numel() % kjt.values().numel() == 0)
# Skip storing invalid input or raw ids, note that runtime_meta will only exist if raw_ids exists so we can return early
if raw_ids is None:
logger.debug("Skipping record_lookup: raw_ids is None")
return

if kjt.values().numel() == 0:
logger.debug("Skipping record_lookup: kjt.values() is empty")
return

if not (raw_ids.numel() % kjt.values().numel() == 0):
logger.warning(
f"Skipping record_lookup. Raw_ids has invalid shape {raw_ids.shape}, expected multiple of {kjt.values().numel()}"
)
return

# Skip storing if runtime_meta is provided but has invalid shape
if runtime_meta is not None and not (
runtime_meta.numel() % kjt.values().numel() == 0
):
logger.warning(
f"Skipping record_lookup. Runtime_meta has invalid shape {runtime_meta.shape}, expected multiple of {kjt.values().numel()}"
)
return

embeddings_2d = raw_ids.view(kjt.values().numel(), -1)
raw_ids_2d = raw_ids.view(kjt.values().numel(), -1)
runtime_meta_2d = None
# It is possible that runtime_meta is None while raw_ids is not None so we will proceed
if runtime_meta is not None:
runtime_meta_2d = runtime_meta.view(kjt.values().numel(), -1)

offset: int = 0
for key in kjt.keys():
table_fqn = self.table_to_fqn[key]
ids_list: List[torch.Tensor] = per_table_ids.get(table_fqn, [])
emb_list: List[torch.Tensor] = per_table_raw_ids.get(table_fqn, [])
raw_ids_list: List[torch.Tensor] = per_table_raw_ids.get(table_fqn, [])
runtime_meta_list: List[torch.Tensor] = per_table_runtime_meta.get(
table_fqn, []
)

ids = kjt[key].values()
ids_list.append(ids)
emb_list.append(embeddings_2d[offset : offset + ids.numel()])
raw_ids_list.append(raw_ids_2d[offset : offset + ids.numel()])
if runtime_meta_2d is not None:
runtime_meta_list.append(runtime_meta_2d[offset : offset + ids.numel()])
offset += ids.numel()

per_table_ids[table_fqn] = ids_list
per_table_raw_ids[table_fqn] = emb_list
per_table_raw_ids[table_fqn] = raw_ids_list
if runtime_meta_2d is not None:
per_table_runtime_meta[table_fqn] = runtime_meta_list

for table_fqn, ids_list in per_table_ids.items():
self.store.append(
batch_idx=self.curr_batch_idx,
fqn=table_fqn,
ids=torch.cat(ids_list),
raw_ids=torch.cat(per_table_raw_ids[table_fqn]),
runtime_meta=(
torch.cat(per_table_runtime_meta[table_fqn])
if table_fqn in per_table_runtime_meta
else None
),
)

def _clean_fqn_fn(self, fqn: str) -> str:
Expand Down Expand Up @@ -277,8 +311,8 @@ def get_indexed_lookups(
self,
tables: List[str],
consumer: Optional[str] = None,
) -> Dict[str, List[torch.Tensor]]:
raw_id_per_table: Dict[str, List[torch.Tensor]] = {}
) -> Dict[str, Tuple[List[torch.Tensor], List[torch.Tensor]]]:
result: Dict[str, Tuple[List[torch.Tensor], List[torch.Tensor]]] = {}
consumer = consumer or self.DEFAULT_CONSUMER
assert (
consumer in self.per_consumer_batch_idx
Expand All @@ -293,17 +327,23 @@ def get_indexed_lookups(

for table in tables:
raw_ids_list = []
runtime_meta_list = []
fqn = self.table_to_fqn[table]
if fqn in indexed_lookups:
for indexed_lookup in indexed_lookups[fqn]:
if indexed_lookup.raw_ids is not None:
raw_ids_list.append(indexed_lookup.raw_ids)
raw_id_per_table[table] = raw_ids_list
if indexed_lookup.runtime_meta is not None:
runtime_meta_list.append(indexed_lookup.runtime_meta)
if (
raw_ids_list
): # if raw_ids doesn't exist runtime_meta will not exist so no need to check for runtime_meta
result[table] = (raw_ids_list, runtime_meta_list)

if self._delete_on_read:
self.store.delete(up_to_idx=min(self.per_consumer_batch_idx.values()))

return raw_id_per_table
return result

def delete(self, up_to_idx: Optional[int]) -> None:
self.store.delete(up_to_idx)
1 change: 1 addition & 0 deletions torchrec/distributed/model_tracker/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class RawIndexedLookup:
batch_idx: int
ids: torch.Tensor
raw_ids: Optional[torch.Tensor] = None
runtime_meta: Optional[torch.Tensor] = None


@dataclass
Expand Down
Loading