@@ -927,3 +927,215 @@ def test_zch_hash_zero_rows(self) -> None:
927927 torch .nonzero (row_mask , as_tuple = False ).squeeze (),
928928 )
929929 )
930+
931+
932+ @unittest .skipIf (
933+ torch .cuda .device_count () < 1 ,
934+ "Not enough GPUs, this test requires at least one GPU" ,
935+ )
936+ class TestVBEWithManagedCollision (unittest .TestCase ):
937+ """Tests for Variable Batch Embeddings (VBE) with ManagedCollisionCollection."""
938+
939+ def setUp (self ) -> None :
940+ """Set up common test fixtures for VBE tests."""
941+ self .hash_sizes_table = {"product_table" : 5 , "user_table" : 8 }
942+ self .total_ids = {"product_table" : 10 , "user_table" : 20 }
943+
944+ # Create hash modules for collision management
945+ self .hash_modules = {
946+ "user_table" : HashZchManagedCollisionModule (
947+ zch_size = self .hash_sizes_table ["user_table" ],
948+ device = torch .device ("cuda" ),
949+ input_hash_size = self .total_ids ["user_table" ],
950+ total_num_buckets = 1 ,
951+ ),
952+ "product_table" : HashZchManagedCollisionModule (
953+ zch_size = self .hash_sizes_table ["product_table" ],
954+ device = torch .device ("cuda" ),
955+ input_hash_size = self .total_ids ["product_table" ],
956+ total_num_buckets = 1 ,
957+ ),
958+ }
959+
960+ # Create embedding configs
961+ self .embedding_configs = [
962+ EmbeddingBagConfig (
963+ name = "user_table" ,
964+ embedding_dim = 3 ,
965+ num_embeddings = self .hash_sizes_table ["user_table" ],
966+ feature_names = ["user" ],
967+ ),
968+ EmbeddingBagConfig (
969+ name = "product_table" ,
970+ embedding_dim = 2 ,
971+ num_embeddings = self .hash_sizes_table ["product_table" ],
972+ feature_names = ["product" ],
973+ ),
974+ ]
975+
976+ # Create ManagedCollisionCollection
977+ self .mcc = ManagedCollisionCollection (
978+ managed_collision_modules = self .hash_modules ,
979+ embedding_configs = self .embedding_configs ,
980+ )
981+
982+ # Create test KJT with VBE (deduped values with inverse_indices)
983+ # User values: [[5, 6, 7], [1, 2, 3]] - 2 unique pooled groups
984+ # Product values: [[0, 1]] - 1 unique pooled group
985+ self .kjt = KeyedJaggedTensor (
986+ keys = ["user" , "product" ],
987+ values = torch .tensor ([5 , 6 , 7 , 1 , 2 , 3 , 0 , 1 ]),
988+ lengths = torch .tensor ([3 , 3 , 2 ]),
989+ stride_per_key_per_rank = [[2 ], [1 ]],
990+ inverse_indices = (["user" , "product" ], torch .tensor ([[0 , 1 , 0 ], [0 , 0 , 0 ]])),
991+ ).to ("cuda" )
992+
993+ def test_mcc_preserves_kjt_attributes (self ) -> None :
994+ """Test that ManagedCollisionCollection preserves all KJT attributes with VBE."""
995+ # Add weights to test kjt
996+ kjt_with_weights = KeyedJaggedTensor (
997+ keys = self .kjt .keys (),
998+ values = self .kjt .values (),
999+ lengths = self .kjt .lengths (),
1000+ weights = torch .tensor ([0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 ]),
1001+ stride_per_key_per_rank = self .kjt .stride_per_key_per_rank (),
1002+ inverse_indices = self .kjt .inverse_indices (),
1003+ ).to ("cuda" )
1004+
1005+ # Pass through MCC
1006+ output = self .mcc .forward (kjt_with_weights )
1007+
1008+ # Verify ID remapping on values is correct for each table
1009+ for i , table in enumerate (["user_table" , "product_table" ]):
1010+ mapping = torch .ravel (
1011+ self .mcc ._managed_collision_modules [table ]._hash_zch_identities
1012+ )
1013+ original_inds = kjt_with_weights .values ()[
1014+ kjt_with_weights .offset_per_key ()[
1015+ i
1016+ ] : kjt_with_weights .offset_per_key ()[i + 1 ]
1017+ ]
1018+ remapped_inds = output .values ()[
1019+ kjt_with_weights .offset_per_key ()[
1020+ i
1021+ ] : kjt_with_weights .offset_per_key ()[i + 1 ]
1022+ ]
1023+ self .assertTrue (
1024+ torch .equal (original_inds , mapping [remapped_inds ]),
1025+ f"ID remapping incorrect for { table } " ,
1026+ )
1027+
1028+ # Verify all other attributes (relevant to VBE) are preserved
1029+ self .assertTrue (
1030+ torch .equal (kjt_with_weights .lengths (), output .lengths ()),
1031+ "Lengths should be preserved" ,
1032+ )
1033+ self .assertTrue (
1034+ torch .equal (kjt_with_weights .weights (), output .weights ()),
1035+ "Weights should be preserved" ,
1036+ )
1037+ self .assertEqual (
1038+ kjt_with_weights .stride (), output .stride (), "Stride should be preserved"
1039+ )
1040+ self .assertEqual (
1041+ kjt_with_weights .stride_per_key (),
1042+ output .stride_per_key (),
1043+ "stride_per_key should be preserved" ,
1044+ )
1045+ self .assertEqual (
1046+ kjt_with_weights .stride_per_key_per_rank (),
1047+ output .stride_per_key_per_rank (),
1048+ "stride_per_key_per_rank should be preserved" ,
1049+ )
1050+
1051+ # Verify inverse_indices are preserved (VBE support)
1052+ input_inverse_indices = kjt_with_weights .inverse_indices ()
1053+ output_inverse_indices = output .inverse_indices ()
1054+
1055+ self .assertEqual (
1056+ input_inverse_indices [0 ],
1057+ output_inverse_indices [0 ],
1058+ "inverse_indices keys should be preserved" ,
1059+ )
1060+ self .assertTrue (
1061+ torch .equal (input_inverse_indices [1 ], output_inverse_indices [1 ]),
1062+ "inverse_indices tensor should be preserved" ,
1063+ )
1064+
1065+ def test_mcebc_with_vbe (self ) -> None :
1066+ """Test that MCEBC correctly handles VBE using inverse_indices."""
1067+ # Set up MCEBC
1068+ ebc = EmbeddingBagCollection (
1069+ device = "cuda" ,
1070+ tables = self .embedding_configs ,
1071+ )
1072+ mcebc = ManagedCollisionEmbeddingBagCollection (
1073+ embedding_bag_collection = ebc ,
1074+ managed_collision_collection = self .mcc ,
1075+ )
1076+
1077+ # Run forward pass
1078+ actual_output , _ = mcebc (self .kjt )
1079+
1080+ # Manually compute results on hard-coded VBE example
1081+ tables = {
1082+ "user_table" : ebc .embedding_bags ["user_table" ].weight ,
1083+ "product_table" : ebc .embedding_bags ["product_table" ].weight ,
1084+ }
1085+
1086+ pooled_embeddings = {
1087+ "user_table" : torch .zeros ((2 , 3 )),
1088+ "product_table" : torch .zeros ((1 , 2 )),
1089+ }
1090+
1091+ i_length = 0
1092+ for i_table , table in enumerate (["user_table" , "product_table" ]):
1093+ stride_per_key = self .kjt .stride_per_key ()
1094+ mcc_table = mcebc ._managed_collision_collection ._managed_collision_modules [
1095+ table
1096+ ]
1097+ remapped_indices = torch .ravel (mcc_table ._hash_zch_identities )
1098+
1099+ original_inds_per_key = self .kjt .values ()[
1100+ self .kjt .offset_per_key ()[i_table ] : self .kjt .offset_per_key ()[
1101+ i_table + 1
1102+ ]
1103+ ]
1104+
1105+ # Process each unique pooled group
1106+ offset_per_key_per_pool = 0
1107+ for i_pooled in range (stride_per_key [i_table ]):
1108+ length_of_pool = self .kjt .lengths ()[i_length ]
1109+
1110+ pooled_original_indices = original_inds_per_key [
1111+ offset_per_key_per_pool : offset_per_key_per_pool + length_of_pool
1112+ ]
1113+
1114+ # Get the new indices from hash-map
1115+ new_indices = torch .tensor (
1116+ [
1117+ torch .where (remapped_indices == idx )[0 ].item ()
1118+ for idx in pooled_original_indices
1119+ ]
1120+ )
1121+
1122+ # Sum embeddings for the pooled group from new_indices
1123+ pooled_embeddings [table ][i_pooled ] = (
1124+ tables [table ][new_indices , :].sum (axis = 0 ).to ("cpu" )
1125+ )
1126+
1127+ i_length += 1
1128+ offset_per_key_per_pool += length_of_pool
1129+
1130+ # Use inverse_indices to expand pooled embeddings to final output
1131+ inverse_keys , inverse_tensor = self .kjt .inverse_indices ()
1132+
1133+ user_inverse = inverse_tensor [inverse_keys .index ("user" )].to ("cpu" )
1134+ expected_user = pooled_embeddings ["user_table" ][user_inverse ]
1135+
1136+ prod_inverse = inverse_tensor [inverse_keys .index ("product" )].to ("cpu" )
1137+ expected_prod = pooled_embeddings ["product_table" ][prod_inverse ]
1138+
1139+ # Verify actual output matches expected output
1140+ self .assertTrue (torch .equal (expected_user , actual_output ["user" ].to ("cpu" )))
1141+ self .assertTrue (torch .equal (expected_prod , actual_output ["product" ].to ("cpu" )))
0 commit comments