|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | +import threading |
| 7 | +from typing import cast, Optional |
| 8 | + |
| 9 | +import torch |
| 10 | +import torch.nn as nn |
| 11 | +from torch import Tensor |
| 12 | + |
| 13 | +from torch.distributed.pipelining.schedules import ( |
| 14 | + _Action, |
| 15 | + _PipelineContext, |
| 16 | + _PipelineScheduleRuntime, |
| 17 | + _wait_batch_p2p, |
| 18 | +) |
| 19 | +from torch.distributed.pipelining.stage import _PipelineStageBase |
| 20 | +from torch.distributed.tensor import DeviceMesh, distribute_module |
| 21 | +from torch.profiler import record_function |
| 22 | + |
| 23 | +from torchtitan.distributed.expert_parallel import BaseExpertParallel |
| 24 | + |
| 25 | +from torchtitan.tools.utils import get_device_info |
| 26 | + |
| 27 | +""" |
| 28 | +Below are optimizations related to pipeline parallelism with expert parallelism |
| 29 | +""" |
| 30 | + |
| 31 | + |
| 32 | +def get_dual_pipe_v_flag(job_config, parallel_dims) -> bool: |
| 33 | + """ |
| 34 | + Determine if DualPipeV should be enabled based on config and |
| 35 | + validates that incompatible features (EP + DualPipeV + AC) are not used together. |
| 36 | + """ |
| 37 | + if not parallel_dims.ep_enabled or not parallel_dims.pp_enabled: |
| 38 | + return False |
| 39 | + |
| 40 | + dual_pipe_v = ( |
| 41 | + job_config.parallelism.pipeline_parallel_expert_parallel_overlap |
| 42 | + and job_config.parallelism.pipeline_parallel_schedule.lower() == "dualpipev" |
| 43 | + ) |
| 44 | + |
| 45 | + if dual_pipe_v and job_config.activation_checkpoint.mode != "none": |
| 46 | + raise NotImplementedError( |
| 47 | + "Expert Parallel with DualPipeV and Activation Checkpointing " |
| 48 | + "cannot be used together. Please disable one of them." |
| 49 | + ) |
| 50 | + |
| 51 | + return dual_pipe_v |
| 52 | + |
| 53 | + |
| 54 | +class DualPipeExpertParallel(BaseExpertParallel): |
| 55 | + """ |
| 56 | + Wrapper that adds dual-pipe synchronization hooks to any BaseExpertParallel. |
| 57 | + Wraps dispatch/combine with sync hooks for overlapping EP communication |
| 58 | + with PP computation in DualPipe scheduling. |
| 59 | +
|
| 60 | + The execution order becomes: |
| 61 | + A -> dispatch -> B -> module -> C -> combine -> D |
| 62 | + """ |
| 63 | + |
| 64 | + def __init__(self, inner_ep: BaseExpertParallel): |
| 65 | + super().__init__() |
| 66 | + self.inner_ep = inner_ep |
| 67 | + |
| 68 | + def _partition_fn(self, name: str, mod: nn.Module, device_mesh: DeviceMesh) -> None: |
| 69 | + return self.inner_ep._partition_fn(name, mod, device_mesh) |
| 70 | + |
| 71 | + def _token_dispatch( |
| 72 | + self, mod: nn.Module, inputs: tuple, device_mesh: DeviceMesh |
| 73 | + ) -> tuple[Tensor, Tensor]: |
| 74 | + """A -> dispatch -> B""" |
| 75 | + inputs = (cast(Tensor, SyncHook.apply(inputs[0], "A")),) + inputs[1:] |
| 76 | + outputs = self.inner_ep._token_dispatch(mod, inputs, device_mesh) |
| 77 | + outputs = (cast(Tensor, SyncHook.apply(outputs[0], "B")),) + outputs[1:] |
| 78 | + return outputs |
| 79 | + |
| 80 | + def _token_combine( |
| 81 | + self, mod: nn.Module, routed_output: Tensor, device_mesh: DeviceMesh |
| 82 | + ) -> Tensor: |
| 83 | + """C -> combine -> D""" |
| 84 | + routed_output = cast(Tensor, SyncHook.apply(routed_output, "C")) |
| 85 | + combine_output = self.inner_ep._token_combine(mod, routed_output, device_mesh) |
| 86 | + combine_output = cast(Tensor, SyncHook.apply(combine_output, "D")) |
| 87 | + return combine_output |
| 88 | + |
| 89 | + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: |
| 90 | + return distribute_module( |
| 91 | + module, |
| 92 | + device_mesh, |
| 93 | + partition_fn=self._partition_fn, |
| 94 | + input_fn=self._token_dispatch, |
| 95 | + output_fn=self._token_combine, |
| 96 | + ) |
| 97 | + |
| 98 | + |
| 99 | +class HookCoordinator: |
| 100 | + def __init__(self): |
| 101 | + # Barrier for 2 threads (forward and backward) to synchronize |
| 102 | + # This ensures that we always alternate at executing one compute and one comm op together |
| 103 | + self._execution_barrier = threading.Barrier(2) |
| 104 | + |
| 105 | + self._coordination_enabled = False |
| 106 | + self._cycle_count = 0 |
| 107 | + self._num_layers = None |
| 108 | + |
| 109 | + def barrier(self): |
| 110 | + """Barrier for 2 threads to synchronize""" |
| 111 | + if not self.is_coordination_enabled(): |
| 112 | + return |
| 113 | + |
| 114 | + try: |
| 115 | + self._execution_barrier.wait() |
| 116 | + except threading.BrokenBarrierError: |
| 117 | + pass |
| 118 | + |
| 119 | + def enable_coordination(self, num_layers: Optional[int] = None): |
| 120 | + if num_layers is not None and num_layers > 0: |
| 121 | + self._coordination_enabled = True |
| 122 | + self._cycle_count = 0 |
| 123 | + |
| 124 | + # Reset barrier |
| 125 | + self._execution_barrier = threading.Barrier(2) |
| 126 | + self._num_layers = num_layers |
| 127 | + |
| 128 | + def disable_coordination(self): |
| 129 | + self._coordination_enabled = False |
| 130 | + self._cycle_count = 0 |
| 131 | + self._execution_barrier.abort() # Break barrier to unblock threads |
| 132 | + |
| 133 | + def check_should_continue_coordination(self): |
| 134 | + if self._num_layers is not None and self._cycle_count >= self._num_layers: |
| 135 | + return False |
| 136 | + return True |
| 137 | + |
| 138 | + def is_coordination_enabled(self): |
| 139 | + return self._coordination_enabled |
| 140 | + |
| 141 | + |
| 142 | +# Global coordinator |
| 143 | +_hook_coordinator = HookCoordinator() |
| 144 | + |
| 145 | + |
| 146 | +class SyncHook(torch.autograd.Function): |
| 147 | + @staticmethod |
| 148 | + def forward(ctx, x, hook_name=""): |
| 149 | + ctx.hook_name = hook_name |
| 150 | + # handle edge case for transformer level boundary |
| 151 | + if _hook_coordinator._coordination_enabled and hook_name == "D": |
| 152 | + _hook_coordinator._cycle_count += 1 |
| 153 | + if not _hook_coordinator.check_should_continue_coordination(): |
| 154 | + _hook_coordinator.disable_coordination() |
| 155 | + return x |
| 156 | + |
| 157 | + _hook_coordinator.barrier() |
| 158 | + return x |
| 159 | + |
| 160 | + @staticmethod |
| 161 | + def backward(ctx, grad_output): |
| 162 | + hook_name = ctx.hook_name |
| 163 | + |
| 164 | + # Edge case, skip initial barrier, all subsequent backward hooks will acquire |
| 165 | + if hook_name == "D" and _hook_coordinator._cycle_count == 0: |
| 166 | + return grad_output, None |
| 167 | + |
| 168 | + _hook_coordinator.barrier() |
| 169 | + return grad_output, None |
| 170 | + |
| 171 | + |
| 172 | +def _count_moe_modules(model): |
| 173 | + """Count MoE modules directly""" |
| 174 | + from torchtitan.models.moe import MoE |
| 175 | + |
| 176 | + moe_count = 0 |
| 177 | + for _, module in model.named_modules(): |
| 178 | + if isinstance(module, MoE): |
| 179 | + moe_count += 1 |
| 180 | + return moe_count |
| 181 | + |
| 182 | + |
| 183 | +device_type, device_module = get_device_info() |
| 184 | + |
| 185 | + |
| 186 | +def overlap_callback(action: _Action, ctx: _PipelineContext): |
| 187 | + """ |
| 188 | + Custom callback for OVERLAP_F_B computation that allows expert parallel communication |
| 189 | + and pipeline parallel computation to overlap. |
| 190 | + """ |
| 191 | + schedule = ctx.schedule_ref |
| 192 | + assert isinstance(schedule, _PipelineScheduleRuntime) |
| 193 | + stage_index_to_stage: dict[int, _PipelineStageBase] = { |
| 194 | + stage.stage_index: stage for stage in schedule._stages |
| 195 | + } |
| 196 | + assert action.sub_actions is not None |
| 197 | + fwd_action = action.sub_actions[0] |
| 198 | + bwd_action = action.sub_actions[1] |
| 199 | + |
| 200 | + # Get stages |
| 201 | + forward_stage_index = fwd_action.stage_index |
| 202 | + forward_mb_index = fwd_action.microbatch_index |
| 203 | + assert forward_mb_index is not None |
| 204 | + backward_stage_index = bwd_action.stage_index |
| 205 | + backward_stage = stage_index_to_stage[backward_stage_index] |
| 206 | + |
| 207 | + # Forward setup |
| 208 | + arg_mbs = ctx.arg_mbs |
| 209 | + kwarg_mbs = ctx.kwarg_mbs |
| 210 | + assert arg_mbs is not None and kwarg_mbs is not None |
| 211 | + fwd_recv_ops = schedule.fwd_recv_ops |
| 212 | + forward_stage = stage_index_to_stage[forward_stage_index] |
| 213 | + forward_is_next_stage_on_this_rank = forward_stage_index + 1 in stage_index_to_stage |
| 214 | + forward_is_prev_stage_on_this_rank = forward_stage_index - 1 in stage_index_to_stage |
| 215 | + |
| 216 | + # Backward setup |
| 217 | + backward_is_next_stage_on_this_rank = ( |
| 218 | + backward_stage.stage_index + 1 in stage_index_to_stage |
| 219 | + ) |
| 220 | + backward_is_prev_stage_on_this_rank = ( |
| 221 | + backward_stage.stage_index - 1 in stage_index_to_stage |
| 222 | + ) |
| 223 | + backward_mb_index = bwd_action.microbatch_index |
| 224 | + assert backward_mb_index is not None |
| 225 | + bwd_recv_ops = schedule.bwd_recv_ops |
| 226 | + |
| 227 | + # Fwd receives |
| 228 | + if ( |
| 229 | + not forward_stage.is_first |
| 230 | + # no recv op expected for V-schedule special case |
| 231 | + and not forward_is_prev_stage_on_this_rank |
| 232 | + ): |
| 233 | + assert ( |
| 234 | + forward_stage_index, |
| 235 | + forward_mb_index, |
| 236 | + ) in fwd_recv_ops, f"Computing {action=} before receiving input" |
| 237 | + _wait_batch_p2p(fwd_recv_ops.pop((forward_stage_index, forward_mb_index))) |
| 238 | + |
| 239 | + # Bwd receives |
| 240 | + if ( |
| 241 | + not backward_stage.is_last |
| 242 | + # no recv op expected for V-schedule special case |
| 243 | + and not backward_is_next_stage_on_this_rank |
| 244 | + ): |
| 245 | + assert ( |
| 246 | + backward_stage_index, |
| 247 | + backward_mb_index, |
| 248 | + ) in bwd_recv_ops, f"Attempted to run compute {action=} before receiving input" |
| 249 | + _wait_batch_p2p(bwd_recv_ops.pop((backward_stage_index, backward_mb_index))) |
| 250 | + |
| 251 | + # We count num layers in case the stage layers differ |
| 252 | + # If they differ than we only want coordination to happen for the min amount of layers |
| 253 | + min_num_layers = min( |
| 254 | + _count_moe_modules(forward_stage.submod), |
| 255 | + _count_moe_modules(backward_stage.submod), |
| 256 | + ) |
| 257 | + # PP computation ======================================================== |
| 258 | + _hook_coordinator.enable_coordination(num_layers=min_num_layers) |
| 259 | + main_stream = torch.accelerator.current_stream(device_module) |
| 260 | + |
| 261 | + # Shared container for exception from backward thread |
| 262 | + def run_backward(): |
| 263 | + schedule._assert_unsharded(backward_stage) |
| 264 | + # Set the backward thread to use the same stream as forward |
| 265 | + device_module.set_stream(main_stream) |
| 266 | + with record_function( |
| 267 | + f"backward_stage_{backward_stage_index}_mb_{backward_mb_index}" |
| 268 | + ): |
| 269 | + loss = schedule._maybe_get_loss(backward_stage, backward_mb_index) |
| 270 | + schedule.backward_counter[backward_stage_index] += 1 |
| 271 | + last_backward = ( |
| 272 | + schedule.backward_counter[backward_stage_index] |
| 273 | + == schedule._n_microbatches |
| 274 | + ) |
| 275 | + backward_stage.backward_one_chunk( |
| 276 | + backward_mb_index, |
| 277 | + loss=loss, |
| 278 | + full_backward=True, |
| 279 | + last_backward=last_backward, |
| 280 | + ) |
| 281 | + |
| 282 | + if backward_is_prev_stage_on_this_rank: |
| 283 | + stage_index_to_stage[backward_stage_index - 1].set_local_bwd_input( |
| 284 | + backward_stage.get_local_bwd_output(backward_mb_index), |
| 285 | + backward_mb_index, |
| 286 | + ) |
| 287 | + |
| 288 | + def run_forward(): |
| 289 | + schedule._assert_unsharded(forward_stage) |
| 290 | + output = forward_stage.forward_one_chunk( |
| 291 | + forward_mb_index, |
| 292 | + arg_mbs[forward_mb_index], |
| 293 | + kwarg_mbs[forward_mb_index], |
| 294 | + ) |
| 295 | + schedule._maybe_compute_loss( |
| 296 | + forward_stage, output, ctx.target_mbs, forward_mb_index |
| 297 | + ) |
| 298 | + if forward_is_next_stage_on_this_rank: |
| 299 | + stage_index_to_stage[forward_stage_index + 1].set_local_fwd_input( |
| 300 | + output, forward_mb_index |
| 301 | + ) |
| 302 | + |
| 303 | + # Run forward and backward in parallel |
| 304 | + thread = threading.Thread(target=run_backward, daemon=True) |
| 305 | + thread.start() |
| 306 | + run_forward() |
| 307 | + thread.join() |
| 308 | + |
| 309 | + _hook_coordinator.disable_coordination() |
0 commit comments