Skip to content

Commit f975c49

Browse files
author
pytorchbot
committed
2025-12-12 nightly release (fcc5643)
1 parent aa46144 commit f975c49

File tree

15 files changed

+435
-44
lines changed

15 files changed

+435
-44
lines changed

.ci/docker/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ fsspec
88
tyro
99
tokenizers >= 0.15.0
1010
safetensors
11+
psutil

README.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,14 @@ The Guiding Principles when building `torchtitan`
4040
* Minimal changes to the model code when applying multi-dimensional parallelism.
4141
* Bias towards a clean, minimal codebase while providing basic reusable / swappable components.
4242

43-
`torchtitan` has been showcasing PyTorch's latest distributed training features, via pretraining Llama 3.1 LLMs of various sizes.
44-
To accelerate contributions to and innovations around torchtitan, we host an [`experiments`](torchtitan/experiments) folder. We look forward to your contributions!
43+
`torchtitan` has been showcasing PyTorch's latest distributed training features, via support for pretraining Llama 3.1 LLMs of various sizes.
4544

45+
## Contributing
46+
47+
We look forward to your contributions!
48+
49+
* To accelerate contributions to and innovations around torchtitan, we host an [`experiments`](torchtitan/experiments) folder. New ideas should start there. To contribute, follow the [`experiments guidelines`](torchtitan/experiments/README.md).
50+
* For fixes and contributions to core, follow these [`guidelines`](CONTRIBUTING.md).
4651

4752
## Llama 3.1 training
4853

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ dependencies = [
2424
"fsspec",
2525
"tyro",
2626
"tensorboard",
27+
"psutil",
2728
]
2829
dynamic = ["version"]
2930

