Skip to content

Commit 4c88aee

Browse files
Shuangping Liufacebook-github-bot
authored andcommitted
Add input KJT validation in EBC input_dist (#3328)
Summary: Pull Request resolved: #3328 Validates input KJT in EBC `input_dist`. This is executed when initializing `input_dist` of `ShardedEmbeddingBagCollection`, so that input features are validated exactly **once** per EBC per rank, assuming the shape of the first batch is representative across all following batches. Reviewed By: TroyGarden Differential Revision: D71752961
1 parent b811c5e commit 4c88aee

File tree

2 files changed

+87
-13
lines changed

2 files changed

+87
-13
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# pyre-strict
99

1010
import copy
11+
import logging
1112
from collections import defaultdict, OrderedDict
1213
from dataclasses import dataclass, field
1314
from functools import partial
@@ -109,6 +110,7 @@
109110
from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
110111
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
111112
from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor
113+
from torchrec.sparse.jagged_tensor_validator import validate_keyed_jagged_tensor
112114
from torchrec.sparse.tensor_dict import maybe_td_to_kjt
113115

114116
try:
@@ -119,6 +121,9 @@
119121
pass
120122

121123

124+
logger: logging.Logger = logging.getLogger(__name__)
125+
126+
122127
def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
123128
return (
124129
tensor
@@ -1515,13 +1520,21 @@ def input_dist(
15151520
features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6]
15161521
ctx.variable_batch_per_feature = features.variable_stride_per_key()
15171522
ctx.inverse_indices = features.inverse_indices_or_none()
1523+
15181524
if self._has_uninitialized_input_dist:
1525+
if torch._utils_internal.justknobs_check(
1526+
"pytorch/torchrec:enable_kjt_validation"
1527+
):
1528+
logger.info("Validating input features...")
1529+
validate_keyed_jagged_tensor(features)
1530+
15191531
self._create_input_dist(features.keys())
15201532
self._has_uninitialized_input_dist = False
15211533
if ctx.variable_batch_per_feature:
15221534
self._create_inverse_indices_permute_indices(ctx.inverse_indices)
15231535
if self._has_mean_pooling_callback:
15241536
self._init_mean_pooling_callback(features.keys(), ctx.inverse_indices)
1537+
15251538
with torch.no_grad():
15261539
if self._has_features_permute:
15271540
features = features.permute(

torchrec/distributed/test_utils/test_model_parallel_base.py

Lines changed: 74 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
PartiallyMaterializedTensor,
2121
)
2222
from hypothesis import assume, given, settings, strategies as st, Verbosity
23+
from pyjk import PyPatchJustKnobs
2324
from torch import distributed as dist
2425
from torch.distributed._shard.sharded_tensor import ShardedTensor
2526
from torch.distributed._tensor import DTensor
@@ -28,7 +29,10 @@
2829
EmbeddingComputeKernel,
2930
EmbeddingTableConfig,
3031
)
31-
from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection
32+
from torchrec.distributed.embeddingbag import (
33+
logger as embeddingbag_logger,
34+
ShardedEmbeddingBagCollection,
35+
)
3236
from torchrec.distributed.fused_embeddingbag import ShardedFusedEmbeddingBagCollection
3337
from torchrec.distributed.model_parallel import DistributedModelParallel
3438
from torchrec.distributed.planner import (
@@ -65,6 +69,7 @@
6569
from torchrec.modules.embedding_modules import EmbeddingBagCollection
6670
from torchrec.modules.fused_embedding_modules import FusedEmbeddingBagCollection
6771
from torchrec.optim.rowwise_adagrad import RowWiseAdagrad
72+
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
6873
from torchrec.test_utils import get_free_port, seed_and_log
6974

7075

@@ -205,13 +210,23 @@ def setUp(self, backend: str = "nccl") -> None:
205210

206211
dist.init_process_group(backend=self.backend)
207212

213+
@classmethod
214+
def setUpClass(cls) -> None:
215+
super().setUpClass()
216+
cls.patcher = PyPatchJustKnobs()
217+
208218
def tearDown(self) -> None:
209219
dist.destroy_process_group()
210220

211221
def test_sharding_ebc_as_top_level(self) -> None:
222+
model = self._create_sharded_model()
223+
224+
self.assertTrue(isinstance(model.module, ShardedEmbeddingBagCollection))
225+
226+
def test_sharding_fused_ebc_as_top_level(self) -> None:
212227
embedding_dim = 128
213228
num_embeddings = 256
214-
ebc = EmbeddingBagCollection(
229+
ebc = FusedEmbeddingBagCollection(
215230
device=torch.device("meta"),
216231
tables=[
217232
EmbeddingBagConfig(
@@ -222,16 +237,67 @@ def test_sharding_ebc_as_top_level(self) -> None:
222237
pooling=PoolingType.SUM,
223238
),
224239
],
240+
optimizer_type=torch.optim.SGD,
241+
optimizer_kwargs={"lr": 0.02},
225242
)
226243

227244
model = DistributedModelParallel(ebc, device=self.device)
228245

229-
self.assertTrue(isinstance(model.module, ShardedEmbeddingBagCollection))
246+
self.assertTrue(isinstance(model.module, ShardedFusedEmbeddingBagCollection))
230247

231-
def test_sharding_fused_ebc_as_top_level(self) -> None:
232-
embedding_dim = 128
233-
num_embeddings = 256
234-
ebc = FusedEmbeddingBagCollection(
248+
def test_sharding_ebc_input_validation_enabled(self) -> None:
249+
model = self._create_sharded_model()
250+
kjt = KeyedJaggedTensor(
251+
keys=["my_feature", "my_feature"],
252+
values=torch.tensor([1, 2, 3, 4, 5]),
253+
lengths=torch.tensor([1, 2, 0, 2]),
254+
offsets=torch.tensor([0, 1, 3, 3, 5]),
255+
)
256+
257+
with self.patcher.patch("pytorch/torchrec:enable_kjt_validation", True):
258+
with self.assertRaisesRegex(ValueError, "keys must be unique"):
259+
model(kjt)
260+
261+
def test_sharding_ebc_validate_input_only_once(self) -> None:
262+
model = self._create_sharded_model()
263+
kjt = KeyedJaggedTensor(
264+
keys=["my_feature"],
265+
values=torch.tensor([1, 2, 3, 4, 5]),
266+
lengths=torch.tensor([1, 2, 0, 2]),
267+
offsets=torch.tensor([0, 1, 3, 3, 5]),
268+
).to(self.device)
269+
270+
with self.patcher.patch("pytorch/torchrec:enable_kjt_validation", True):
271+
with self.assertLogs(embeddingbag_logger, level="INFO") as logs:
272+
model(kjt)
273+
model(kjt)
274+
model(kjt)
275+
276+
matched_logs = list(
277+
filter(lambda s: "Validating input features..." in s, logs.output)
278+
)
279+
self.assertEqual(1, len(matched_logs))
280+
281+
def test_sharding_ebc_input_validation_disabled(self) -> None:
282+
model = self._create_sharded_model()
283+
kjt = KeyedJaggedTensor(
284+
keys=["my_feature", "my_feature"],
285+
values=torch.tensor([1, 2, 3, 4, 5]),
286+
lengths=torch.tensor([1, 2, 0, 2]),
287+
offsets=torch.tensor([0, 1, 3, 3, 5]),
288+
).to(self.device)
289+
290+
# Without KJT validation, input_dist will not raise exceptions
291+
with self.patcher.patch("pytorch/torchrec:enable_kjt_validation", False):
292+
try:
293+
model(kjt)
294+
except ValueError:
295+
self.fail("Input validation should not be enabled.")
296+
297+
def _create_sharded_model(
298+
self, embedding_dim: int = 128, num_embeddings: int = 256
299+
) -> DistributedModelParallel:
300+
ebc = EmbeddingBagCollection(
235301
device=torch.device("meta"),
236302
tables=[
237303
EmbeddingBagConfig(
@@ -242,13 +308,8 @@ def test_sharding_fused_ebc_as_top_level(self) -> None:
242308
pooling=PoolingType.SUM,
243309
),
244310
],
245-
optimizer_type=torch.optim.SGD,
246-
optimizer_kwargs={"lr": 0.02},
247311
)
248-
249-
model = DistributedModelParallel(ebc, device=self.device)
250-
251-
self.assertTrue(isinstance(model.module, ShardedFusedEmbeddingBagCollection))
312+
return DistributedModelParallel(ebc, device=self.device)
252313

253314

254315
class ModelParallelSingleRankBase(unittest.TestCase):

0 commit comments

Comments
 (0)