Skip to content

Commit a546e78

Browse files
Alireza Tehranimeta-codesync[bot]
authored andcommitted
Fix perserving strides, inverse_indices in ManagedCollisionCollection (#3597)
Summary: Pull Request resolved: #3597 The forward method of `ManagedCollisionCollection` goes through each table, and maps the original indices of the input KJT into the new indices using a hash-function. This produces a Dict[str, JaggedTensor], which is then converted into a KeyedJaggedTensor. MCC should only change the values attribute of the KJT, while perserving all other attributes. This conversion did not perserve key attributes of KJT such as `inverse_indices`, and `stride`. that are essential to work with VBE. Reviewed By: kausv Differential Revision: D84944895 fbshipit-source-id: 386a70046d37f956cfb1949077dc9cd69511356b
1 parent b02f57d commit a546e78

File tree

2 files changed

+216
-0
lines changed

2 files changed

+216
-0
lines changed

torchrec/modules/mc_modules.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,10 @@ def forward(
417417
values=values,
418418
lengths=lengths,
419419
weights=features.weights_or_none(),
420+
stride=features.stride(),
421+
stride_per_key=features.stride_per_key(),
422+
stride_per_key_per_rank=features._stride_per_key_per_rank,
423+
inverse_indices=features.inverse_indices_or_none(),
420424
)
421425

422426
def evict(self) -> Dict[str, Optional[torch.Tensor]]:

torchrec/modules/tests/test_hash_mc_modules.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -927,3 +927,215 @@ def test_zch_hash_zero_rows(self) -> None:
927927
torch.nonzero(row_mask, as_tuple=False).squeeze(),
928928
)
929929
)
930+
931+
932+
@unittest.skipIf(
933+
torch.cuda.device_count() < 1,
934+
"Not enough GPUs, this test requires at least one GPU",
935+
)
936+
class TestVBEWithManagedCollision(unittest.TestCase):
937+
"""Tests for Variable Batch Embeddings (VBE) with ManagedCollisionCollection."""
938+
939+
def setUp(self) -> None:
940+
"""Set up common test fixtures for VBE tests."""
941+
self.hash_sizes_table = {"product_table": 5, "user_table": 8}
942+
self.total_ids = {"product_table": 10, "user_table": 20}
943+
944+
# Create hash modules for collision management
945+
self.hash_modules = {
946+
"user_table": HashZchManagedCollisionModule(
947+
zch_size=self.hash_sizes_table["user_table"],
948+
device=torch.device("cuda"),
949+
input_hash_size=self.total_ids["user_table"],
950+
total_num_buckets=1,
951+
),
952+
"product_table": HashZchManagedCollisionModule(
953+
zch_size=self.hash_sizes_table["product_table"],
954+
device=torch.device("cuda"),
955+
input_hash_size=self.total_ids["product_table"],
956+
total_num_buckets=1,
957+
),
958+
}
959+
960+
# Create embedding configs
961+
self.embedding_configs = [
962+
EmbeddingBagConfig(
963+
name="user_table",
964+
embedding_dim=3,
965+
num_embeddings=self.hash_sizes_table["user_table"],
966+
feature_names=["user"],
967+
),
968+
EmbeddingBagConfig(
969+
name="product_table",
970+
embedding_dim=2,
971+
num_embeddings=self.hash_sizes_table["product_table"],
972+
feature_names=["product"],
973+
),
974+
]
975+
976+
# Create ManagedCollisionCollection
977+
self.mcc = ManagedCollisionCollection(
978+
managed_collision_modules=self.hash_modules,
979+
embedding_configs=self.embedding_configs,
980+
)
981+
982+
# Create test KJT with VBE (deduped values with inverse_indices)
983+
# User values: [[5, 6, 7], [1, 2, 3]] - 2 unique pooled groups
984+
# Product values: [[0, 1]] - 1 unique pooled group
985+
self.kjt = KeyedJaggedTensor(
986+
keys=["user", "product"],
987+
values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]),
988+
lengths=torch.tensor([3, 3, 2]),
989+
stride_per_key_per_rank=[[2], [1]],
990+
inverse_indices=(["user", "product"], torch.tensor([[0, 1, 0], [0, 0, 0]])),
991+
).to("cuda")
992+
993+
def test_mcc_preserves_kjt_attributes(self) -> None:
994+
"""Test that ManagedCollisionCollection preserves all KJT attributes with VBE."""
995+
# Add weights to test kjt
996+
kjt_with_weights = KeyedJaggedTensor(
997+
keys=self.kjt.keys(),
998+
values=self.kjt.values(),
999+
lengths=self.kjt.lengths(),
1000+
weights=torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]),
1001+
stride_per_key_per_rank=self.kjt.stride_per_key_per_rank(),
1002+
inverse_indices=self.kjt.inverse_indices(),
1003+
).to("cuda")
1004+
1005+
# Pass through MCC
1006+
output = self.mcc.forward(kjt_with_weights)
1007+
1008+
# Verify ID remapping on values is correct for each table
1009+
for i, table in enumerate(["user_table", "product_table"]):
1010+
mapping = torch.ravel(
1011+
self.mcc._managed_collision_modules[table]._hash_zch_identities
1012+
)
1013+
original_inds = kjt_with_weights.values()[
1014+
kjt_with_weights.offset_per_key()[
1015+
i
1016+
] : kjt_with_weights.offset_per_key()[i + 1]
1017+
]
1018+
remapped_inds = output.values()[
1019+
kjt_with_weights.offset_per_key()[
1020+
i
1021+
] : kjt_with_weights.offset_per_key()[i + 1]
1022+
]
1023+
self.assertTrue(
1024+
torch.equal(original_inds, mapping[remapped_inds]),
1025+
f"ID remapping incorrect for {table}",
1026+
)
1027+
1028+
# Verify all other attributes (relevant to VBE) are preserved
1029+
self.assertTrue(
1030+
torch.equal(kjt_with_weights.lengths(), output.lengths()),
1031+
"Lengths should be preserved",
1032+
)
1033+
self.assertTrue(
1034+
torch.equal(kjt_with_weights.weights(), output.weights()),
1035+
"Weights should be preserved",
1036+
)
1037+
self.assertEqual(
1038+
kjt_with_weights.stride(), output.stride(), "Stride should be preserved"
1039+
)
1040+
self.assertEqual(
1041+
kjt_with_weights.stride_per_key(),
1042+
output.stride_per_key(),
1043+
"stride_per_key should be preserved",
1044+
)
1045+
self.assertEqual(
1046+
kjt_with_weights.stride_per_key_per_rank(),
1047+
output.stride_per_key_per_rank(),
1048+
"stride_per_key_per_rank should be preserved",
1049+
)
1050+
1051+
# Verify inverse_indices are preserved (VBE support)
1052+
input_inverse_indices = kjt_with_weights.inverse_indices()
1053+
output_inverse_indices = output.inverse_indices()
1054+
1055+
self.assertEqual(
1056+
input_inverse_indices[0],
1057+
output_inverse_indices[0],
1058+
"inverse_indices keys should be preserved",
1059+
)
1060+
self.assertTrue(
1061+
torch.equal(input_inverse_indices[1], output_inverse_indices[1]),
1062+
"inverse_indices tensor should be preserved",
1063+
)
1064+
1065+
def test_mcebc_with_vbe(self) -> None:
1066+
"""Test that MCEBC correctly handles VBE using inverse_indices."""
1067+
# Set up MCEBC
1068+
ebc = EmbeddingBagCollection(
1069+
device="cuda",
1070+
tables=self.embedding_configs,
1071+
)
1072+
mcebc = ManagedCollisionEmbeddingBagCollection(
1073+
embedding_bag_collection=ebc,
1074+
managed_collision_collection=self.mcc,
1075+
)
1076+
1077+
# Run forward pass
1078+
actual_output, _ = mcebc(self.kjt)
1079+
1080+
# Manually compute results on hard-coded VBE example
1081+
tables = {
1082+
"user_table": ebc.embedding_bags["user_table"].weight,
1083+
"product_table": ebc.embedding_bags["product_table"].weight,
1084+
}
1085+
1086+
pooled_embeddings = {
1087+
"user_table": torch.zeros((2, 3)),
1088+
"product_table": torch.zeros((1, 2)),
1089+
}
1090+
1091+
i_length = 0
1092+
for i_table, table in enumerate(["user_table", "product_table"]):
1093+
stride_per_key = self.kjt.stride_per_key()
1094+
mcc_table = mcebc._managed_collision_collection._managed_collision_modules[
1095+
table
1096+
]
1097+
remapped_indices = torch.ravel(mcc_table._hash_zch_identities)
1098+
1099+
original_inds_per_key = self.kjt.values()[
1100+
self.kjt.offset_per_key()[i_table] : self.kjt.offset_per_key()[
1101+
i_table + 1
1102+
]
1103+
]
1104+
1105+
# Process each unique pooled group
1106+
offset_per_key_per_pool = 0
1107+
for i_pooled in range(stride_per_key[i_table]):
1108+
length_of_pool = self.kjt.lengths()[i_length]
1109+
1110+
pooled_original_indices = original_inds_per_key[
1111+
offset_per_key_per_pool : offset_per_key_per_pool + length_of_pool
1112+
]
1113+
1114+
# Get the new indices from hash-map
1115+
new_indices = torch.tensor(
1116+
[
1117+
torch.where(remapped_indices == idx)[0].item()
1118+
for idx in pooled_original_indices
1119+
]
1120+
)
1121+
1122+
# Sum embeddings for the pooled group from new_indices
1123+
pooled_embeddings[table][i_pooled] = (
1124+
tables[table][new_indices, :].sum(axis=0).to("cpu")
1125+
)
1126+
1127+
i_length += 1
1128+
offset_per_key_per_pool += length_of_pool
1129+
1130+
# Use inverse_indices to expand pooled embeddings to final output
1131+
inverse_keys, inverse_tensor = self.kjt.inverse_indices()
1132+
1133+
user_inverse = inverse_tensor[inverse_keys.index("user")].to("cpu")
1134+
expected_user = pooled_embeddings["user_table"][user_inverse]
1135+
1136+
prod_inverse = inverse_tensor[inverse_keys.index("product")].to("cpu")
1137+
expected_prod = pooled_embeddings["product_table"][prod_inverse]
1138+
1139+
# Verify actual output matches expected output
1140+
self.assertTrue(torch.equal(expected_user, actual_output["user"].to("cpu")))
1141+
self.assertTrue(torch.equal(expected_prod, actual_output["product"].to("cpu")))

0 commit comments

Comments
 (0)