diff --git a/deepspeed/compile/custom_ops/all_to_all.py b/deepspeed/compile/custom_ops/all_to_all.py new file mode 100644 index 000000000000..6f4b5172b67a --- /dev/null +++ b/deepspeed/compile/custom_ops/all_to_all.py @@ -0,0 +1,64 @@ +import torch +import torch.distributed as dist + +@torch.library.custom_op("autosp::all_to_all", mutates_args=()) +def all_to_all( + input: torch.Tensor, + scatter_idx: int, + gather_idx: int, + world_size: int, + name: str, +) -> torch.Tensor: + B, dim1, dim2, H = input.shape + + if scatter_idx == 1: + N, local_S = dim1, dim2 + input_t = input.reshape(B, world_size, N // world_size, local_S, H) + input_t = input_t.permute(1, 0, 2, 3, 4).contiguous() + + output = torch.empty_like(input_t) + dist.all_to_all_single(output, input_t, group=dist.group.WORLD) + + output = output.permute(1, 2, 0, 3, 4).contiguous() + output = output.reshape(B, N // world_size, world_size * local_S, H) + else: + local_N, S = dim1, dim2 + input_t = input.reshape(B, local_N, world_size, S // world_size, H) + input_t = input_t.permute(2, 0, 1, 3, 4).contiguous() + + output = torch.empty_like(input_t) + dist.all_to_all_single(output, input_t, group=dist.group.WORLD) + + output = output.permute(1, 0, 2, 3, 4).contiguous() + output = output.reshape(B, world_size * local_N, S // world_size, H) + + return output + + +@torch.library.register_fake("autosp::all_to_all") +def all_to_all_fake(input: torch.Tensor, scatter_idx: int, gather_idx: int, world_size: int, name: str): + B, dim1, dim2, H = input.shape + if scatter_idx == 1: + return input.new_empty(B, dim1 // world_size, dim2 * world_size, H) + else: + return input.new_empty(B, dim1 * world_size, dim2 // world_size, H) + + +def _all_to_all_backward_setup(ctx, inputs, output): + _, scatter_idx, gather_idx, world_size, name = inputs + ctx.scatter_idx = gather_idx + ctx.gather_idx = scatter_idx + ctx.world_size = world_size + ctx.name = name + "_grad" + + +def _all_to_all_backward(ctx, grad): + return ( + all_to_all(grad, ctx.scatter_idx, ctx.gather_idx, ctx.world_size, ctx.name), + None, None, None, None, + ) + + +torch.library.register_autograd( + "autosp::all_to_all", _all_to_all_backward, setup_context=_all_to_all_backward_setup +) diff --git a/deepspeed/compile/fx.py b/deepspeed/compile/fx.py index 7b3408b56afe..d745bbda4624 100644 --- a/deepspeed/compile/fx.py +++ b/deepspeed/compile/fx.py @@ -3,11 +3,11 @@ # DeepSpeed Team -from typing import Callable, Any, List, Dict +from typing import Callable, Any, List, Dict, Optional from collections import defaultdict import torch -from torch.fx import Node, Graph +from torch.fx import Node, Graph, GraphModule from .util import get_last_uses @@ -138,3 +138,28 @@ def free_tensors(tensors: List[torch.Tensor]): # Python version for debugging # graph.create_node('call_function', free_tensors, args, {}, name=node_name) + +def find_node_by_name(gm: GraphModule, name: str) -> Optional[Node]: + for node in gm.graph.nodes: + if node.name == name: + return node + return None + +def get_node_shape_meta(node: Node) -> Optional[torch.Tensor]: + return node.meta.get("val") or node.meta.get("example_value") + +def find_node_by_tag(gm: GraphModule, tag: str) -> Optional[Node]: + input_id_node = None + for node in gm.graph.nodes: + # https://github.com/pytorch/pytorch/blob/085b71eab05cbc7d474a173884269c62d2778f77/torch/_dynamo/utils.py#L5048 + tensor_dict = node.meta.get('tensor_dict') + if tensor_dict and tensor_dict.get('tag') == tag: + input_id_node = node + break + return input_id_node + +def replace_node_users(node: Node, replacement: Node, exclude: Optional[List[Node]] = None): + exclude = exclude or [] + to_replace = [u for u in node.users if u not in exclude] + for user in to_replace: + user.replace_input_with(node, replacement) diff --git a/deepspeed/compile/init_sp.py b/deepspeed/compile/init_sp.py new file mode 100644 index 000000000000..7862420a2006 --- /dev/null +++ b/deepspeed/compile/init_sp.py @@ -0,0 +1,14 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from torch.fx import GraphModule +from .passes.sp_compile import apply_autosp + +def init_autosp(): + def backend_fn(gm: GraphModule, real_inputs): + apply_autosp(gm, real_inputs, debug=False) + return torch._inductor.compile(gm, real_inputs) + return backend_fn diff --git a/deepspeed/compile/passes/sp_compile.py b/deepspeed/compile/passes/sp_compile.py new file mode 100644 index 000000000000..19c1588d273e --- /dev/null +++ b/deepspeed/compile/passes/sp_compile.py @@ -0,0 +1,235 @@ +"""AutoSP: Automatic Sequence Parallel (Ulysses) pass for graph modules. + +Ulysses Transformation: + Input: [B, N, S/P, H] (all heads, partitioned sequence) + After A2A on QKV: [B, N/P, S, H] (partitioned heads, full sequence) + After SDPA: [B, N/P, S, H] + After A2A on O: [B, N, S/P, H] (all heads, partitioned sequence) + +Where: + B = batch size, N = num heads, S = full sequence length, H = head dim, P = world size +""" + +import operator +from typing import Optional, List, Callable + +import torch +import torch.distributed as dist +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx import GraphModule, Node +from torch.fx.passes.fake_tensor_prop import FakeTensorProp +from torch.fx.experimental.symbolic_shapes import ShapeEnv + +from deepspeed.runtime import constants + +from ..custom_ops import all_to_all +from ..fx import find_node_by_name, get_node_shape_meta +from ..util import get_input_id_node, get_label_id_node, get_position_id_node, shard_tensor_node, get_sdpa_nodes, ShardingConfig + +def prepare_autosp_inputs(input_id: torch.Tensor, label_id: torch.Tensor, position_id: torch.Tensor = None, attention_mask: torch.Tensor = None, seq_dim: int = 1): + """ + Prepare inputs for AutoSP by marking dynamic dimensions and tagging tensors. + + Args: + input_id: Token IDs tensor (required) + label_id: Label IDs tensor (required) + position_id: Position IDs tensor (optional) + attention_mask: Attention mask tensor (optional) + seq_dim: Sequence dimension index to mark as dynamic (default: 1) + """ + + if input_id is None: + raise ValueError("input_id is required") + if label_id is None: + raise ValueError("label_id is required") + + if seq_dim < 0 or seq_dim >= input_id.ndim: + raise ValueError(f"seq_dim {seq_dim} must be a valid index for input_id with shape {input_id.shape}") + + if position_id is not None: + if seq_dim >= position_id.ndim: + raise ValueError(f"seq_dim {seq_dim} is out of bounds for position_id with shape {position_id.shape}") + + if attention_mask is not None: + if seq_dim >= attention_mask.ndim: + raise ValueError(f"seq_dim {seq_dim} is out of bounds for attention_mask with shape {attention_mask.shape}") + + torch._dynamo.decorators.mark_dynamic(input_id, seq_dim) + torch._dynamo.decorators.mark_dynamic(label_id, seq_dim) + if position_id is not None: + torch._dynamo.decorators.mark_dynamic(position_id, seq_dim) + if attention_mask is not None: + torch._dynamo.decorators.mark_dynamic(attention_mask, seq_dim) + + input_id.tag = constants.INPUT_ID_KEY + label_id.tag = constants.LABEL_ID_KEY + if position_id is not None: + position_id.tag = constants.POSITION_ID_KEY + + return input_id, label_id, position_id, attention_mask + +def pass_shard_seq_dim(gm: GraphModule, example_inputs): + """ + Finds all direct and indirect consumers of the input sequence, label and position ids. + Shard the sequence dimension used by all such consumers. + """ + world_size = dist.get_world_size() + + input_ids_node = get_input_id_node(gm) + val = get_node_shape_meta(input_ids_node) + seq_symint = val.shape[1] + assert isinstance(seq_symint, torch.SymInt), f"expected sequence dimension to be of type `torch.SymInt` but found `{type(seq_symint)}`" + + sym_seq_dim_node = find_node_by_name(gm, str(seq_symint)) + if sym_seq_dim_node is None: + print(f"WARNING: Could not find the symbolic node for the sequence dimension") + return + + with gm.graph.inserting_after(sym_seq_dim_node): + sharded_node = gm.graph.call_function( + operator.floordiv, + args=(sym_seq_dim_node, world_size) + ) + + sharded_input_nodes = set() + label_ids_node = get_label_id_node(gm) + position_ids_node = get_position_id_node(gm) + + if input_ids_node is not None: + sharded_input_nodes.add(input_ids_node) + if label_ids_node is not None: + sharded_input_nodes.add(label_ids_node) + if position_ids_node is not None: + sharded_input_nodes.add(position_ids_node) + + # find all consumers of the sharded inputs + consumer_nodes = set() + worklist = list(sharded_input_nodes) + visited = set() + + while worklist: + node = worklist.pop(0) + if node in visited: + continue + visited.add(node) + consumer_nodes.add(node) + + for user in node.users: + if user not in visited: + worklist.append(user) + + to_replace = [] + for node in consumer_nodes: + if sym_seq_dim_node in node.all_input_nodes: + to_replace.append(node) + + for user in to_replace: + user.replace_input_with(sym_seq_dim_node, sharded_node) + + +def pass_shard_input_ids(gm: GraphModule, example_inputs): + config = ShardingConfig.from_distributed() + input_ids_node = get_input_id_node(gm) + shard_tensor_node(gm, input_ids_node, config) + + +def pass_shard_label_ids(gm: GraphModule, example_inputs): + config = ShardingConfig.from_distributed() + label_ids_node = get_label_id_node(gm) + shard_tensor_node(gm, label_ids_node, config) + +def pass_shard_position_ids(gm: GraphModule, example_inputs): + config = ShardingConfig.from_distributed() + position_ids_node = get_position_id_node(gm) + if position_ids_node is None: + print("[WARNING] position id node not found. Skipping sharding of position ids.") + return + shard_tensor_node(gm, position_ids_node, config) + + +def pass_insert_attention_all_to_all(gm: GraphModule, real_inputs): + """ + Insert all-to-all collectives around SDPA for Ulysses parallelism. + + For each SDPA: + - Before Q, K, V: scatter heads (dim=1), gather sequence (dim=2) + - After O: scatter sequence (dim=2), gather heads (dim=1) + """ + world_size = dist.get_world_size() + attention_nodes = get_sdpa_nodes(gm) + + def insert_a2a(node: Node, scatter_idx: int, gather_idx: int, name: str) -> Node: + with gm.graph.inserting_after(node): + a2a_node = gm.graph.call_function( + torch.ops.autosp.all_to_all.default, + args=(node, scatter_idx, gather_idx, world_size, name), + ) + a2a_node.name = f"a2a_{name}" + node.replace_all_uses_with(a2a_node) + a2a_node.update_arg(0, node) + return a2a_node + + for idx, attn_node in enumerate(attention_nodes): + q, k, v = attn_node.args[:3] + suffix = f"_{idx}" if len(attention_nodes) > 1 else "" + + # QKV: [B, N, S/P, H] -> [B, N/P, S, H] + insert_a2a(q, scatter_idx=1, gather_idx=2, name=f"q{suffix}") + insert_a2a(k, scatter_idx=1, gather_idx=2, name=f"k{suffix}") + insert_a2a(v, scatter_idx=1, gather_idx=2, name=f"v{suffix}") + + # O: [B, N/P, S, H] -> [B, N, S/P, H] + insert_a2a(attn_node, scatter_idx=2, gather_idx=1, name=f"o{suffix}") + + +def pass_canonicalize(gm: GraphModule, real_inputs): + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + +def pass_propagate_shapes(gm: torch.fx.GraphModule, real_inputs): + shape_env = ShapeEnv() + fake_mode = FakeTensorMode(shape_env=shape_env) + fake_inputs = [] + for t in real_inputs: + if isinstance(t, torch.Tensor): + fake_inputs.append(fake_mode.from_tensor(t)) + else: + fake_inputs.append(t) + FakeTensorProp(gm).propagate(*fake_inputs) + + +def apply_autosp( + gm: GraphModule, + real_inputs, + debug: bool = False, + passes: Optional[List[Callable]] = None, +): + AUTOSP_PASSES = [ + pass_shard_seq_dim, + pass_shard_input_ids, + pass_shard_label_ids, + pass_shard_position_ids, + pass_insert_attention_all_to_all, + pass_propagate_shapes, + pass_canonicalize, + ] + + passes = passes or AUTOSP_PASSES + rank = dist.get_rank() + + for p in passes: + if debug and rank == 0: + print(f"\n{'='*60}") + print(f" BEFORE: {p.__name__}") + print(f"{'='*60}\n") + print(gm.print_readable(print_output=False)) + + p(gm, real_inputs) + + if debug and rank == 0: + print(f"\n{'='*60}") + print(f" AFTER: {p.__name__}") + print(f"{'='*60}\n") + print(gm.print_readable(print_output=False)) + diff --git a/deepspeed/compile/util.py b/deepspeed/compile/util.py index e8abcc2c8b3c..9e3150e1a221 100644 --- a/deepspeed/compile/util.py +++ b/deepspeed/compile/util.py @@ -6,11 +6,13 @@ import functools import operator from typing import List, Tuple, Dict, Optional +from dataclasses import dataclass from collections import defaultdict import torch -from torch.fx import Node, Graph +from torch.fx import Node, Graph, GraphModule from torch.fx.node import map_aggregate, Argument, map_arg +import torch.nn.functional as F try: from torch._subclasses.fake_tensor import unset_fake_temporarily @@ -22,7 +24,7 @@ from deepspeed.accelerator import get_accelerator from deepspeed.utils.torch import required_torch_version from deepspeed.ops.op_builder.dc import DeepCompileBuilder - +from deepspeed.runtime import constants def is_deepcompile_supported() -> bool: return required_torch_version(min_version=2.6, max_version=2.9) and get_accelerator().device_name() == "cuda" @@ -521,3 +523,96 @@ def pad_tensors(specs: List[Tuple[torch.Tensor, int, int]]) -> List[torch.Tensor padded.append(out) return padded + +@dataclass +class ShardingConfig: + world_size: int + rank: int + + @classmethod + def from_distributed(cls) -> "ShardingConfig": + return cls( + world_size=dist.get_world_size(), + rank=dist.get_rank(), + ) + +def get_sdpa_nodes(gm: GraphModule) -> List[Node]: + return list(gm.graph.find_nodes( + op="call_function", + target=F.scaled_dot_product_attention, + )) + +def get_input_id_node(gm: GraphModule) -> Node: + from .fx import find_node_by_tag + node = find_node_by_tag(gm, constants.INPUT_ID_KEY) + if node is None: + raise RuntimeError("Failed to find a node for the input sequence.") + return node + +def get_label_id_node(gm: GraphModule) -> Node: + from .fx import find_node_by_tag + node = find_node_by_tag(gm, constants.LABEL_ID_KEY) + if node is None: + raise RuntimeError("Failed to find a node for the label.") + return node + +def get_position_id_node(gm: GraphModule) -> Node: + from .fx import find_node_by_tag + node = find_node_by_tag(gm, constants.POSITION_ID_KEY) + return node + +def create_shard_offsets( + gm: GraphModule, + sym_seq_dim_node: Node, + world_size: int, + rank: int +) -> Tuple[Node, Node]: + with gm.graph.inserting_after(sym_seq_dim_node): + chunk_size_node = gm.graph.call_function(operator.floordiv, args=(sym_seq_dim_node, world_size)) + with gm.graph.inserting_after(chunk_size_node): + start_node = gm.graph.call_function(operator.mul, args=(rank, chunk_size_node)) + with gm.graph.inserting_after(start_node): + end_node = gm.graph.call_function(operator.add, args=(start_node, chunk_size_node)) + + return start_node, end_node + +def create_symbolic_slice_indices( + gm: GraphModule, + sym_seq_dim_node: Node, + config: ShardingConfig +) -> Tuple[Node, Node]: + start_node, end_node = create_shard_offsets(gm, sym_seq_dim_node, config.world_size, config.rank) + + with gm.graph.inserting_after(end_node): + slice_all = gm.graph.call_function(slice, args=(None, None, None)) + with gm.graph.inserting_after(slice_all): + slice_range = gm.graph.call_function(slice, args=(start_node, end_node, None)) + + return slice_all, slice_range + +def shard_tensor_node( + gm: GraphModule, + tensor_node: Node, + config: ShardingConfig +): + from .fx import find_node_by_name, get_node_shape_meta, replace_node_users + val = get_node_shape_meta(tensor_node) + assert val is not None, f"Node {tensor_node.name} has no shape metadata" + + seq_len = val.shape[1] + + assert isinstance(seq_len, torch.SymInt), f"Expected sequence dimension to be `torch.SymInt` but instead found `{type(seq_len)}`" + + symb_seq_int_node = find_node_by_name(gm, str(seq_len)) + assert symb_seq_int_node, f"Unable to find symbolic placeholder for {seq_len}" + + slice_all, slice_range = create_symbolic_slice_indices(gm, symb_seq_int_node, config) + indices = (slice_all, slice_range) + + with gm.graph.inserting_after(tensor_node): + sliced_node = gm.graph.call_function( + operator.getitem, + args=(tensor_node, indices), + ) + + replace_node_users(tensor_node, sliced_node, exclude=[sliced_node]) diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index 9e73bad73376..a916befc76f4 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -501,3 +501,10 @@ class ValidationMode: ######################################### USE_DATA_BEFORE_EXPERT_PARALLEL = "use_data_before_expert_parallelism" USE_DATA_BEFORE_EXPERT_PARALLEL_DEFAULT = False + +######################################### +# AUTOSP +######################################### +INPUT_ID_KEY = "input_id" +LABEL_ID_KEY = "label_id" +POSITION_ID_KEY = "position_id" diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index e6d838df5adf..eb71a7c51765 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -127,6 +127,7 @@ from deepspeed.compile.passes import zero3_compile, prefetch, selective_gather, offload_adam_states from deepspeed.compile.init_z1 import init_z1 from deepspeed.compile.init_z3 import init_z3 +from deepspeed.compile.init_sp import init_autosp MEMORY_OPT_ALLREDUCE_SIZE = 500000000 @@ -4361,7 +4362,8 @@ def compile(self, enable_deepcompile = self.is_deepcompile_enabled() if enable_deepcompile and self.zero_optimization_stage() != ZeroStageEnum.optimizer_states \ and self.zero_optimization_stage() != ZeroStageEnum.weights \ - and self.zero_optimization_stage() != ZeroStageEnum.gradients: + and self.zero_optimization_stage() != ZeroStageEnum.gradients \ + and self.zero_optimization_stage() != ZeroStageEnum.disabled: logger.info( f"Currently DeepCompile supports ZeRO stage 1, 2, or 3 only, but ZeRO stage is set to {self.zero_optimization_stage()}. Falling back to the torch compiler." ) @@ -4396,6 +4398,8 @@ def passes_name_to_fn(passes): "DeepCompile with ZeRO stage 3 is not currently supported on PyTorch >= 2.9. " "Please use ZeRO stage 1 or 2 with DeepCompile, or disable DeepCompile for ZeRO stage 3.") backend = init_z3(self, backend, compile_config, compile_kwargs, schedule) + elif self.zero_optimization_stage() == ZeroStageEnum.disabled: + backend = init_autosp() # Hook state must align with whether DeepCompile is active. self._set_deepcompile_active(enable_deepcompile)