tests/integration_tests/models.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,21 @@ def build_model_tests_list() -> list[OverrideDefinitions]:
3232
"deepseek_v3_fsdp+ep+compile",
3333
ngpu=4,
3434
),
35+
OverrideDefinitions(
36+
[
37+
[
38+
"--model.name deepseek_v3",
39+
"--parallelism.pipeline_parallel_degree 2",
40+
"--parallelism.expert_parallel_degree 2",
41+
"--parallelism.pipeline_parallel_schedule DualPipeV",
42+
# AC is not supported for DualPipeV yet
43+
"--activation_checkpoint.mode 'none'",
44+
],
45+
],
46+
"PP dual pipe v schedule test",
47+
"pp_dualpipev",
48+
ngpu=4,
49+
),
3550
OverrideDefinitions(
3651
[
3752
[

torchtitan/config/job_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,14 @@ class Parallelism:
373373
The global training batch size must be evenly divisible by pipeline_parallel_microbatch_size.
374374
"""
375375

376+
pipeline_parallel_expert_parallel_overlap: bool = True
377+
"""Whether to turn on the optimization to overlap expert parallel and pipeline parallel
378+
communication. This is only effective when the pipeline parallel schedule is DualPipeV and
379+
pipeline_parallel_degree > 1 and expert_parallel_degree > 1.
380+
381+
TODO: Does not support activation_checkpoint, set mode="none"
382+
"""
383+
376384
context_parallel_degree: int = 1
377385
"""Context parallelism degree. 1 means disabled."""
378386

Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import threading
7+
from typing import cast, Optional
8+
9+
import torch
10+
import torch.nn as nn
11+
from torch import Tensor
12+
13+
from torch.distributed.pipelining.schedules import (
14+
_Action,
15+
_PipelineContext,
16+
_PipelineScheduleRuntime,
17+
_wait_batch_p2p,
18+
)
19+
from torch.distributed.pipelining.stage import _PipelineStageBase
20+
from torch.distributed.tensor import DeviceMesh, distribute_module
21+
from torch.profiler import record_function
22+
23+
from torchtitan.distributed.expert_parallel import BaseExpertParallel
24+
25+
from torchtitan.tools.utils import get_device_info
26+
27+
"""
28+
Below are optimizations related to pipeline parallelism with expert parallelism
29+
"""
30+
31+
32+
def get_dual_pipe_v_flag(job_config, parallel_dims) -> bool:
33+
"""
34+
Determine if DualPipeV should be enabled based on config and
35+
validates that incompatible features (EP + DualPipeV + AC) are not used together.
36+
"""
37+
if not parallel_dims.ep_enabled or not parallel_dims.pp_enabled:
38+
return False
39+
40+
dual_pipe_v = (
41+
job_config.parallelism.pipeline_parallel_expert_parallel_overlap
42+
and job_config.parallelism.pipeline_parallel_schedule.lower() == "dualpipev"
43+
)
44+
45+
if dual_pipe_v and job_config.activation_checkpoint.mode != "none":
46+
raise NotImplementedError(
47+
"Expert Parallel with DualPipeV and Activation Checkpointing "
48+
"cannot be used together. Please disable one of them."
49+
)
50+
51+
return dual_pipe_v
52+
53+
54+
class DualPipeExpertParallel(BaseExpertParallel):
55+
"""
56+
Wrapper that adds dual-pipe synchronization hooks to any BaseExpertParallel.
57+
Wraps dispatch/combine with sync hooks for overlapping EP communication
58+
with PP computation in DualPipe scheduling.
59+
60+
The execution order becomes:
61+
A -> dispatch -> B -> module -> C -> combine -> D
62+
"""
63+
64+
def __init__(self, inner_ep: BaseExpertParallel):
65+
super().__init__()
66+
self.inner_ep = inner_ep
67+
68+
def _partition_fn(self, name: str, mod: nn.Module, device_mesh: DeviceMesh) -> None:
69+
return self.inner_ep._partition_fn(name, mod, device_mesh)
70+
71+
def _token_dispatch(
72+
self, mod: nn.Module, inputs: tuple, device_mesh: DeviceMesh
73+
) -> tuple[Tensor, Tensor]:
74+
"""A -> dispatch -> B"""
75+
inputs = (cast(Tensor, SyncHook.apply(inputs[0], "A")),) + inputs[1:]
76+
outputs = self.inner_ep._token_dispatch(mod, inputs, device_mesh)
77+
outputs = (cast(Tensor, SyncHook.apply(outputs[0], "B")),) + outputs[1:]
78+
return outputs
79+
80+
def _token_combine(
81+
self, mod: nn.Module, routed_output: Tensor, device_mesh: DeviceMesh
82+
) -> Tensor:
83+
"""C -> combine -> D"""
84+
routed_output = cast(Tensor, SyncHook.apply(routed_output, "C"))
85+
combine_output = self.inner_ep._token_combine(mod, routed_output, device_mesh)
86+
combine_output = cast(Tensor, SyncHook.apply(combine_output, "D"))
87+
return combine_output
88+
89+
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
90+
return distribute_module(
91+
module,
92+
device_mesh,
93+
partition_fn=self._partition_fn,
94+
input_fn=self._token_dispatch,
95+
output_fn=self._token_combine,
96+
)
97+
98+
99+
class HookCoordinator:
100+
def __init__(self):
101+
# Barrier for 2 threads (forward and backward) to synchronize
102+
# This ensures that we always alternate at executing one compute and one comm op together
103+
self._execution_barrier = threading.Barrier(2)
104+
105+
self._coordination_enabled = False
106+
self._cycle_count = 0
107+
self._num_layers = None
108+
109+
def barrier(self):
110+
"""Barrier for 2 threads to synchronize"""
111+
if not self.is_coordination_enabled():
112+
return
113+
114+
try:
115+
self._execution_barrier.wait()
116+
except threading.BrokenBarrierError:
117+
pass
118+
119+
def enable_coordination(self, num_layers: Optional[int] = None):
120+
if num_layers is not None and num_layers > 0:
121+
self._coordination_enabled = True
122+
self._cycle_count = 0
123+
124+
# Reset barrier
125+
self._execution_barrier = threading.Barrier(2)
126+
self._num_layers = num_layers
127+
128+
def disable_coordination(self):
129+
self._coordination_enabled = False
130+
self._cycle_count = 0
131+
self._execution_barrier.abort() # Break barrier to unblock threads
132+
133+
def check_should_continue_coordination(self):
134+
if self._num_layers is not None and self._cycle_count >= self._num_layers:
135+
return False
136+
return True
137+
138+
def is_coordination_enabled(self):
139+
return self._coordination_enabled
140+
141+
142+
# Global coordinator
143+
_hook_coordinator = HookCoordinator()
144+
145+
146+
class SyncHook(torch.autograd.Function):
147+
@staticmethod
148+
def forward(ctx, x, hook_name=""):
149+
ctx.hook_name = hook_name
150+
# handle edge case for transformer level boundary
151+
if _hook_coordinator._coordination_enabled and hook_name == "D":
152+
_hook_coordinator._cycle_count += 1
153+
if not _hook_coordinator.check_should_continue_coordination():
154+
_hook_coordinator.disable_coordination()
155+
return x
156+
157+
_hook_coordinator.barrier()
158+
return x
159+
160+
@staticmethod
161+
def backward(ctx, grad_output):
162+
hook_name = ctx.hook_name
163+
164+
# Edge case, skip initial barrier, all subsequent backward hooks will acquire
165+
if hook_name == "D" and _hook_coordinator._cycle_count == 0:
166+
return grad_output, None
167+
168+
_hook_coordinator.barrier()
169+
return grad_output, None
170+
171+
172+
def _count_moe_modules(model):
173+
"""Count MoE modules directly"""
174+
from torchtitan.models.moe import MoE
175+
176+
moe_count = 0
177+
for _, module in model.named_modules():
178+
if isinstance(module, MoE):
179+
moe_count += 1
180+
return moe_count
181+
182+
183+
device_type, device_module = get_device_info()
184+
185+
186+
def overlap_callback(action: _Action, ctx: _PipelineContext):
187+
"""
188+
Custom callback for OVERLAP_F_B computation that allows expert parallel communication
189+
and pipeline parallel computation to overlap.
190+
"""
191+
schedule = ctx.schedule_ref
192+
assert isinstance(schedule, _PipelineScheduleRuntime)
193+
stage_index_to_stage: dict[int, _PipelineStageBase] = {
194+
stage.stage_index: stage for stage in schedule._stages
195+
}
196+
assert action.sub_actions is not None
197+
fwd_action = action.sub_actions[0]
198+
bwd_action = action.sub_actions[1]
199+
200+
# Get stages
201+
forward_stage_index = fwd_action.stage_index
202+
forward_mb_index = fwd_action.microbatch_index
203+
assert forward_mb_index is not None
204+
backward_stage_index = bwd_action.stage_index
205+
backward_stage = stage_index_to_stage[backward_stage_index]
206+
207+
# Forward setup
208+
arg_mbs = ctx.arg_mbs
209+
kwarg_mbs = ctx.kwarg_mbs
210+
assert arg_mbs is not None and kwarg_mbs is not None
211+
fwd_recv_ops = schedule.fwd_recv_ops
212+
forward_stage = stage_index_to_stage[forward_stage_index]
213+
forward_is_next_stage_on_this_rank = forward_stage_index + 1 in stage_index_to_stage
214+
forward_is_prev_stage_on_this_rank = forward_stage_index - 1 in stage_index_to_stage
215+
216+
# Backward setup
217+
backward_is_next_stage_on_this_rank = (
218+
backward_stage.stage_index + 1 in stage_index_to_stage
219+
)
220+
backward_is_prev_stage_on_this_rank = (
221+
backward_stage.stage_index - 1 in stage_index_to_stage
222+
)
223+
backward_mb_index = bwd_action.microbatch_index
224+
assert backward_mb_index is not None
225+
bwd_recv_ops = schedule.bwd_recv_ops
226+
227+
# Fwd receives
228+
if (
229+
not forward_stage.is_first
230+
# no recv op expected for V-schedule special case
231+
and not forward_is_prev_stage_on_this_rank
232+
):
233+
assert (
234+
forward_stage_index,
235+
forward_mb_index,
236+
) in fwd_recv_ops, f"Computing {action=} before receiving input"
237+
_wait_batch_p2p(fwd_recv_ops.pop((forward_stage_index, forward_mb_index)))
238+
239+
# Bwd receives
240+
if (
241+
not backward_stage.is_last
242+
# no recv op expected for V-schedule special case
243+
and not backward_is_next_stage_on_this_rank
244+
):
245+
assert (
246+
backward_stage_index,
247+
backward_mb_index,
248+
) in bwd_recv_ops, f"Attempted to run compute {action=} before receiving input"
249+
_wait_batch_p2p(bwd_recv_ops.pop((backward_stage_index, backward_mb_index)))
250+
251+
# We count num layers in case the stage layers differ
252+
# If they differ than we only want coordination to happen for the min amount of layers
253+
min_num_layers = min(
254+
_count_moe_modules(forward_stage.submod),
255+
_count_moe_modules(backward_stage.submod),
256+
)
257+
# PP computation ========================================================
258+
_hook_coordinator.enable_coordination(num_layers=min_num_layers)
259+
main_stream = torch.accelerator.current_stream(device_module)
260+
261+
# Shared container for exception from backward thread
262+
def run_backward():
263+
schedule._assert_unsharded(backward_stage)
264+
# Set the backward thread to use the same stream as forward
265+
device_module.set_stream(main_stream)
266+
with record_function(
267+
f"backward_stage_{backward_stage_index}_mb_{backward_mb_index}"
268+
):
269+
loss = schedule._maybe_get_loss(backward_stage, backward_mb_index)
270+
schedule.backward_counter[backward_stage_index] += 1
271+
last_backward = (
272+
schedule.backward_counter[backward_stage_index]
273+
== schedule._n_microbatches
274+
)
275+
backward_stage.backward_one_chunk(
276+
backward_mb_index,
277+
loss=loss,
278+
full_backward=True,
279+
last_backward=last_backward,
280+
)
281+
282+
if backward_is_prev_stage_on_this_rank:
283+
stage_index_to_stage[backward_stage_index - 1].set_local_bwd_input(
284+
backward_stage.get_local_bwd_output(backward_mb_index),
285+
backward_mb_index,
286+
)
287+
288+
def run_forward():
289+
schedule._assert_unsharded(forward_stage)
290+
output = forward_stage.forward_one_chunk(
291+
forward_mb_index,
292+
arg_mbs[forward_mb_index],
293+
kwarg_mbs[forward_mb_index],
294+
)
295+
schedule._maybe_compute_loss(
296+
forward_stage, output, ctx.target_mbs, forward_mb_index
297+
)
298+
if forward_is_next_stage_on_this_rank:
299+
stage_index_to_stage[forward_stage_index + 1].set_local_fwd_input(
300+
output, forward_mb_index
301+
)
302+
303+
# Run forward and backward in parallel
304+
thread = threading.Thread(target=run_backward, daemon=True)
305+
thread.start()
306+
run_forward()
307+
thread.join()
308+
309+
_hook_coordinator.disable_coordination()

0 commit comments

Comments
 (0)