20
20
PartiallyMaterializedTensor ,
21
21
)
22
22
from hypothesis import assume , given , settings , strategies as st , Verbosity
23
+ from pyjk import PyPatchJustKnobs
23
24
from torch import distributed as dist
24
25
from torch .distributed ._shard .sharded_tensor import ShardedTensor
25
26
from torch .distributed ._tensor import DTensor
28
29
EmbeddingComputeKernel ,
29
30
EmbeddingTableConfig ,
30
31
)
31
- from torchrec .distributed .embeddingbag import ShardedEmbeddingBagCollection
32
+ from torchrec .distributed .embeddingbag import (
33
+ logger as embeddingbag_logger ,
34
+ ShardedEmbeddingBagCollection ,
35
+ )
32
36
from torchrec .distributed .fused_embeddingbag import ShardedFusedEmbeddingBagCollection
33
37
from torchrec .distributed .model_parallel import DistributedModelParallel
34
38
from torchrec .distributed .planner import (
65
69
from torchrec .modules .embedding_modules import EmbeddingBagCollection
66
70
from torchrec .modules .fused_embedding_modules import FusedEmbeddingBagCollection
67
71
from torchrec .optim .rowwise_adagrad import RowWiseAdagrad
72
+ from torchrec .sparse .jagged_tensor import KeyedJaggedTensor
68
73
from torchrec .test_utils import get_free_port , seed_and_log
69
74
70
75
@@ -205,13 +210,23 @@ def setUp(self, backend: str = "nccl") -> None:
205
210
206
211
dist .init_process_group (backend = self .backend )
207
212
213
+ @classmethod
214
+ def setUpClass (cls ) -> None :
215
+ super ().setUpClass ()
216
+ cls .patcher = PyPatchJustKnobs ()
217
+
208
218
def tearDown (self ) -> None :
209
219
dist .destroy_process_group ()
210
220
211
221
def test_sharding_ebc_as_top_level (self ) -> None :
222
+ model = self ._create_sharded_model ()
223
+
224
+ self .assertTrue (isinstance (model .module , ShardedEmbeddingBagCollection ))
225
+
226
+ def test_sharding_fused_ebc_as_top_level (self ) -> None :
212
227
embedding_dim = 128
213
228
num_embeddings = 256
214
- ebc = EmbeddingBagCollection (
229
+ ebc = FusedEmbeddingBagCollection (
215
230
device = torch .device ("meta" ),
216
231
tables = [
217
232
EmbeddingBagConfig (
@@ -222,16 +237,67 @@ def test_sharding_ebc_as_top_level(self) -> None:
222
237
pooling = PoolingType .SUM ,
223
238
),
224
239
],
240
+ optimizer_type = torch .optim .SGD ,
241
+ optimizer_kwargs = {"lr" : 0.02 },
225
242
)
226
243
227
244
model = DistributedModelParallel (ebc , device = self .device )
228
245
229
- self .assertTrue (isinstance (model .module , ShardedEmbeddingBagCollection ))
246
+ self .assertTrue (isinstance (model .module , ShardedFusedEmbeddingBagCollection ))
230
247
231
- def test_sharding_fused_ebc_as_top_level (self ) -> None :
232
- embedding_dim = 128
233
- num_embeddings = 256
234
- ebc = FusedEmbeddingBagCollection (
248
+ def test_sharding_ebc_input_validation_enabled (self ) -> None :
249
+ model = self ._create_sharded_model ()
250
+ kjt = KeyedJaggedTensor (
251
+ keys = ["my_feature" , "my_feature" ],
252
+ values = torch .tensor ([1 , 2 , 3 , 4 , 5 ]),
253
+ lengths = torch .tensor ([1 , 2 , 0 , 2 ]),
254
+ offsets = torch .tensor ([0 , 1 , 3 , 3 , 5 ]),
255
+ )
256
+
257
+ with self .patcher .patch ("pytorch/torchrec:enable_kjt_validation" , True ):
258
+ with self .assertRaisesRegex (ValueError , "keys must be unique" ):
259
+ model (kjt )
260
+
261
+ def test_sharding_ebc_validate_input_only_once (self ) -> None :
262
+ model = self ._create_sharded_model ()
263
+ kjt = KeyedJaggedTensor (
264
+ keys = ["my_feature" ],
265
+ values = torch .tensor ([1 , 2 , 3 , 4 , 5 ]),
266
+ lengths = torch .tensor ([1 , 2 , 0 , 2 ]),
267
+ offsets = torch .tensor ([0 , 1 , 3 , 3 , 5 ]),
268
+ ).to (self .device )
269
+
270
+ with self .patcher .patch ("pytorch/torchrec:enable_kjt_validation" , True ):
271
+ with self .assertLogs (embeddingbag_logger , level = "INFO" ) as logs :
272
+ model (kjt )
273
+ model (kjt )
274
+ model (kjt )
275
+
276
+ matched_logs = list (
277
+ filter (lambda s : "Validating input features..." in s , logs .output )
278
+ )
279
+ self .assertEqual (1 , len (matched_logs ))
280
+
281
+ def test_sharding_ebc_input_validation_disabled (self ) -> None :
282
+ model = self ._create_sharded_model ()
283
+ kjt = KeyedJaggedTensor (
284
+ keys = ["my_feature" , "my_feature" ],
285
+ values = torch .tensor ([1 , 2 , 3 , 4 , 5 ]),
286
+ lengths = torch .tensor ([1 , 2 , 0 , 2 ]),
287
+ offsets = torch .tensor ([0 , 1 , 3 , 3 , 5 ]),
288
+ ).to (self .device )
289
+
290
+ # Without KJT validation, input_dist will not raise exceptions
291
+ with self .patcher .patch ("pytorch/torchrec:enable_kjt_validation" , False ):
292
+ try :
293
+ model (kjt )
294
+ except ValueError :
295
+ self .fail ("Input validation should not be enabled." )
296
+
297
+ def _create_sharded_model (
298
+ self , embedding_dim : int = 128 , num_embeddings : int = 256
299
+ ) -> DistributedModelParallel :
300
+ ebc = EmbeddingBagCollection (
235
301
device = torch .device ("meta" ),
236
302
tables = [
237
303
EmbeddingBagConfig (
@@ -242,13 +308,8 @@ def test_sharding_fused_ebc_as_top_level(self) -> None:
242
308
pooling = PoolingType .SUM ,
243
309
),
244
310
],
245
- optimizer_type = torch .optim .SGD ,
246
- optimizer_kwargs = {"lr" : 0.02 },
247
311
)
248
-
249
- model = DistributedModelParallel (ebc , device = self .device )
250
-
251
- self .assertTrue (isinstance (model .module , ShardedFusedEmbeddingBagCollection ))
312
+ return DistributedModelParallel (ebc , device = self .device )
252
313
253
314
254
315
class ModelParallelSingleRankBase (unittest .TestCase ):
0 commit comments