Skip to content

Commit 82aaaa8

Browse files
EddyLXJfacebook-github-bot
authored andcommitted
Using eviction policy to populate _enable_feature_score_weight_accumulation (#3354)
Summary: Pull Request resolved: #3354 As title, if kvzch table is enable feature score based eviction, we need to populate _enable_feature_score_weight_accumulation to true to collect feature score. Reviewed By: emlin Differential Revision: D81744008 fbshipit-source-id: 79b0b2e5dfb101387545e7d24c04ee133046431e
1 parent b811c5e commit 82aaaa8

File tree

2 files changed

+13
-16
lines changed

2 files changed

+13
-16
lines changed

torchrec/distributed/embedding.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
ShardingType,
4848
)
4949
from torchrec.distributed.fused_params import (
50-
ENABLE_FEATURE_SCORE_WEIGHT_ACCUMULATION,
5150
FUSED_PARAM_IS_SSD_TABLE,
5251
FUSED_PARAM_SSD_TABLE_LIST,
5352
)
@@ -91,6 +90,7 @@
9190
from torchrec.modules.embedding_configs import (
9291
EmbeddingConfig,
9392
EmbeddingTableConfig,
93+
FeatureScoreBasedEvictionPolicy,
9494
PoolingType,
9595
)
9696
from torchrec.modules.embedding_modules import (
@@ -422,18 +422,6 @@ def __init__(
422422
super().__init__(qcomm_codecs_registry=qcomm_codecs_registry)
423423
self._enable_feature_score_weight_accumulation: bool = False
424424

425-
if (
426-
fused_params is not None
427-
and ENABLE_FEATURE_SCORE_WEIGHT_ACCUMULATION in fused_params
428-
):
429-
self._enable_feature_score_weight_accumulation = cast(
430-
bool, fused_params[ENABLE_FEATURE_SCORE_WEIGHT_ACCUMULATION]
431-
)
432-
fused_params.pop(ENABLE_FEATURE_SCORE_WEIGHT_ACCUMULATION)
433-
logger.info(
434-
f"EC feature score weight accumulation enabled: {self._enable_feature_score_weight_accumulation}."
435-
)
436-
437425
self._module_fqn = module_fqn
438426
self._embedding_configs: List[EmbeddingConfig] = module.embedding_configs()
439427
self._table_names: List[str] = [
@@ -492,6 +480,18 @@ def __init__(
492480
self._has_uninitialized_input_dist: bool = True
493481
logger.info(f"EC index dedup enabled: {self._use_index_dedup}.")
494482

483+
for config in self._embedding_configs:
484+
virtual_table_eviction_policy = config.virtual_table_eviction_policy
485+
if virtual_table_eviction_policy is not None and isinstance(
486+
virtual_table_eviction_policy, FeatureScoreBasedEvictionPolicy
487+
):
488+
self._enable_feature_score_weight_accumulation = True
489+
break
490+
491+
logger.info(
492+
f"EC feature score weight accumulation enabled: {self._enable_feature_score_weight_accumulation}."
493+
)
494+
495495
# Get all fused optimizers and combine them.
496496
optims = []
497497
for lookup in self._lookups:

torchrec/distributed/fused_params.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,6 @@
3333
FUSED_PARAM_SSD_TABLE_LIST: str = "__register_ssd_table_list"
3434
# Bool fused param per table to check if the table is offloaded to SSD
3535
FUSED_PARAM_IS_SSD_TABLE: str = "__register_is_ssd_table"
36-
ENABLE_FEATURE_SCORE_WEIGHT_ACCUMULATION: str = (
37-
"enable_feature_score_weight_accumulation"
38-
)
3936

4037

4138
class TBEToRegisterMixIn:

0 commit comments

Comments
 (0)