Skip to content

Commit 93eae33

Browse files
jiayulufacebook-github-bot
authored andcommitted
torchrec changes v4 (#3348)
Summary: Pull Request resolved: #3348 add utility functions to re-initialize torch states of ShardedEmbeddingBag classes Reviewed By: liangbeixu Differential Revision: D81154653 fbshipit-source-id: 80e90575f66a01d127558420842686e45669a3ff
1 parent 6cd43b2 commit 93eae33

File tree

4 files changed

+272
-221
lines changed

4 files changed

+272
-221
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2310,15 +2310,22 @@ def __init__(
23102310
self._emb_module,
23112311
pg,
23122312
)
2313-
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = dict(
2313+
self.init_parameters()
2314+
2315+
@property
2316+
def _param_per_table(self) -> Dict[str, TableBatchedEmbeddingSlice]:
2317+
return dict(
23142318
_gen_named_parameters_by_table_fused(
23152319
emb_module=self._emb_module,
23162320
table_name_to_count=self.table_name_to_count.copy(),
23172321
config=self._config,
2318-
pg=pg,
2322+
pg=self._pg,
23192323
)
23202324
)
2321-
self.init_parameters()
2325+
2326+
@_param_per_table.setter
2327+
def _param_per_table(self, v: Dict[str, TableBatchedEmbeddingSlice]) -> None:
2328+
self.__dict__["_param_per_table"] = v
23222329

23232330
@property
23242331
def emb_module(
@@ -3169,15 +3176,22 @@ def __init__(
31693176
self._emb_module,
31703177
pg,
31713178
)
3172-
self._param_per_table: Dict[str, TableBatchedEmbeddingSlice] = dict(
3179+
self.init_parameters()
3180+
3181+
@property
3182+
def _param_per_table(self) -> Dict[str, TableBatchedEmbeddingSlice]:
3183+
return dict(
31733184
_gen_named_parameters_by_table_fused(
31743185
emb_module=self._emb_module,
31753186
table_name_to_count=self.table_name_to_count.copy(),
31763187
config=self._config,
3177-
pg=pg,
3188+
pg=self._pg,
31783189
)
31793190
)
3180-
self.init_parameters()
3191+
3192+
@_param_per_table.setter
3193+
def _param_per_table(self, v: Dict[str, TableBatchedEmbeddingSlice]) -> None:
3194+
self.__dict__["_param_per_table"] = v
31813195

31823196
@property
31833197
def emb_module(

torchrec/distributed/fp_embeddingbag.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
163163
if "_embedding_bag_collection" in fqn:
164164
yield append_prefix(prefix, fqn)
165165

166+
def _initialize_torch_state(self, skip_registering: bool = False) -> None: # noqa
167+
self._embedding_bag_collection._initialize_torch_state(skip_registering)
168+
166169

167170
class FeatureProcessedEmbeddingBagCollectionSharder(
168171
BaseEmbeddingSharder[FeatureProcessedEmbeddingBagCollection]

0 commit comments

Comments
 (0)