From aea4525031f41cda977d6e24e26dfa29b971d54d Mon Sep 17 00:00:00 2001 From: Fangyu Luo Date: Sun, 23 Oct 2022 21:46:27 -0700 Subject: [PATCH] Authoring Aware ROO on LTV ReAgent model Summary: [Free] Authoring Aware ROO on LTV ReAgent model Differential Revision: D40456561 fbshipit-source-id: 5f04033b615b0aa6c4ae43cd4d3b6f9743919933 --- reagent/core/types.py | 2 + reagent/models/sparse_dqn.py | 103 +++++++++++++++------ reagent/test/models/test_sparse_dqn_net.py | 20 +++- 3 files changed, 97 insertions(+), 28 deletions(-) diff --git a/reagent/core/types.py b/reagent/core/types.py index deaf168a..f05d50e0 100644 --- a/reagent/core/types.py +++ b/reagent/core/types.py @@ -313,6 +313,7 @@ class FeatureData(TensorDataClass): float_features: torch.Tensor # For sparse features saved in KeyedJaggedTensor format id_list_features: Optional[KeyedJaggedTensor] = None + id_list_features_ro: Optional[KeyedJaggedTensor] = None id_score_list_features: Optional[KeyedJaggedTensor] = None # For sparse features saved in dictionary format @@ -339,6 +340,7 @@ def __post_init__(self): def has_float_features_only(self) -> bool: return ( not self.id_list_features + and not self.id_list_features_ro and not self.id_score_list_features and self.time_since_first is None and self.candidate_docs is None diff --git a/reagent/models/sparse_dqn.py b/reagent/models/sparse_dqn.py index 249fa874..bf6f7d46 100644 --- a/reagent/models/sparse_dqn.py +++ b/reagent/models/sparse_dqn.py @@ -1,13 +1,13 @@ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from typing import List +from typing import List, Optional, Tuple import torch from reagent.core import types as rlt from reagent.models import FullyConnectedNetwork from reagent.models.base import ModelBase -from torchrec.models.dlrm import SparseArch +from torchrec.models.dlrm import SparseArchRO from torchrec.modules.embedding_modules import EmbeddingBagCollection from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -16,20 +16,39 @@ @torch.fx.wrap def fetch_id_list_features( state: rlt.FeatureData, action: rlt.FeatureData -) -> KeyedJaggedTensor: - assert state.id_list_features is not None or action.id_list_features is not None - if state.id_list_features is not None and action.id_list_features is None: - sparse_features = state.id_list_features - elif state.id_list_features is None and action.id_list_features is not None: - sparse_features = action.id_list_features - elif state.id_list_features is not None and action.id_list_features is not None: - sparse_features = KeyedJaggedTensor.concat( - [state.id_list_features, action.id_list_features] - ) - else: +) -> Tuple[Optional[KeyedJaggedTensor], Optional[KeyedJaggedTensor]]: + assert ( + state.id_list_features is not None + or state.id_list_features_ro is not None + or action.id_list_features is not None + or action.id_list_features_ro is not None + ) + + def _get_sparse_features( + id_list_features_1, id_list_features_2 + ) -> Optional[KeyedJaggedTensor]: + sparse_features = None + if id_list_features_1 is not None and id_list_features_2 is None: + sparse_features = id_list_features_1 + elif id_list_features_1 is None and id_list_features_2 is not None: + sparse_features = id_list_features_2 + elif id_list_features_1 is not None and id_list_features_2 is not None: + sparse_features = KeyedJaggedTensor.concat( + [id_list_features_1, id_list_features_2] + ) + return sparse_features + + sparse_features = _get_sparse_features( + state.id_list_features, action.id_list_features + ) + sparse_features_ro = _get_sparse_features( + state.id_list_features_ro, action.id_list_features_ro + ) + if sparse_features is None and sparse_features_ro is None: raise ValueError + # TODO: add id_list_score_features - return sparse_features + return sparse_features, sparse_features_ro class SparseDQN(ModelBase): @@ -41,7 +60,8 @@ class SparseDQN(ModelBase): def __init__( self, state_dense_dim: int, - embedding_bag_collection: EmbeddingBagCollection, + embedding_bag_collection: Optional[EmbeddingBagCollection], + embedding_bag_collection_ro: Optional[EmbeddingBagCollection], action_dense_dim: int, overarch_dims: List[int], activation: str = "relu", @@ -51,17 +71,37 @@ def __init__( output_dim: int = 1, ) -> None: super().__init__() - self.sparse_arch: SparseArch = SparseArch(embedding_bag_collection) + self.sparse_arch: SparseArchRO = SparseArchRO( + embedding_bag_collection, embedding_bag_collection_ro + ) + + self.sparse_embedding_dim: int = ( + sum( + [ + len(embc.feature_names) * embc.embedding_dim + for embc in embedding_bag_collection.embedding_bag_configs() + ] + ) + if embedding_bag_collection is not None + else 0 + ) - self.sparse_embedding_dim: int = sum( - [ - len(embc.feature_names) * embc.embedding_dim - for embc in embedding_bag_collection.embedding_bag_configs() - ] + self.sparse_embedding_dim_ro: int = ( + sum( + [ + len(embc.feature_names) * embc.embedding_dim + for embc in embedding_bag_collection.embedding_bag_configs() + ] + ) + if embedding_bag_collection is not None + else 0 ) self.input_dim: int = ( - state_dense_dim + self.sparse_embedding_dim + action_dense_dim + state_dense_dim + + self.sparse_embedding_dim + + self.sparse_embedding_dim_ro + + action_dense_dim ) layers = [self.input_dim] + overarch_dims + [output_dim] activations = [activation] * len(overarch_dims) + [final_activation] @@ -76,11 +116,20 @@ def forward(self, state: rlt.FeatureData, action: rlt.FeatureData) -> torch.Tens (state.float_features, action.float_features), dim=-1 ) batch_size = dense_features.shape[0] - sparse_features = fetch_id_list_features(state, action) + sparse_features, sparse_features_ro = fetch_id_list_features(state, action) # shape: batch_size, num_sparse_features, embedding_dim - embedded_sparse = self.sparse_arch(sparse_features) - # shape: batch_size, num_sparse_features * embedding_dim - embedded_sparse = embedded_sparse.reshape(batch_size, -1) - concatenated_dense = torch.cat((dense_features, embedded_sparse), dim=-1) + embedded_sparse, embedded_sparse_ro = self.sparse_arch( + sparse_features, sparse_features_ro + ) + features_list: List[torch.Tensor] = [dense_features] + if embedded_sparse is not None: + # shape: batch_size, num_sparse_features * embedding_dim + embedded_sparse = embedded_sparse.reshape(batch_size, -1) + features_list.append(embedded_sparse) + if embedded_sparse_ro is not None: + # shape: batch_size, num_sparse_features * embedding_dim + embedded_sparse_ro = embedded_sparse_ro.reshape(batch_size, -1) + features_list.append(embedded_sparse_ro) + concatenated_dense = torch.cat(features_list, dim=-1) return self.q_network(concatenated_dense) diff --git a/reagent/test/models/test_sparse_dqn_net.py b/reagent/test/models/test_sparse_dqn_net.py index 896cf4b1..5da2cf6d 100644 --- a/reagent/test/models/test_sparse_dqn_net.py +++ b/reagent/test/models/test_sparse_dqn_net.py @@ -16,6 +16,7 @@ def test_single_step_sparse_dqn(self): embedding_table_size = 1000 embedding_dim = 32 num_sparse_features = 2 # refer to watched_ids and liked_ids below + embedding_bag_configs = [ EmbeddingBagConfig( name="video_id", @@ -24,13 +25,27 @@ def test_single_step_sparse_dqn(self): embedding_dim=embedding_dim, ) ] + + num_sparse_features_ro = 2 # refer to watched_page_ids and liked_ids below + embedding_bag_configs_ro = [ + EmbeddingBagConfig( + name="watched_page_ids", + feature_names=["watched_page_ids", "liked_ids"], + num_embeddings=embedding_table_size, + embedding_dim=embedding_dim, + ) + ] embedding_bag_col = EmbeddingBagCollection( device=torch.device("cpu"), tables=embedding_bag_configs ) + embedding_bag_col_ro = EmbeddingBagCollection( + device=torch.device("cpu"), tables=embedding_bag_configs_ro + ) net = SparseDQN( state_dense_dim=state_dense_dim, embedding_bag_collection=embedding_bag_col, + embedding_bag_collection_ro=embedding_bag_col_ro, action_dense_dim=action_dense_dim, overarch_dims=dense_sizes, activation=activation, @@ -42,7 +57,10 @@ def test_single_step_sparse_dqn(self): # number of sparse features times embedding dimension for sparse features assert ( net[1].in_features - == state_dense_dim + action_dense_dim + num_sparse_features * embedding_dim + == state_dense_dim + + action_dense_dim + + num_sparse_features * embedding_dim + + num_sparse_features_ro * embedding_dim ) assert net[1].out_features == dense_sizes[0] assert net[4].in_features == dense_sizes[0]