Skip to content

Commit

Permalink
Authoring Aware ROO on LTV ReAgent model
Browse files Browse the repository at this point in the history
Summary: [Free] Authoring Aware ROO on LTV ReAgent model

Differential Revision: D40456561

fbshipit-source-id: 5f04033b615b0aa6c4ae43cd4d3b6f9743919933
  • Loading branch information
Fangyu Luo authored and facebook-github-bot committed Oct 24, 2022
1 parent 30715aa commit aea4525
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 28 deletions.
2 changes: 2 additions & 0 deletions reagent/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ class FeatureData(TensorDataClass):
float_features: torch.Tensor
# For sparse features saved in KeyedJaggedTensor format
id_list_features: Optional[KeyedJaggedTensor] = None
id_list_features_ro: Optional[KeyedJaggedTensor] = None
id_score_list_features: Optional[KeyedJaggedTensor] = None

# For sparse features saved in dictionary format
Expand All @@ -339,6 +340,7 @@ def __post_init__(self):
def has_float_features_only(self) -> bool:
return (
not self.id_list_features
and not self.id_list_features_ro
and not self.id_score_list_features
and self.time_since_first is None
and self.candidate_docs is None
Expand Down
103 changes: 76 additions & 27 deletions reagent/models/sparse_dqn.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

from typing import List
from typing import List, Optional, Tuple

import torch
from reagent.core import types as rlt
from reagent.models import FullyConnectedNetwork
from reagent.models.base import ModelBase
from torchrec.models.dlrm import SparseArch
from torchrec.models.dlrm import SparseArchRO
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

Expand All @@ -16,20 +16,39 @@
@torch.fx.wrap
def fetch_id_list_features(
state: rlt.FeatureData, action: rlt.FeatureData
) -> KeyedJaggedTensor:
assert state.id_list_features is not None or action.id_list_features is not None
if state.id_list_features is not None and action.id_list_features is None:
sparse_features = state.id_list_features
elif state.id_list_features is None and action.id_list_features is not None:
sparse_features = action.id_list_features
elif state.id_list_features is not None and action.id_list_features is not None:
sparse_features = KeyedJaggedTensor.concat(
[state.id_list_features, action.id_list_features]
)
else:
) -> Tuple[Optional[KeyedJaggedTensor], Optional[KeyedJaggedTensor]]:
assert (
state.id_list_features is not None
or state.id_list_features_ro is not None
or action.id_list_features is not None
or action.id_list_features_ro is not None
)

def _get_sparse_features(
id_list_features_1, id_list_features_2
) -> Optional[KeyedJaggedTensor]:
sparse_features = None
if id_list_features_1 is not None and id_list_features_2 is None:
sparse_features = id_list_features_1
elif id_list_features_1 is None and id_list_features_2 is not None:
sparse_features = id_list_features_2
elif id_list_features_1 is not None and id_list_features_2 is not None:
sparse_features = KeyedJaggedTensor.concat(
[id_list_features_1, id_list_features_2]
)
return sparse_features

sparse_features = _get_sparse_features(
state.id_list_features, action.id_list_features
)
sparse_features_ro = _get_sparse_features(
state.id_list_features_ro, action.id_list_features_ro
)
if sparse_features is None and sparse_features_ro is None:
raise ValueError

# TODO: add id_list_score_features
return sparse_features
return sparse_features, sparse_features_ro


