Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mask out non-present arm scores for Offline Eval #702

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions reagent/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,7 @@ class FrechetSortConfig:
@dataclass
class CBInput(TensorDataClass):
context_arm_features: torch.Tensor
arm_presence: Final[Optional[torch.Tensor]] = None
action: Final[Optional[torch.Tensor]] = None
reward: Final[Optional[torch.Tensor]] = None
log_prob: Final[Optional[torch.Tensor]] = None
Expand All @@ -1137,6 +1138,7 @@ def input_prototype(
def from_dict(cls, d: Dict[str, torch.Tensor]) -> "CBInput":
return cls(
context_arm_features=d["context_arm_features"],
arm_presence=d.get("arm_presence", None),
action=d.get("action", None),
reward=d.get("reward", None),
log_prob=d.get("log_prob", None),
Expand Down
128 changes: 127 additions & 1 deletion reagent/preprocessing/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,94 @@ def __call__(self, data):
return data


class VarLengthSequences:
"""
Like FixedLengthSequences, but doesn't require the sequence-lengths to be the same. Instead,
the largest slate size from the batch is used. For batches with smaller
slate sizes, the values are padded with zeros.
Additionally a presence tensor is produced to indicate which elements are present
vs padded.
The item presense tensor is a float boolean tensor of shape `[B, max_slate_size]`
"""

def __init__(
self,
keys: List[str],
sequence_id: int,
*,
to_keys: Optional[List[str]] = None,
to_keys_item_presence: Optional[List[str]] = None,
):
self.keys = keys
self.sequence_id = sequence_id
self.to_keys = to_keys or keys
self.to_keys_item_presence = to_keys_item_presence or [
k + "_item_presence" for k in self.to_keys
]
assert len(self.to_keys) == len(keys)

def __call__(self, data):
for key, to_key, to_key_item_presence in zip(
self.keys, self.to_keys, self.to_keys_item_presence
):
# ignore the feature presence
offsets, (value, presence) = data[key][self.sequence_id]

# compute the length of each observation
lengths = torch.diff(
torch.cat(
(
offsets,
torch.tensor(
[value.shape[0]], dtype=offsets.dtype, device=offsets.device
),
)
)
)

num_obs = len(lengths)
max_len = lengths.max().item()
self.max_len = max_len
feature_dim = value.shape[1]

# create an empty 2d tensor to store the amended tensor
# the new shape should be the maximum length of the observations times the number of observations, and the number of features
new_shape = (num_obs * max_len, feature_dim)
padded_value = torch.zeros(
*new_shape, dtype=value.dtype, device=value.device
)
padded_presence = torch.zeros(
*new_shape, dtype=presence.dtype, device=presence.device
)

# create a tensor of indices to scatter the values to
indices = torch.cat(
[
torch.arange(lengths[i], device=value.device) + i * max_len
for i in range(num_obs)
]
)

# insert the values into the padded tensor
padded_value[indices] = value
padded_presence[indices] = presence

# get the item presence tensor
item_presence = torch.cat(
[
(torch.arange(max_len, device=value.device) < lengths[i]).float()
for i in range(num_obs)
]
)

item_presence = item_presence.view(-1, max_len)

data[to_key] = (padded_value, padded_presence)
data[to_key_item_presence] = item_presence

return data


class FixedLengthSequenceDenseNormalization:
"""
Combines the FixedLengthSequences, DenseNormalization, and SlateView transforms
Expand All @@ -604,8 +692,9 @@ def __init__(
normalization_data: NormalizationData,
expected_length: Optional[int] = None,
device: Optional[torch.device] = None,
to_keys: Optional[List[str]] = None,
):
to_keys = [f"{k}:{sequence_id}" for k in keys]
to_keys = to_keys or [f"{k}:{sequence_id}" for k in keys]
self.fixed_length_sequences = FixedLengthSequences(
keys, sequence_id, to_keys=to_keys, expected_length=expected_length
)
Expand All @@ -622,6 +711,43 @@ def __call__(self, data):
return self.slate_view(data)


class VarLengthSequenceDenseNormalization:
"""
Combines the VarLengthSequences, DenseNormalization, and SlateView transforms.
For SlateView we infer the slate size at runtime and patch the transform.
"""

def __init__(
self,
keys: List[str],
sequence_id: int,
normalization_data: NormalizationData,
to_keys_item_presence: Optional[List[str]] = None,
device: Optional[torch.device] = None,
to_keys: Optional[List[str]] = None,
):
to_keys = to_keys or [f"{k}:{sequence_id}" for k in keys]
self.var_length_sequences = VarLengthSequences(
keys,
sequence_id,
to_keys=to_keys,
to_keys_item_presence=to_keys_item_presence,
)
self.dense_normalization = DenseNormalization(
to_keys, normalization_data, device=device
)
# We will override slate_size in __call__()
self.slate_view = SlateView(to_keys, slate_size=-1)

def __call__(self, data):
data = self.var_length_sequences(data)
data = self.dense_normalization(data)
self.slate_view.slate_size = (
self.var_length_sequences.max_len
) # this assumes that max_len is the same for all all keys
return self.slate_view(data)


class AppendConstant:
"""
Append a column of constant value at the beginning of the specified dimension
Expand Down
1 change: 1 addition & 0 deletions reagent/preprocessing/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,4 @@ class InputColumn(object):
ARM_FEATURES = "arm_features"
CONTEXT_ARM_FEATURES = "context_arm_features"
ARMS = "arms"
ARM_PRESENCE = "arm_presence"
6 changes: 6 additions & 0 deletions reagent/test/evaluation/cb/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,22 @@ def test_eval_during_training(self):
[
[1, 7],
[1, 8],
[
1,
9,
], # this arm would have been chosen by the model if it was present
],
[
[1, 9],
[1, 10],
[1, 11],
],
],
dtype=torch.float,
),
action=torch.tensor([[1], [0]], dtype=torch.long),
reward=torch.tensor([[1.2], [2.9]], dtype=torch.float),
arm_presence=torch.tensor([[1, 1, 0], [1, 1, 1]], dtype=torch.bool),
)
self.trainer.training_step(batch_2, 0)

