Skip to content

Commit 03a0d93

Browse files
peterfu0facebook-github-bot
authored andcommitted
add prefetch in customized order pipeline (#3349)
Summary: Pull Request resolved: #3349 Reviewed By: TroyGarden Differential Revision: D79404930 fbshipit-source-id: 763afb0afa51d089dc75f7ef17a3262748c76feb
1 parent 1cbe55d commit 03a0d93

File tree

2 files changed

+156
-0
lines changed

2 files changed

+156
-0
lines changed

torchrec/distributed/train_pipeline/runtime_forwards.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ def detach_embeddings(
223223
class PrefetchPipelinedForward(BaseForward[PrefetchTrainPipelineContext]):
224224
"""
225225
This pipeline is used in PrefetchTrainPipelineSparseDist
226+
OR in TrainPipelineCustomizedOrderSparseDist, when prefetch is enabled but pipeline_embedding_lookup_fwd is disabled
226227
"""
227228

228229
def __init__(
@@ -267,6 +268,67 @@ def __call__(self, *input, **kwargs) -> Awaitable:
267268
return self._module.compute_and_output_dist(ctx, data)
268269

269270

271+
class PrefetchEmbeddingPipelinedForward(PrefetchPipelinedForward):
272+
"""
273+
This pipeline is used in TrainPipelineCustomizedOrderSparseDist when
274+
prefetch is enabled and pipelined_sprase_lookup_fwd is enabled
275+
compute_and_output_dist for batch N is called at the end of step N - 1
276+
"""
277+
278+
def __init__(
279+
self,
280+
name: str,
281+
args: CallArgs,
282+
module: ShardedModule,
283+
context: PrefetchTrainPipelineContext,
284+
prefetch_stream: Optional[torch.Stream] = None,
285+
) -> None:
286+
super().__init__(
287+
name=name,
288+
args=args,
289+
module=module,
290+
context=context,
291+
prefetch_stream=prefetch_stream,
292+
)
293+
self._compute_and_output_dist_awaitable: Optional[
294+
Awaitable[Multistreamable]
295+
] = None
296+
297+
def compute_and_output_dist(self) -> None:
298+
assert (
299+
self._name in self._context.module_input_post_prefetch
300+
), "Invalid PrefetchEmbeddingPipelinedForward usage, please do not directly call model.forward()"
301+
data = self._context.module_input_post_prefetch.pop(self._name)
302+
ctx = self._context.module_contexts_post_prefetch.pop(self._name)
303+
304+
# Make sure that both result of input_dist and context
305+
# are properly transferred to the current stream.
306+
if self._stream is not None:
307+
torch.get_device_module(self._device).current_stream().wait_stream(
308+
self._stream
309+
)
310+
cur_stream = torch.get_device_module(self._device).current_stream()
311+
312+
assert isinstance(
313+
data, (torch.Tensor, Multistreamable)
314+
), f"{type(data)} must implement Multistreamable interface"
315+
data.record_stream(cur_stream)
316+
317+
ctx.record_stream(cur_stream)
318+
319+
self._compute_and_output_dist_awaitable = self._module.compute_and_output_dist(
320+
ctx, data
321+
)
322+
323+
# pyre-ignore [2, 24]
324+
def __call__(self, *input, **kwargs) -> Awaitable:
325+
if not self._compute_and_output_dist_awaitable:
326+
raise Exception(
327+
"compute_and_output_dist must be called before __call__",
328+
)
329+
return self._compute_and_output_dist_awaitable
330+
331+
270332
class KJTAllToAllForward:
271333
def __init__(
272334
self, pg: dist.ProcessGroup, splits: List[int], stagger: int = 1
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import unittest
11+
from unittest.mock import MagicMock
12+
13+
from torchrec.distributed.train_pipeline.pipeline_context import (
14+
PrefetchTrainPipelineContext,
15+
)
16+
from torchrec.distributed.train_pipeline.runtime_forwards import (
17+
PrefetchEmbeddingPipelinedForward,
18+
PrefetchPipelinedForward,
19+
)
20+
from torchrec.distributed.train_pipeline.types import CallArgs
21+
22+
23+
class TestPrefetchEmbeddingPipelinedForward(unittest.TestCase):
24+
"""Test PrefetchEmbeddingPipelinedForward key functionality"""
25+
26+
def setUp(self) -> None:
27+
"""Set up test fixtures."""
28+
self.mock_module = MagicMock()
29+
self.prefetch_context = PrefetchTrainPipelineContext()
30+
self.mock_args = CallArgs(args=[], kwargs={})
31+
32+
def test_prefetch_returns_true(self) -> None:
33+
"""Test that prefetch() returns True."""
34+
forward = PrefetchEmbeddingPipelinedForward(
35+
name="test_prefetch",
36+
args=self.mock_args,
37+
module=self.mock_module,
38+
context=self.prefetch_context,
39+
)
40+
41+
# Test that prefetch returns True
42+
self.assertIsInstance(forward, PrefetchPipelinedForward)
43+
44+
def test_call_fails_without_compute_and_output_dist(self) -> None:
45+
"""Test that __call__ fails if compute_and_output_dist is not called first."""
46+
forward = PrefetchEmbeddingPipelinedForward(
47+
name="test_call_error",
48+
args=self.mock_args,
49+
module=self.mock_module,
50+
context=self.prefetch_context,
51+
)
52+
53+
# Should raise exception when called without compute_and_output_dist
54+
with self.assertRaises(Exception) as context:
55+
forward()
56+
57+
self.assertIn(
58+
"compute_and_output_dist must be called before __call__",
59+
str(context.exception),
60+
)
61+
62+
def test_call_succeeds_after_compute_and_output_dist(self) -> None:
63+
"""Test that __call__ succeeds when compute_and_output_dist is called first."""
64+
forward = PrefetchEmbeddingPipelinedForward(
65+
name="test_call_success",
66+
args=self.mock_args,
67+
module=self.mock_module,
68+
context=self.prefetch_context,
69+
)
70+
71+
# Set up mock data in context
72+
test_data = MagicMock()
73+
test_ctx = MagicMock()
74+
self.prefetch_context.module_input_post_prefetch = {
75+
"test_call_success": test_data
76+
}
77+
self.prefetch_context.module_contexts_post_prefetch = {
78+
"test_call_success": test_ctx
79+
}
80+
81+
# Mock the module's compute_and_output_dist method
82+
mock_awaitable = MagicMock()
83+
self.mock_module.compute_and_output_dist.return_value = mock_awaitable
84+
85+
# Call compute_and_output_dist first
86+
forward.compute_and_output_dist()
87+
88+
# Now __call__ should succeed and return the awaitable
89+
result = forward()
90+
self.assertEqual(result, mock_awaitable)
91+
92+
93+
if __name__ == "__main__":
94+
unittest.main()

0 commit comments

Comments
 (0)