class SparseDQN(ModelBase):
Expand All @@ -41,7 +60,8 @@ class SparseDQN(ModelBase):
def __init__(
self,
state_dense_dim: int,
embedding_bag_collection: EmbeddingBagCollection,
embedding_bag_collection: Optional[EmbeddingBagCollection],
embedding_bag_collection_ro: Optional[EmbeddingBagCollection],
action_dense_dim: int,
overarch_dims: List[int],
activation: str = "relu",
Expand All @@ -51,17 +71,37 @@ def __init__(
output_dim: int = 1,
) -> None:
super().__init__()
self.sparse_arch: SparseArch = SparseArch(embedding_bag_collection)
self.sparse_arch: SparseArchRO = SparseArchRO(
embedding_bag_collection, embedding_bag_collection_ro
)

self.sparse_embedding_dim: int = (
sum(
[
len(embc.feature_names) * embc.embedding_dim
for embc in embedding_bag_collection.embedding_bag_configs()
]
)
if embedding_bag_collection is not None
else 0
)

self.sparse_embedding_dim: int = sum(
[
len(embc.feature_names) * embc.embedding_dim
for embc in embedding_bag_collection.embedding_bag_configs()
]
self.sparse_embedding_dim_ro: int = (
sum(
[
len(embc.feature_names) * embc.embedding_dim
for embc in embedding_bag_collection.embedding_bag_configs()
]
)
if embedding_bag_collection is not None
else 0
)

self.input_dim: int = (
state_dense_dim + self.sparse_embedding_dim + action_dense_dim
state_dense_dim
+ self.sparse_embedding_dim
+ self.sparse_embedding_dim_ro
+ action_dense_dim
)
layers = [self.input_dim] + overarch_dims + [output_dim]
activations = [activation] * len(overarch_dims) + [final_activation]
Expand All @@ -76,11 +116,20 @@ def forward(self, state: rlt.FeatureData, action: rlt.FeatureData) -> torch.Tens
(state.float_features, action.float_features), dim=-1
)
batch_size = dense_features.shape[0]
sparse_features = fetch_id_list_features(state, action)
sparse_features, sparse_features_ro = fetch_id_list_features(state, action)
# shape: batch_size, num_sparse_features, embedding_dim
embedded_sparse = self.sparse_arch(sparse_features)
# shape: batch_size, num_sparse_features * embedding_dim
embedded_sparse = embedded_sparse.reshape(batch_size, -1)
concatenated_dense = torch.cat((dense_features, embedded_sparse), dim=-1)
embedded_sparse, embedded_sparse_ro = self.sparse_arch(
sparse_features, sparse_features_ro
)
features_list: List[torch.Tensor] = [dense_features]
if embedded_sparse is not None:
# shape: batch_size, num_sparse_features * embedding_dim
embedded_sparse = embedded_sparse.reshape(batch_size, -1)
features_list.append(embedded_sparse)
if embedded_sparse_ro is not None:
# shape: batch_size, num_sparse_features * embedding_dim
embedded_sparse_ro = embedded_sparse_ro.reshape(batch_size, -1)
features_list.append(embedded_sparse_ro)

concatenated_dense = torch.cat(features_list, dim=-1)
return self.q_network(concatenated_dense)
20 changes: 19 additions & 1 deletion reagent/test/models/test_sparse_dqn_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def test_single_step_sparse_dqn(self):
embedding_table_size = 1000
embedding_dim = 32
num_sparse_features = 2 # refer to watched_ids and liked_ids below

embedding_bag_configs = [
EmbeddingBagConfig(
name="video_id",
Expand All @@ -24,13 +25,27 @@ def test_single_step_sparse_dqn(self):
embedding_dim=embedding_dim,
)
]

num_sparse_features_ro = 2 # refer to watched_page_ids and liked_ids below
embedding_bag_configs_ro = [
EmbeddingBagConfig(
name="watched_page_ids",
feature_names=["watched_page_ids", "liked_ids"],
num_embeddings=embedding_table_size,
embedding_dim=embedding_dim,
)
]
embedding_bag_col = EmbeddingBagCollection(
device=torch.device("cpu"), tables=embedding_bag_configs
)
embedding_bag_col_ro = EmbeddingBagCollection(
device=torch.device("cpu"), tables=embedding_bag_configs_ro
)

net = SparseDQN(
state_dense_dim=state_dense_dim,
embedding_bag_collection=embedding_bag_col,
embedding_bag_collection_ro=embedding_bag_col_ro,
action_dense_dim=action_dense_dim,
overarch_dims=dense_sizes,
activation=activation,
Expand All @@ -42,7 +57,10 @@ def test_single_step_sparse_dqn(self):
# number of sparse features times embedding dimension for sparse features
assert (
net[1].in_features
== state_dense_dim + action_dense_dim + num_sparse_features * embedding_dim
== state_dense_dim
+ action_dense_dim
+ num_sparse_features * embedding_dim
+ num_sparse_features_ro * embedding_dim
)
assert net[1].out_features == dense_sizes[0]
assert net[4].in_features == dense_sizes[0]
Expand Down

0 comments on commit aea4525

Please sign in to comment.