Expand Down
108 changes: 108 additions & 0 deletions reagent/test/preprocessing/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,3 +765,111 @@ def test_ToDtype(self) -> None:
self.assertEqual(t_data["a"].dtype, torch.float) # was float, didn't change
self.assertEqual(t_data["b"].dtype, torch.float) # changed from double to float
self.assertEqual(t_data["c"].dtype, torch.double) # was double, didn't change

def test_VarLengthSequences(self) -> None:
seq_id = 1

# of form {sequence_id: (offsets, Tuple(Tensor, Tensor))}
a_T = (
torch.tensor([[0, 1, 3], [2, 3, 7], [4, 5, 8], [2, 3, 1]]).float(),
torch.ones(4, 3),
)
b_T = (
torch.tensor(
[[1, 1, 3], [2, 2, 5], [3, 3, 1], [9, 10, 4], [5, 1, 7]]
).float(),
torch.ones(5, 3),
)
a_in = {seq_id: (torch.tensor([0, 1]), a_T)}
b_in = {seq_id: (torch.tensor([0, 4]), b_T)}
vls1 = transforms.VarLengthSequences(keys=["a", "b"], sequence_id=seq_id)
vls2 = transforms.VarLengthSequences(
keys=["a", "b"], sequence_id=seq_id, to_keys=["a_to_key", "b_to_key"]
)
vls3 = transforms.VarLengthSequences(
keys=["a", "b"],
sequence_id=seq_id,
to_keys=["a_to_key", "b_to_key"],
to_keys_item_presence=["a_to_key_item_presence", "b_to_key_item_presence"],
)
o1 = vls1({"a": a_in, "b": b_in})
o2 = vls2({"a": a_in, "b": b_in})
o3 = vls3({"a": a_in, "b": b_in})

self.assertSetEqual(
set(o1.keys()), {"a", "b", "a_item_presence", "b_item_presence"}
)
self.assertSetEqual(
set(o2.keys()),
{
"a",
"b",
"a_to_key",
"b_to_key",
"a_to_key_item_presence",
"b_to_key_item_presence",
},
)
self.assertSetEqual(
set(o3.keys()),
{
"a",
"b",
"a_to_key",
"b_to_key",
"a_to_key_item_presence",
"b_to_key_item_presence",
},
)

# ensure input values are not changed if output keys are different
self.assertEqual(o2["a"], a_in)
self.assertEqual(o2["b"], b_in)
self.assertEqual(o3["a"], a_in)
self.assertEqual(o3["b"], b_in)

# Testing assertion in the constructor
with self.assertRaises(AssertionError):
transforms.VarLengthSequences(
keys=["a", "b"], sequence_id=1, to_keys=["to_a"]
)

# output shapes are correct
self.assertTupleEqual(tuple(o1["a"][0].shape), (6, 3))
self.assertTupleEqual(tuple(o1["b"][0].shape), (8, 3))

# output values are correct
expected_a = torch.tensor(
[
[
[0, 1, 3],
[0, 0, 0],
[0, 0, 0],
[2, 3, 7],
[4, 5, 8],
[2, 3, 1],
],
]
).float()
expected_b = torch.tensor(
[
[
[1, 1, 3],
[2, 2, 5],
[3, 3, 1],
[9, 10, 4],
[5, 1, 7],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
],
]
).float()
self.assertEqual(o1["a"][0], expected_a)
self.assertEqual(o1["b"][0], expected_b)

# item presence tensors are correct
extected_a_item_presence = torch.tensor([[1, 0, 0], [1, 1, 1]]).float()
extected_b_item_presence = torch.tensor([[1, 1, 1, 1], [1, 0, 0, 0]]).float()
self.assertEqual(o1["a_item_presence"], extected_a_item_presence)
self.assertEqual(o1["b_item_presence"], extected_b_item_presence)
14 changes: 13 additions & 1 deletion reagent/training/cb/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,19 @@ def training_step(self, batch: CBInput, batch_idx: int, optimizer_idx: int = 0):
eval_module.num_eval_model_updates += 1
with torch.no_grad():
eval_scores = eval_module.eval_model(batch.context_arm_features)
model_actions = torch.argmax(eval_scores, dim=1).reshape(-1, 1)
if batch.arm_presence is not None:
# mask out non-present arms
eval_scores = torch.masked.as_masked_tensor(
eval_scores, batch.arm_presence.bool()
)
model_actions = (
# pyre-fixme[16]: `Tensor` has no attribute `get_data`.
torch.argmax(eval_scores, dim=1)
.get_data()
.reshape(-1, 1)
)
else:
model_actions = torch.argmax(eval_scores, dim=1).reshape(-1, 1)
new_batch = eval_module.ingest_batch(batch, model_actions)
eval_module.sum_weight_since_update += (
batch.weight.sum() if batch.weight is not None else len(batch)
Expand Down