From 0c9b893261a2e8e7fb37f5ea5e236d7665a9cb5b Mon Sep 17 00:00:00 2001 From: Whsjrczr Date: Mon, 12 Aug 2024 15:35:58 +0800 Subject: [PATCH 01/10] 2 api --- .../infer_symbolic_shape/binary_infer_sym.cc | 62 +++++++++++++++++-- .../infer_symbolic_shape/binary_infer_sym.h | 2 +- .../multiary_infer_sym.cc | 60 ++++++++++++++++-- .../infer_symbolic_shape/multiary_infer_sym.h | 2 +- paddle/phi/ops/yaml/ops.yaml | 2 + 5 files changed, 115 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index 43c35dc905ada..c03876e0959fd 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -180,11 +180,63 @@ bool Binomial_OpInferSymbolicShape( // return true; // } -// bool BmmOpInferSymbolicShape(pir::Operation *op, -// pir::InferSymbolicShapeContext *infer_context) { -// // pass -// return true; -// } +bool BmmOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { + const auto &x_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const auto &y_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(1)); + + const std::vector &x_dims = x_shape_or_data.shape(); + const std::vector &y_dims = y_shape_or_data.shape(); + std::size_t x_ndims = x_dims.size(); + std::size_t y_ndims = y_dims.size(); + + PADDLE_ENFORCE_EQ( + x_ndims, + 3, + common::errors::InvalidArgument("Input(X) of BmmOp must be 3-dimensional " + "in BmmOp, but received X's shape: [%d].", + x_ndims)); + PADDLE_ENFORCE_EQ( + y_ndims, + 3, + common::errors::InvalidArgument("Input(Y) of BmmOp must be 3-dimensional " + "in BmmOp, but received Y's shape: [%d].", + y_ndims)); + + auto cal_shape_fn = [](const symbol::DimExpr &x, + const symbol::DimExpr &y, + const std::string &error_str) -> symbol::DimExpr { + if (x.is_dynamic()) { + return y; + } else if (y.is_dynamic()) { + return x; + } + PADDLE_ENFORCE_EQ(x, y, common::errors::InvalidArgument(error_str, x, y)); + return x; + }; + + cal_shape_fn(x_dims[2], + y_dims[1], + "Input(X)'s width must be equal with Input(Y)'s height in " + "BmmOp, but receive X's width: [%d], Y's height: [%d]."); + symbol::DimExpr batch_size = cal_shape_fn( + x_dims[0], + y_dims[0], + "Input(X) and Input(Y) must have the same batch size in BmmOp, but " + "received X's batch size: [%d], Y's batch size [%d]"); + symbol::DimExpr out_height = x_dims[1]; + symbol::DimExpr out_width = y_dims[2]; + + std::vector out_dims = {batch_size, out_height, out_width}; + + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)}); + + return true; +} // bool CholeskySolveOpInferSymbolicShape(pir::Operation *op, // pir::InferSymbolicShapeContext diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h index e7ee88b249029..f4ca8dbcc637b 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h @@ -25,7 +25,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(BceLoss_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Binomial) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Binomial_) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(Bincount) -// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Bmm) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Bmm) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(CholeskySolve) OP_DECLARE_INFER_SYMBOLIC_SHAPE(CtcAlign) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Conv2d) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index c7cff62df9e2f..41df322077d42 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -622,12 +622,60 @@ bool BilinearOpInferSymbolicShape( // return true; // } -// bool BroadcastTensorsOpInferSymbolicShape(pir::Operation *op, -// pir::InferSymbolicShapeContext -// *infer_context) { -// // pass -// return true; -// } +bool BroadcastTensorsOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + std::vector> input_shapes; + for (size_t i = 0; i < op->num_operands(); ++i) { + const auto &input_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(i)); + input_shapes.push_back(input_shape_or_data.shape()); + } + + int target_rank = 0; + + // 1. Find Output rank = max(Inputs rank) + for (const auto &input_shape : input_shapes) { + target_rank = std::max(target_rank, static_cast(input_shape.size())); + } + + std::vector target_dims(target_rank, symbol::DimExpr(1)); + + // 2. Output dim(axis=x) = max(Inputs dim(axis=x)) + for (int index = 0; index < target_rank; ++index) { + symbol::DimExpr target_dim_size(1); + for (const auto &input_shape : input_shapes) { + int axis = static_cast(input_shape.size()) - index - 1; + symbol::DimExpr dim_size(1); + if (axis >= 0) { + dim_size = input_shape[axis]; + } + + if (!target_dim_size.is_dynamic() && !dim_size.is_dynamic() && + target_dim_size != dim_size && dim_size != 1 && + target_dim_size != 1) { + PADDLE_THROW(errors::InvalidArgument( + "BroadcastTensorsOp inputs do not satisfy broadcast semantics, " + "please check axis = %d in reverse order", + index)); + } + + if (dim_size != 1) { + target_dim_size = dim_size; + } + } + target_dims[target_rank - index - 1] = target_dim_size; + } + + // 3. Set Output Dim + for (size_t i = 0; i < op->num_results(); ++i) { + infer_context->SetShapeOrDataForValue( + op->result(i), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(target_dims)}); + } + + return true; +} bool BilinearInterpOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h index dccefa7e149d4..8df9b1f01cfeb 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h @@ -24,7 +24,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Addmm_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(AddN) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Auc) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(AssignPos) -// OP_DECLARE_INFER_SYMBOLIC_SHAPE(BroadcastTensors) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(BroadcastTensors) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchFc) OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchNorm) OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchNorm_) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index dd30e85fc84b0..8e07b6a067ef9 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -644,6 +644,7 @@ kernel : func : bmm backward : bmm_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : box_clip args: (Tensor input, Tensor im_info) @@ -671,6 +672,7 @@ func: broadcast_tensors data_type : input backward: broadcast_tensors_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : c_allgather args : (Tensor x, int ring_id, int nranks, bool use_calc_stream) From 0ec200290ec7f4da2c52e7e62281993e9adb3be3 Mon Sep 17 00:00:00 2001 From: Whsjrczr Date: Tue, 13 Aug 2024 10:32:14 +0800 Subject: [PATCH 02/10] update broadcast tensor --- .../interface/infer_symbolic_shape/multiary_infer_sym.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 41df322077d42..1b9d21e1b715a 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -638,7 +638,7 @@ bool BroadcastTensorsOpInferSymbolicShape( target_rank = std::max(target_rank, static_cast(input_shape.size())); } - std::vector target_dims(target_rank, symbol::DimExpr(1)); + std::vector target_dims(target_rank, symbol::DimExpr(0)); // 2. Output dim(axis=x) = max(Inputs dim(axis=x)) for (int index = 0; index < target_rank; ++index) { @@ -650,8 +650,7 @@ bool BroadcastTensorsOpInferSymbolicShape( dim_size = input_shape[axis]; } - if (!target_dim_size.is_dynamic() && !dim_size.is_dynamic() && - target_dim_size != dim_size && dim_size != 1 && + if (target_dim_size != dim_size && dim_size != 1 && target_dim_size != 1) { PADDLE_THROW(errors::InvalidArgument( "BroadcastTensorsOp inputs do not satisfy broadcast semantics, " From 8511068b679c00829a30b6a62cf621117c54bb8a Mon Sep 17 00:00:00 2001 From: Whsjrczr Date: Wed, 14 Aug 2024 17:55:53 +0800 Subject: [PATCH 03/10] typehint in 2 files --- .../distributed/fleet/base/role_maker.py | 90 ++++++------ .../paddle/distributed/fleet/base/topology.py | 132 ++++++++++-------- 2 files changed, 126 insertions(+), 96 deletions(-) diff --git a/python/paddle/distributed/fleet/base/role_maker.py b/python/paddle/distributed/fleet/base/role_maker.py index c7ead667d9f97..dab5a5ccd0d1b 100755 --- a/python/paddle/distributed/fleet/base/role_maker.py +++ b/python/paddle/distributed/fleet/base/role_maker.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """Definition of Role Makers.""" +from __future__ import annotations + import os import re import time import warnings from multiprocessing import Manager, Process +from typing import TYPE_CHECKING, Any import numpy as np @@ -28,6 +31,9 @@ from ...backup_env import getenv_or_backup +if TYPE_CHECKING: + from paddle.base.core import task + __all__ = [] @@ -563,7 +569,7 @@ class PaddleCloudRoleMaker(RoleMakerBase): """ - def __init__(self, is_collective=False, **kwargs): + def __init__(self, is_collective: bool = False, **kwargs: Any) -> None: super().__init__() self._is_collective = is_collective self._non_distributed = False @@ -589,16 +595,18 @@ def __init__(self, is_collective=False, **kwargs): self._gloo = Gloo() # gloo instance - def _barrier(self, comm_world): + def _barrier(self, comm_world: str) -> None: self._gloo.barrier(comm_world) - def _all_gather(self, input, comm_world="worker"): + def _all_gather(self, input: Any, comm_world="worker") -> task: return self._gloo.all_gather(input, comm_world) - def _all_reduce(self, input, mode="sum", comm_world="worker"): + def _all_reduce( + self, input: Any, mode: str = "sum", comm_world: str = "worker" + ) -> task: return self._gloo.all_reduce(input, mode, comm_world) - def _heter_device(self): + def _heter_device(self) -> str: """ return the heter device that current heter worker is using """ @@ -606,7 +614,7 @@ def _heter_device(self): self._generate_role() return self._heter_trainer_device - def _heter_device_type(self): + def _heter_device_type(self) -> str: """ return the heter device type that current heter worker is using """ @@ -614,7 +622,7 @@ def _heter_device_type(self): self._generate_role() return self._heter_trainer_device_type - def _get_stage_id(self): + def _get_stage_id(self) -> int | str: """ return stage id of current heter worker """ @@ -622,7 +630,7 @@ def _get_stage_id(self): self._generate_role() return self._stage_id - def _get_stage_trainers(self): + def _get_stage_trainers(self) -> list: """ return trainer num of all stages """ @@ -630,7 +638,7 @@ def _get_stage_trainers(self): self._generate_role() return self._stage_trainers - def _get_num_stage(self): + def _get_num_stage(self) -> int: """ return stage num """ @@ -638,7 +646,7 @@ def _get_num_stage(self): self._generate_role() return self._stage_num - def _is_worker(self): + def _is_worker(self) -> bool: """ whether current process is worker """ @@ -646,7 +654,7 @@ def _is_worker(self): self._generate_role() return self._role == Role.WORKER - def _is_server(self): + def _is_server(self) -> bool: """ whether current process is server """ @@ -654,12 +662,12 @@ def _is_server(self): self._generate_role() return self._role == Role.SERVER - def _is_coordinator(self): + def _is_coordinator(self) -> bool: if not self._role_is_generated: self._generate_role() return self._role == Role.COORDINATOR - def _is_first_worker(self): + def _is_first_worker(self) -> bool: """ whether current process is worker of rank 0 """ @@ -667,7 +675,7 @@ def _is_first_worker(self): self._generate_role() return self._role == Role.WORKER and self._current_id == 0 - def _worker_index(self): + def _worker_index(self) -> int | str: """ get index of current worker """ @@ -675,7 +683,7 @@ def _worker_index(self): self._generate_role() return self._current_id - def _server_index(self): + def _server_index(self) -> int | str: """ get index of current server """ @@ -683,7 +691,7 @@ def _server_index(self): self._generate_role() return self._current_id - def _role_id(self): + def _role_id(self) -> int | str: """ get index of current node """ @@ -691,7 +699,7 @@ def _role_id(self): self._generate_role() return self._current_id - def _worker_num(self): + def _worker_num(self) -> int: """ return the current number of worker """ @@ -699,7 +707,7 @@ def _worker_num(self): self._generate_role() return self._trainers_num - def _server_num(self): + def _server_num(self) -> int: """ return the current number of server """ @@ -711,7 +719,7 @@ def _server_num(self): else 0 ) - def _node_num(self): + def _node_num(self) -> int: """ return the training node number """ @@ -719,7 +727,7 @@ def _node_num(self): self._generate_role() return self._nodes_num - def _get_node_num(self): + def _get_node_num(self) -> int: """ return the training node number """ @@ -727,22 +735,22 @@ def _get_node_num(self): self._generate_role() return self._nodes_num - def _get_local_rank(self): + def _get_local_rank(self) -> str | None: if not self._role_is_generated: self._generate_role() return self._local_rank - def _get_local_device_ids(self): + def _get_local_device_ids(self) -> str | None: if not self._role_is_generated: self._generate_role() return self._local_device_ids - def _get_world_device_ids(self): + def _get_world_device_ids(self) -> str | None: if not self._role_is_generated: self._generate_role() return self._world_device_ids - def _get_trainer_endpoints(self): + def _get_trainer_endpoints(self) -> str | list[str]: """ get endpoint of all trainers """ @@ -750,7 +758,7 @@ def _get_trainer_endpoints(self): self._generate_role() return self._worker_endpoints - def _get_trainer_endpoint(self): + def _get_trainer_endpoint(self) -> str | None | list[str]: if not self._role_is_generated: self._generate_role() assert ( @@ -758,7 +766,7 @@ def _get_trainer_endpoint(self): ), "get_trainer_endpoint should be called by trainer" return self._cur_endpoint - def _get_heter_worker_endpoints(self): + def _get_heter_worker_endpoints(self) -> list[str] | None: """ Returns: string: all heter_trainers'endpoints @@ -770,10 +778,10 @@ def _get_heter_worker_endpoints(self): ), "Heter Worker Endpoints Not initialized" return self._heter_trainer_endpoints - def _get_heter_worker_endpoint(self): + def _get_heter_worker_endpoint(self) -> str | None: """ Returns: - int: corresponding heter_trainer's endpoint + str: corresponding heter_trainer's endpoint """ if not self._role_is_generated: self._generate_role() @@ -782,7 +790,7 @@ def _get_heter_worker_endpoint(self): ), "_get_heter_worker_endpoint should be invoked by heter worker" return self._cur_endpoint - def _get_pserver_endpoints(self): + def _get_pserver_endpoints(self) -> str | list[str] | None: """ get endpoint of all pservers """ @@ -790,7 +798,7 @@ def _get_pserver_endpoints(self): self._generate_role() return self._server_endpoints - def _get_coordinator_endpoints(self): + def _get_coordinator_endpoints(self) -> str | list[str] | None: if not self._role_is_generated: self._generate_role() return self._coordinator_endpoints @@ -807,7 +815,7 @@ def _get_previous_trainers(self): ), "_get_previous_trainers should be invoked by trainer or heter worker" return self._previous_heter_trainer_endpoints - def _get_next_trainers(self): + def _get_next_trainers(self) -> list[str]: """ invoked by heter worker """ @@ -819,7 +827,7 @@ def _get_next_trainers(self): ), "_get_next_trainers should be invoked by trainer or heter worker" return self._next_heter_trainer_endpoints - def _is_non_distributed(self): + def _is_non_distributed(self) -> bool: """ Return True if indispensable environment for fleetrun is not found (use python-run to launch fleet-code directly) @@ -828,7 +836,7 @@ def _is_non_distributed(self): self._generate_role() return self._non_distributed - def _heter_worker_num(self): + def _heter_worker_num(self) -> int: """ get heter worker nums """ @@ -836,7 +844,7 @@ def _heter_worker_num(self): self._generate_role() return self._heter_trainers_num - def _is_heter_worker(self): + def _is_heter_worker(self) -> bool: """ whether current process is heter worker """ @@ -844,7 +852,7 @@ def _is_heter_worker(self): self._generate_role() return self._role == Role.HETER_WORKER - def _ps_env(self): # each role will execute it + def _ps_env(self) -> None: # each role will execute it # Environment variable PADDLE_PSERVERS_IP_PORT_LIST must be set # format: string(ip:port,ip:port), eg. 127.0.0.1:6001,127.0.0.1:6002 self._server_endpoints = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST", None) @@ -1084,7 +1092,7 @@ def _ps_env(self): # each role will execute it self._current_id = current_id self._nodes_num = len({x.split(':')[0] for x in self._worker_endpoints}) - def _collective_env(self): + def _collective_env(self) -> None: self._current_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) self._training_role = os.getenv("PADDLE_TRAINING_ROLE", "TRAINER") assert self._training_role == "TRAINER" @@ -1107,7 +1115,7 @@ def _collective_env(self): self._local_device_ids = os.getenv("PADDLE_LOCAL_DEVICE_IDS") self._world_device_ids = os.getenv("PADDLE_WORLD_DEVICE_IDS") - def _gloo_init(self): + def _gloo_init(self) -> None: # PADDLE_WITH_GLOO 1: trainer barrier, 2: all barrier use_gloo = int(os.getenv("PADDLE_WITH_GLOO", "0")) if use_gloo not in [1, 2]: @@ -1186,7 +1194,7 @@ def _gloo_init(self): if rendezvous_type == Gloo.RENDEZVOUS.HTTP: http_server_d['running'] = False - def _generate_role(self): + def _generate_role(self) -> None: """ generate role for role maker """ @@ -1223,7 +1231,7 @@ def __init__(self, is_collective=False, init_gloo=False, **kwargs): ) self._init_gloo = init_gloo - def _user_defined_ps_env(self): + def _user_defined_ps_env(self) -> None: self._server_endpoints = self._kwargs.get("server_endpoints") self._worker_endpoints = self._kwargs.get("worker_endpoints", []) self._trainers_num = self._kwargs.get("worker_num", 0) @@ -1244,14 +1252,14 @@ def _user_defined_ps_env(self): self._cur_endpoint = self._server_endpoints[self._current_id] self._nodes_num = len({x.split(':')[0] for x in self._worker_endpoints}) - def _user_defined_collective_env(self): + def _user_defined_collective_env(self) -> None: self._worker_endpoints = self._kwargs.get("worker_endpoints") self._current_id = self._kwargs.get("current_id") self._trainers_num = len(self._worker_endpoints) self._training_role = Role.WORKER self._nodes_num = len({x.split(':')[0] for x in self._worker_endpoints}) - def _generate_role(self): + def _generate_role(self) -> None: """ generate role for role maker """ diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index 8105e2672c87f..b1d1e8f885a00 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -11,17 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import collections import os from functools import reduce from itertools import product +from typing import TYPE_CHECKING, Any, Literal import paddle from paddle.distributed.utils.nccl_utils import check_nccl_version_for_p2p from ..utils.log_util import logger +if TYPE_CHECKING: + from paddle.base.core import task + from paddle.distributed.collective import Group + __all__ = ['CommunicateTopology', 'HybridCommunicateGroup'] _HYBRID_PARALLEL_GROUP = None @@ -65,9 +71,15 @@ class ParallelMode: class CommunicateTopology: def __init__( self, - hybrid_group_names=["data", "pipe", "sharding", "sep", "model"], - dims=[1, 1, 1, 1, 1], - ): + hybrid_group_names: list[str] = [ + "data", + "pipe", + "sharding", + "sep", + "model", + ], + dims: list[int] = [1, 1, 1, 1, 1], + ) -> None: self._parallel_names = hybrid_group_names self._dims = dims self.coordinate = collections.namedtuple( @@ -83,27 +95,27 @@ def __init__( zip(self._coord2rank.values(), self._coord2rank.keys()) ) - def get_hybrid_group_names(self): + def get_hybrid_group_names(self) -> list[str]: return self._parallel_names - def get_dim(self, axis_name): + def get_dim(self, axis_name: str) -> list[str]: return self._dims[self._parallel_names.index(axis_name)] - def world_size(self): + def world_size(self) -> int: return self._world_size - def get_rank(self, **args): + def get_rank(self, **args: Any) -> dict[list[str], int]: assert len(args) == len(self._dims) key = self.coordinate(**args) assert key in self._coord2rank.keys() return self._coord2rank[key] - def get_coord(self, rank): + def get_coord(self, rank: str) -> int | list[str] | None: assert rank < self._world_size assert rank in self._rank2coord.keys() return self._rank2coord[rank] - def get_axis_list(self, axis_name, index): + def get_axis_list(self, axis_name: str, index: str) -> list[int]: axis = self._parallel_names.index(axis_name) ranks = [ self._coord2rank[coord] @@ -113,11 +125,11 @@ def get_axis_list(self, axis_name, index): ranks.sort() return ranks - def get_dim_size(self, axis_name): + def get_dim_size(self, axis_name: str) -> int: assert axis_name in self._parallel_names return self._dims[self._parallel_names.index(axis_name)] - def get_fused_ranks(self, fused_axis): + def get_fused_ranks(self, fused_axis: str) -> list[int]: non_fused_axis = list(set(self._parallel_names).difference(fused_axis)) non_fused_ranges = [] for axis_name in non_fused_axis: @@ -144,7 +156,7 @@ def get_fused_ranks(self, fused_axis): return rank_list - def get_comm_list(self, axis_name): + def get_comm_list(self, axis_name: str) -> list[list[int | list[str]]]: assert axis_name in self._parallel_names other_axis_names = [ name for name in self._parallel_names if name != axis_name @@ -169,14 +181,16 @@ def get_comm_list(self, axis_name): return all_result - def get_rank_from_stage(self, global_rank, **kwargs): + def get_rank_from_stage( + self, global_rank: int | list[str] | None, **kwargs: Any + ) -> task: coord = self.get_coord(global_rank) tf = coord._replace(**kwargs)._asdict() return self.get_rank(**tf) class HybridCommunicateGroup: - def __init__(self, topology): + def __init__(self, topology: dict[int]) -> None: self.nranks = paddle.distributed.get_world_size() self.global_rank = paddle.distributed.get_rank() self._topo = topology @@ -281,7 +295,7 @@ def __init__(self, topology): global _HYBRID_PARALLEL_GROUP _HYBRID_PARALLEL_GROUP = self - def get_parallel_mode(self): + def get_parallel_mode(self) -> Literal[0, 1, 2, 3, 4]: # there are five modes : DataParallel / TensorParallel / PipelineParallel / ShardingParallel / SepParallel # NOTE when sharding conjugates with other parallel, sharding should act like a optimizer and # adding its parallel logic within that parallelism @@ -319,7 +333,7 @@ def get_parallel_mode(self): # pp may coexist with mp、sep、dp and sharding return ParallelMode.PIPELINE_PARALLEL - def _check_valid_topo(self): + def _check_valid_topo(self) -> bool: return ( self._dp_degree * self._mp_degree @@ -329,10 +343,12 @@ def _check_valid_topo(self): == self.nranks ) - def _check_sep_exist(self): + def _check_sep_exist(self) -> None: assert self._sep_degree > 1, "sep not exist" - def _set_comm_group(self, parallel_method="data"): + def _set_comm_group( + self, parallel_method: str = "data" + ) -> None | tuple[list[int], Group]: parallel_group = [] parallel_comm_group = None parallel_groups = self._topo.get_comm_list(parallel_method) @@ -359,7 +375,9 @@ def _set_comm_group(self, parallel_method="data"): ) return parallel_group, parallel_comm_group - def _set_check_group(self, parallel_method="data"): + def _set_check_group( + self, parallel_method: str = "data" + ) -> tuple[list[int], Group] | None: parallel_group = [] parallel_comm_group = None parallel_size = self._topo.get_dim(parallel_method) @@ -375,15 +393,15 @@ def _set_check_group(self, parallel_method="data"): return parallel_group, parallel_comm_group - def _get_p2p_next_rank(self): + def _get_p2p_next_rank(self) -> int | None: assert hasattr(self, 'next_rank'), "next_rank has not been inited" return self.next_rank - def _get_p2p_prev_rank(self): + def _get_p2p_prev_rank(self) -> int | None: assert hasattr(self, 'prev_rank'), "prev_rank has not been inited" return self.prev_rank - def _set_p2p_prev_next(self): + def _set_p2p_prev_next(self) -> None: comm_lists = self._topo.get_comm_list('pipe') for comm_ranks in comm_lists: @@ -397,7 +415,7 @@ def _set_p2p_prev_next(self): self.next_rank = next_rank self.prev_rank = prev_rank - def _set_four_directions_p2p_group(self): + def _set_four_directions_p2p_group(self) -> None: comm_lists = self._topo.get_comm_list('pipe') self.send_next_group = None @@ -434,75 +452,75 @@ def _set_four_directions_p2p_group(self): assert self.recv_next_group is not None assert self.recv_prev_group is not None - def topology(self): + def topology(self) -> dict[int]: return self._topo - def get_global_rank(self): + def get_global_rank(self) -> int: return self.global_rank # data parallel message: - def _get_data_parallel_id(self): + def _get_data_parallel_id(self) -> int: return self._topo.get_coord(self.global_rank).data - def get_data_parallel_rank(self): + def get_data_parallel_rank(self) -> int: return self._data_parallel_id - def get_data_parallel_world_size(self): + def get_data_parallel_world_size(self) -> int: return self._dp_degree - def get_data_parallel_group(self): + def get_data_parallel_group(self) -> list[int]: return self._dp_comm_group - def get_data_parallel_group_src_rank(self): + def get_data_parallel_group_src_rank(self) -> int: return self._dp_comm_group.ranks[0] # model parallel message: - def _get_model_parallel_id(self): + def _get_model_parallel_id(self) -> str: return self._topo.get_coord(self.global_rank).model - def get_model_parallel_rank(self): + def get_model_parallel_rank(self) -> int: return self._model_parallel_id - def get_model_parallel_world_size(self): + def get_model_parallel_world_size(self) -> int: return self._mp_degree - def get_model_parallel_group(self): + def get_model_parallel_group(self) -> Group: return self._mp_comm_group - def get_model_parallel_group_src_rank(self): + def get_model_parallel_group_src_rank(self) -> int: return self._mp_comm_group.ranks[0] # pipeline parallel message - def _get_pipe_parallel_id(self): + def _get_pipe_parallel_id(self) -> int: return self._topo.get_coord(self.global_rank).pipe - def get_stage_id(self): + def get_stage_id(self) -> int: return self.stage_id - def get_pipe_parallel_world_size(self): + def get_pipe_parallel_world_size(self) -> int: return self._pp_degree - def _get_sep_parallel_id(self): + def _get_sep_parallel_id(self) -> int: return self._topo.get_coord(self.global_rank).sep - def get_sep_parallel_rank(self): + def get_sep_parallel_rank(self) -> int: return self._sep_parallel_id - def get_sep_parallel_world_size(self): + def get_sep_parallel_world_size(self) -> int: return self._sep_degree - def get_sep_parallel_group(self): + def get_sep_parallel_group(self) -> Group: self._check_sep_exist() return self._sep_comm_group - def get_sep_parallel_group_src_rank(self): + def get_sep_parallel_group_src_rank(self) -> int: self._check_sep_exist() return self._sep_comm_group.ranks[0] - def get_pipe_parallel_group(self): + def get_pipe_parallel_group(self) -> Group: return self._pp_comm_group - def get_p2p_groups(self): + def get_p2p_groups(self) -> tuple[Group, Group, Group, Group]: assert ( _use_four_directions ), "If you want to use four directions p2p group, set the environment variable PADDLE_USE_FOUR_DIRECTIONS_P2P to True." @@ -514,44 +532,48 @@ def get_p2p_groups(self): ) # sharding parallel message: - def _get_sharding_parallel_id(self): + def _get_sharding_parallel_id(self) -> int: return self._topo.get_coord(self.global_rank).sharding - def get_sharding_parallel_rank(self): + def get_sharding_parallel_rank(self) -> int: return self._sharding_parallel_id - def get_sharding_parallel_world_size(self): + def get_sharding_parallel_world_size(self) -> int: return self._sharding_degree - def get_sharding_parallel_group(self): + def get_sharding_parallel_group(self) -> Group: return self._sharding_comm_group - def get_sharding_parallel_group_src_rank(self): + def get_sharding_parallel_group_src_rank(self) -> int: # TODO should the src rank related to the shard rank for each parameter ? return self._sharding_comm_group.ranks[0] # check parallel group - def get_check_parallel_group(self, sharding=False): + def get_check_parallel_group( + self, sharding: bool = False + ) -> list[int] | Group: if sharding: return self.sharding_check_comm_group else: return self._check_comm_group - def get_rank_from_stage(self, stage_id, **kwargs): + def get_rank_from_stage(self, stage_id: int, **kwargs: Any) -> int: return self._topo.get_rank_from_stage( self.global_rank, pipe=stage_id, **kwargs ) # fuse comm group message - def get_dp_sep_parallel_group(self): + def get_dp_sep_parallel_group(self) -> list[int]: self._check_sep_exist() return self._dp_sep_comm_group - def get_pp_mp_parallel_group(self): + def get_pp_mp_parallel_group(self) -> list[int]: self._check_sep_exist() return self._pp_mp_comm_group - def create_fuse_group(self, fused_strategy_list): + def create_fuse_group( + self, fused_strategy_list: list[int] + ) -> list[int] | int: assert ( len(fused_strategy_list) > 0 ), "the length of fused_strategy_list must be greater than 0." From 321614d2eb531e0542583c7ad17e896ec9d6167e Mon Sep 17 00:00:00 2001 From: Whsjrczr Date: Wed, 14 Aug 2024 18:02:33 +0800 Subject: [PATCH 04/10] undo --- .../infer_symbolic_shape/binary_infer_sym.cc | 62 ++----------------- .../multiary_infer_sym.cc | 59 ++---------------- .../infer_symbolic_shape/multiary_infer_sym.h | 2 +- paddle/phi/ops/yaml/ops.yaml | 2 - 4 files changed, 12 insertions(+), 113 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index 567211f1c82ab..fa178b03467e2 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -207,63 +207,11 @@ bool Binomial_OpInferSymbolicShape( // return true; // } -bool BmmOpInferSymbolicShape(pir::Operation *op, - pir::InferSymbolicShapeContext *infer_context) { - const auto &x_shape_or_data = - infer_context->GetShapeOrDataForValue(op->operand_source(0)); - const auto &y_shape_or_data = - infer_context->GetShapeOrDataForValue(op->operand_source(1)); - - const std::vector &x_dims = x_shape_or_data.shape(); - const std::vector &y_dims = y_shape_or_data.shape(); - std::size_t x_ndims = x_dims.size(); - std::size_t y_ndims = y_dims.size(); - - PADDLE_ENFORCE_EQ( - x_ndims, - 3, - common::errors::InvalidArgument("Input(X) of BmmOp must be 3-dimensional " - "in BmmOp, but received X's shape: [%d].", - x_ndims)); - PADDLE_ENFORCE_EQ( - y_ndims, - 3, - common::errors::InvalidArgument("Input(Y) of BmmOp must be 3-dimensional " - "in BmmOp, but received Y's shape: [%d].", - y_ndims)); - - auto cal_shape_fn = [](const symbol::DimExpr &x, - const symbol::DimExpr &y, - const std::string &error_str) -> symbol::DimExpr { - if (x.is_dynamic()) { - return y; - } else if (y.is_dynamic()) { - return x; - } - PADDLE_ENFORCE_EQ(x, y, common::errors::InvalidArgument(error_str, x, y)); - return x; - }; - - cal_shape_fn(x_dims[2], - y_dims[1], - "Input(X)'s width must be equal with Input(Y)'s height in " - "BmmOp, but receive X's width: [%d], Y's height: [%d]."); - symbol::DimExpr batch_size = cal_shape_fn( - x_dims[0], - y_dims[0], - "Input(X) and Input(Y) must have the same batch size in BmmOp, but " - "received X's batch size: [%d], Y's batch size [%d]"); - symbol::DimExpr out_height = x_dims[1]; - symbol::DimExpr out_width = y_dims[2]; - - std::vector out_dims = {batch_size, out_height, out_width}; - - infer_context->SetShapeOrDataForValue( - op->result(0), - symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)}); - - return true; -} +// bool BmmOpInferSymbolicShape(pir::Operation *op, +// pir::InferSymbolicShapeContext *infer_context) { +// // pass +// return true; +// } // bool CholeskySolveOpInferSymbolicShape(pir::Operation *op, // pir::InferSymbolicShapeContext diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index ccc1e3ac1c550..894678f2c4d1e 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -622,59 +622,12 @@ bool BilinearOpInferSymbolicShape( // return true; // } -bool BroadcastTensorsOpInferSymbolicShape( - pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { - std::vector> input_shapes; - for (size_t i = 0; i < op->num_operands(); ++i) { - const auto &input_shape_or_data = - infer_context->GetShapeOrDataForValue(op->operand_source(i)); - input_shapes.push_back(input_shape_or_data.shape()); - } - - int target_rank = 0; - - // 1. Find Output rank = max(Inputs rank) - for (const auto &input_shape : input_shapes) { - target_rank = std::max(target_rank, static_cast(input_shape.size())); - } - - std::vector target_dims(target_rank, symbol::DimExpr(0)); - - // 2. Output dim(axis=x) = max(Inputs dim(axis=x)) - for (int index = 0; index < target_rank; ++index) { - symbol::DimExpr target_dim_size(1); - for (const auto &input_shape : input_shapes) { - int axis = static_cast(input_shape.size()) - index - 1; - symbol::DimExpr dim_size(1); - if (axis >= 0) { - dim_size = input_shape[axis]; - } - - if (target_dim_size != dim_size && dim_size != 1 && - target_dim_size != 1) { - PADDLE_THROW(errors::InvalidArgument( - "BroadcastTensorsOp inputs do not satisfy broadcast semantics, " - "please check axis = %d in reverse order", - index)); - } - - if (dim_size != 1) { - target_dim_size = dim_size; - } - } - target_dims[target_rank - index - 1] = target_dim_size; - } - - // 3. Set Output Dim - for (size_t i = 0; i < op->num_results(); ++i) { - infer_context->SetShapeOrDataForValue( - op->result(i), - symbol::ShapeOrDataDimExprs{ - symbol::TensorShapeOrDataDimExprs(target_dims)}); - } - - return true; -} +// bool BroadcastTensorsOpInferSymbolicShape(pir::Operation *op, +// pir::InferSymbolicShapeContext +// *infer_context) { +// // pass +// return true; +// } bool BilinearInterpOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h index f4857392c6a42..095590eca991d 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h @@ -24,7 +24,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Addmm_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(AddN) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Auc) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(AssignPos) -OP_DECLARE_INFER_SYMBOLIC_SHAPE(BroadcastTensors) +// OP_DECLARE_INFER_SYMBOLIC_SHAPE(BroadcastTensors) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchFc) OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchNorm) OP_DECLARE_INFER_SYMBOLIC_SHAPE(BatchNorm_) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 2cf0640f3a6f0..683800a5cdcfc 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -644,7 +644,6 @@ kernel : func : bmm backward : bmm_grad - interfaces : paddle::dialect::InferSymbolicShapeInterface - op : box_clip args: (Tensor input, Tensor im_info) @@ -673,7 +672,6 @@ func: broadcast_tensors data_type : input backward: broadcast_tensors_grad - interfaces : paddle::dialect::InferSymbolicShapeInterface - op : c_allgather args : (Tensor x, int ring_id, int nranks, bool use_calc_stream) From 47211e550864141f97eff9c23a1b215b56afa76e Mon Sep 17 00:00:00 2001 From: Whsjrczr Date: Wed, 14 Aug 2024 18:03:17 +0800 Subject: [PATCH 05/10] undo --- .../operator/interface/infer_symbolic_shape/binary_infer_sym.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h index e7e3d9afd8b2e..418cac250ff5b 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h @@ -26,7 +26,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(BoxClip) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Binomial) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Binomial_) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(Bincount) -OP_DECLARE_INFER_SYMBOLIC_SHAPE(Bmm) +// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Bmm) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(CholeskySolve) OP_DECLARE_INFER_SYMBOLIC_SHAPE(CtcAlign) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Conv2d) From 1c8ed16e2006c65920ab6a96b4f220dfd4003e09 Mon Sep 17 00:00:00 2001 From: Whsjrczr Date: Thu, 15 Aug 2024 14:05:34 +0800 Subject: [PATCH 06/10] update --- .../distributed/fleet/base/role_maker.py | 45 +++++++++++-------- .../paddle/distributed/fleet/base/topology.py | 43 ++++++++---------- 2 files changed, 45 insertions(+), 43 deletions(-) diff --git a/python/paddle/distributed/fleet/base/role_maker.py b/python/paddle/distributed/fleet/base/role_maker.py index dab5a5ccd0d1b..5ae38d1a2c0e6 100755 --- a/python/paddle/distributed/fleet/base/role_maker.py +++ b/python/paddle/distributed/fleet/base/role_maker.py @@ -19,7 +19,7 @@ import time import warnings from multiprocessing import Manager, Process -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar, Literal import numpy as np @@ -38,11 +38,11 @@ class Role: - WORKER = 1 - SERVER = 2 - HETER_WORKER = 3 - ALL = 4 - COORDINATOR = 5 + WORKER: ClassVar[Literal[1]] = 1 + SERVER: ClassVar[Literal[2]] = 2 + HETER_WORKER: ClassVar[Literal[3]] = 3 + ALL: ClassVar[Literal[4]] = 4 + COORDINATOR: ClassVar[Literal[5]] = 5 class Gloo: @@ -598,7 +598,9 @@ def __init__(self, is_collective: bool = False, **kwargs: Any) -> None: def _barrier(self, comm_world: str) -> None: self._gloo.barrier(comm_world) - def _all_gather(self, input: Any, comm_world="worker") -> task: + def _all_gather( + self, input: Any, comm_world: str = "worker" + ) -> np.NDArray[Any]: return self._gloo.all_gather(input, comm_world) def _all_reduce( @@ -622,7 +624,7 @@ def _heter_device_type(self) -> str: self._generate_role() return self._heter_trainer_device_type - def _get_stage_id(self) -> int | str: + def _get_stage_id(self) -> int: """ return stage id of current heter worker """ @@ -630,7 +632,7 @@ def _get_stage_id(self) -> int | str: self._generate_role() return self._stage_id - def _get_stage_trainers(self) -> list: + def _get_stage_trainers(self) -> list[int]: """ return trainer num of all stages """ @@ -675,7 +677,7 @@ def _is_first_worker(self) -> bool: self._generate_role() return self._role == Role.WORKER and self._current_id == 0 - def _worker_index(self) -> int | str: + def _worker_index(self) -> int: """ get index of current worker """ @@ -683,7 +685,7 @@ def _worker_index(self) -> int | str: self._generate_role() return self._current_id - def _server_index(self) -> int | str: + def _server_index(self) -> int: """ get index of current server """ @@ -691,7 +693,7 @@ def _server_index(self) -> int | str: self._generate_role() return self._current_id - def _role_id(self) -> int | str: + def _role_id(self) -> int: """ get index of current node """ @@ -750,7 +752,7 @@ def _get_world_device_ids(self) -> str | None: self._generate_role() return self._world_device_ids - def _get_trainer_endpoints(self) -> str | list[str]: + def _get_trainer_endpoints(self) -> list[str]: """ get endpoint of all trainers """ @@ -758,7 +760,7 @@ def _get_trainer_endpoints(self) -> str | list[str]: self._generate_role() return self._worker_endpoints - def _get_trainer_endpoint(self) -> str | None | list[str]: + def _get_trainer_endpoint(self) -> str: if not self._role_is_generated: self._generate_role() assert ( @@ -766,7 +768,7 @@ def _get_trainer_endpoint(self) -> str | None | list[str]: ), "get_trainer_endpoint should be called by trainer" return self._cur_endpoint - def _get_heter_worker_endpoints(self) -> list[str] | None: + def _get_heter_worker_endpoints(self) -> list[str]: """ Returns: string: all heter_trainers'endpoints @@ -778,7 +780,7 @@ def _get_heter_worker_endpoints(self) -> list[str] | None: ), "Heter Worker Endpoints Not initialized" return self._heter_trainer_endpoints - def _get_heter_worker_endpoint(self) -> str | None: + def _get_heter_worker_endpoint(self) -> str: """ Returns: str: corresponding heter_trainer's endpoint @@ -790,7 +792,7 @@ def _get_heter_worker_endpoint(self) -> str | None: ), "_get_heter_worker_endpoint should be invoked by heter worker" return self._cur_endpoint - def _get_pserver_endpoints(self) -> str | list[str] | None: + def _get_pserver_endpoints(self) -> list[str]: """ get endpoint of all pservers """ @@ -798,7 +800,7 @@ def _get_pserver_endpoints(self) -> str | list[str] | None: self._generate_role() return self._server_endpoints - def _get_coordinator_endpoints(self) -> str | list[str] | None: + def _get_coordinator_endpoints(self) -> list[str]: if not self._role_is_generated: self._generate_role() return self._coordinator_endpoints @@ -1225,7 +1227,12 @@ class UserDefinedRoleMaker(PaddleCloudRoleMaker): ... server_endpoints=["127.0.0.1:36011", "127.0.0.1:36012"]) """ - def __init__(self, is_collective=False, init_gloo=False, **kwargs): + def __init__( + self, + is_collective: bool = False, + init_gloo: bool = False, + **kwargs: Any, + ) -> None: super().__init__( is_collective=is_collective, init_gloo=init_gloo, **kwargs ) diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index b1d1e8f885a00..1ff54f383b019 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -25,7 +25,6 @@ from ..utils.log_util import logger if TYPE_CHECKING: - from paddle.base.core import task from paddle.distributed.collective import Group __all__ = ['CommunicateTopology', 'HybridCommunicateGroup'] @@ -98,24 +97,24 @@ def __init__( def get_hybrid_group_names(self) -> list[str]: return self._parallel_names - def get_dim(self, axis_name: str) -> list[str]: + def get_dim(self, axis_name: str) -> int: return self._dims[self._parallel_names.index(axis_name)] def world_size(self) -> int: return self._world_size - def get_rank(self, **args: Any) -> dict[list[str], int]: + def get_rank(self, **args: Any) -> int: assert len(args) == len(self._dims) key = self.coordinate(**args) assert key in self._coord2rank.keys() return self._coord2rank[key] - def get_coord(self, rank: str) -> int | list[str] | None: + def get_coord(self, rank: int) -> Any: assert rank < self._world_size assert rank in self._rank2coord.keys() return self._rank2coord[rank] - def get_axis_list(self, axis_name: str, index: str) -> list[int]: + def get_axis_list(self, axis_name: str, index: int) -> list[int]: axis = self._parallel_names.index(axis_name) ranks = [ self._coord2rank[coord] @@ -129,7 +128,7 @@ def get_dim_size(self, axis_name: str) -> int: assert axis_name in self._parallel_names return self._dims[self._parallel_names.index(axis_name)] - def get_fused_ranks(self, fused_axis: str) -> list[int]: + def get_fused_ranks(self, fused_axis: list[int]) -> list[list[int]]: non_fused_axis = list(set(self._parallel_names).difference(fused_axis)) non_fused_ranges = [] for axis_name in non_fused_axis: @@ -156,7 +155,7 @@ def get_fused_ranks(self, fused_axis: str) -> list[int]: return rank_list - def get_comm_list(self, axis_name: str) -> list[list[int | list[str]]]: + def get_comm_list(self, axis_name: str) -> list[list[int]]: assert axis_name in self._parallel_names other_axis_names = [ name for name in self._parallel_names if name != axis_name @@ -181,16 +180,14 @@ def get_comm_list(self, axis_name: str) -> list[list[int | list[str]]]: return all_result - def get_rank_from_stage( - self, global_rank: int | list[str] | None, **kwargs: Any - ) -> task: + def get_rank_from_stage(self, global_rank: int, **kwargs: Any) -> int: coord = self.get_coord(global_rank) tf = coord._replace(**kwargs)._asdict() return self.get_rank(**tf) class HybridCommunicateGroup: - def __init__(self, topology: dict[int]) -> None: + def __init__(self, topology: CommunicateTopology) -> None: self.nranks = paddle.distributed.get_world_size() self.global_rank = paddle.distributed.get_rank() self._topo = topology @@ -348,7 +345,7 @@ def _check_sep_exist(self) -> None: def _set_comm_group( self, parallel_method: str = "data" - ) -> None | tuple[list[int], Group]: + ) -> tuple[list[int], Group]: parallel_group = [] parallel_comm_group = None parallel_groups = self._topo.get_comm_list(parallel_method) @@ -377,7 +374,7 @@ def _set_comm_group( def _set_check_group( self, parallel_method: str = "data" - ) -> tuple[list[int], Group] | None: + ) -> tuple[list[int], Group]: parallel_group = [] parallel_comm_group = None parallel_size = self._topo.get_dim(parallel_method) @@ -393,11 +390,11 @@ def _set_check_group( return parallel_group, parallel_comm_group - def _get_p2p_next_rank(self) -> int | None: + def _get_p2p_next_rank(self) -> int: assert hasattr(self, 'next_rank'), "next_rank has not been inited" return self.next_rank - def _get_p2p_prev_rank(self) -> int | None: + def _get_p2p_prev_rank(self) -> int: assert hasattr(self, 'prev_rank'), "prev_rank has not been inited" return self.prev_rank @@ -452,7 +449,7 @@ def _set_four_directions_p2p_group(self) -> None: assert self.recv_next_group is not None assert self.recv_prev_group is not None - def topology(self) -> dict[int]: + def topology(self) -> CommunicateTopology: return self._topo def get_global_rank(self) -> int: @@ -468,7 +465,7 @@ def get_data_parallel_rank(self) -> int: def get_data_parallel_world_size(self) -> int: return self._dp_degree - def get_data_parallel_group(self) -> list[int]: + def get_data_parallel_group(self) -> Group: return self._dp_comm_group def get_data_parallel_group_src_rank(self) -> int: @@ -549,9 +546,7 @@ def get_sharding_parallel_group_src_rank(self) -> int: return self._sharding_comm_group.ranks[0] # check parallel group - def get_check_parallel_group( - self, sharding: bool = False - ) -> list[int] | Group: + def get_check_parallel_group(self, sharding: bool = False) -> Group: if sharding: return self.sharding_check_comm_group else: @@ -563,17 +558,17 @@ def get_rank_from_stage(self, stage_id: int, **kwargs: Any) -> int: ) # fuse comm group message - def get_dp_sep_parallel_group(self) -> list[int]: + def get_dp_sep_parallel_group(self) -> Group: self._check_sep_exist() return self._dp_sep_comm_group - def get_pp_mp_parallel_group(self) -> list[int]: + def get_pp_mp_parallel_group(self) -> Group: self._check_sep_exist() return self._pp_mp_comm_group def create_fuse_group( - self, fused_strategy_list: list[int] - ) -> list[int] | int: + self, fused_strategy_list: list[str] + ) -> tuple[list[list[int]], list[Group]] | tuple[list[int], Group]: assert ( len(fused_strategy_list) > 0 ), "the length of fused_strategy_list must be greater than 0." From 4d9ad68978c05e95247181ed990c2d0d0604740e Mon Sep 17 00:00:00 2001 From: Whsjrczr Date: Fri, 16 Aug 2024 12:45:25 +0800 Subject: [PATCH 07/10] fixed 2 --- python/paddle/distributed/fleet/base/role_maker.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/python/paddle/distributed/fleet/base/role_maker.py b/python/paddle/distributed/fleet/base/role_maker.py index 5ae38d1a2c0e6..84cfc370f8949 100755 --- a/python/paddle/distributed/fleet/base/role_maker.py +++ b/python/paddle/distributed/fleet/base/role_maker.py @@ -19,7 +19,7 @@ import time import warnings from multiprocessing import Manager, Process -from typing import TYPE_CHECKING, Any, ClassVar, Literal +from typing import Any, ClassVar, Literal import numpy as np @@ -31,9 +31,6 @@ from ...backup_env import getenv_or_backup -if TYPE_CHECKING: - from paddle.base.core import task - __all__ = [] @@ -600,12 +597,12 @@ def _barrier(self, comm_world: str) -> None: def _all_gather( self, input: Any, comm_world: str = "worker" - ) -> np.NDArray[Any]: + ) -> list[float]: return self._gloo.all_gather(input, comm_world) def _all_reduce( self, input: Any, mode: str = "sum", comm_world: str = "worker" - ) -> task: + ) -> np.ndarray[Any]: return self._gloo.all_reduce(input, mode, comm_world) def _heter_device(self) -> str: From 14a96b9df706ba077ee8514e84376b2d4bf4b5e9 Mon Sep 17 00:00:00 2001 From: Whsjrczr <123729598+Whsjrczr@users.noreply.github.com> Date: Fri, 16 Aug 2024 15:23:51 +0800 Subject: [PATCH 08/10] Update python/paddle/distributed/fleet/base/role_maker.py Co-authored-by: megemini --- python/paddle/distributed/fleet/base/role_maker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/fleet/base/role_maker.py b/python/paddle/distributed/fleet/base/role_maker.py index 84cfc370f8949..c16d82c1a3865 100755 --- a/python/paddle/distributed/fleet/base/role_maker.py +++ b/python/paddle/distributed/fleet/base/role_maker.py @@ -602,7 +602,7 @@ def _all_gather( def _all_reduce( self, input: Any, mode: str = "sum", comm_world: str = "worker" - ) -> np.ndarray[Any]: + ) -> npt.NDArray[Any]: return self._gloo.all_reduce(input, mode, comm_world) def _heter_device(self) -> str: From 91e35ecf341bf56050505bbb786ad38d933e64e9 Mon Sep 17 00:00:00 2001 From: Whsjrczr Date: Fri, 16 Aug 2024 16:34:54 +0800 Subject: [PATCH 09/10] npt --- python/paddle/distributed/fleet/base/role_maker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/distributed/fleet/base/role_maker.py b/python/paddle/distributed/fleet/base/role_maker.py index c16d82c1a3865..ca66c385c3dee 100755 --- a/python/paddle/distributed/fleet/base/role_maker.py +++ b/python/paddle/distributed/fleet/base/role_maker.py @@ -22,6 +22,7 @@ from typing import Any, ClassVar, Literal import numpy as np +import numpy.typing as npt import paddle from paddle.base import core From b184de5b12d4d8fc9e8f796915859f4d428b6b7a Mon Sep 17 00:00:00 2001 From: Whsjrczr <123729598+Whsjrczr@users.noreply.github.com> Date: Mon, 19 Aug 2024 21:05:25 +0800 Subject: [PATCH 10/10] Update typechecking --- python/paddle/distributed/fleet/base/role_maker.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/paddle/distributed/fleet/base/role_maker.py b/python/paddle/distributed/fleet/base/role_maker.py index ca66c385c3dee..f79dd4c11bdd6 100755 --- a/python/paddle/distributed/fleet/base/role_maker.py +++ b/python/paddle/distributed/fleet/base/role_maker.py @@ -19,10 +19,9 @@ import time import warnings from multiprocessing import Manager, Process -from typing import Any, ClassVar, Literal +from typing import TYPE_CHECKING, Any, ClassVar, Literal import numpy as np -import numpy.typing as npt import paddle from paddle.base import core @@ -32,6 +31,9 @@ from ...backup_env import getenv_or_backup +if TYPE_CHECKING: + import numpy.typing as npt + __all__ = []