From b07423c84b7324abc7fe1a1ac123f4616ea01bfd Mon Sep 17 00:00:00 2001 From: Whsjrczr <123729598+Whsjrczr@users.noreply.github.com> Date: Tue, 20 Aug 2024 02:32:37 +0800 Subject: [PATCH] [Typing][C-41, C-42][BUAA] Add type annotations for `python/paddle/distributed/fleet/base/*` (#67439) --------- Co-authored-by: megemini --- .../distributed/fleet/base/role_maker.py | 109 ++++++++------- .../paddle/distributed/fleet/base/topology.py | 127 ++++++++++-------- 2 files changed, 134 insertions(+), 102 deletions(-) diff --git a/python/paddle/distributed/fleet/base/role_maker.py b/python/paddle/distributed/fleet/base/role_maker.py index c7ead667d9f97..f79dd4c11bdd6 100755 --- a/python/paddle/distributed/fleet/base/role_maker.py +++ b/python/paddle/distributed/fleet/base/role_maker.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """Definition of Role Makers.""" +from __future__ import annotations + import os import re import time import warnings from multiprocessing import Manager, Process +from typing import TYPE_CHECKING, Any, ClassVar, Literal import numpy as np @@ -28,15 +31,18 @@ from ...backup_env import getenv_or_backup +if TYPE_CHECKING: + import numpy.typing as npt + __all__ = [] class Role: - WORKER = 1 - SERVER = 2 - HETER_WORKER = 3 - ALL = 4 - COORDINATOR = 5 + WORKER: ClassVar[Literal[1]] = 1 + SERVER: ClassVar[Literal[2]] = 2 + HETER_WORKER: ClassVar[Literal[3]] = 3 + ALL: ClassVar[Literal[4]] = 4 + COORDINATOR: ClassVar[Literal[5]] = 5 class Gloo: @@ -563,7 +569,7 @@ class PaddleCloudRoleMaker(RoleMakerBase): """ - def __init__(self, is_collective=False, **kwargs): + def __init__(self, is_collective: bool = False, **kwargs: Any) -> None: super().__init__() self._is_collective = is_collective self._non_distributed = False @@ -589,16 +595,20 @@ def __init__(self, is_collective=False, **kwargs): self._gloo = Gloo() # gloo instance - def _barrier(self, comm_world): + def _barrier(self, comm_world: str) -> None: self._gloo.barrier(comm_world) - def _all_gather(self, input, comm_world="worker"): + def _all_gather( + self, input: Any, comm_world: str = "worker" + ) -> list[float]: return self._gloo.all_gather(input, comm_world) - def _all_reduce(self, input, mode="sum", comm_world="worker"): + def _all_reduce( + self, input: Any, mode: str = "sum", comm_world: str = "worker" + ) -> npt.NDArray[Any]: return self._gloo.all_reduce(input, mode, comm_world) - def _heter_device(self): + def _heter_device(self) -> str: """ return the heter device that current heter worker is using """ @@ -606,7 +616,7 @@ def _heter_device(self): self._generate_role() return self._heter_trainer_device - def _heter_device_type(self): + def _heter_device_type(self) -> str: """ return the heter device type that current heter worker is using """ @@ -614,7 +624,7 @@ def _heter_device_type(self): self._generate_role() return self._heter_trainer_device_type - def _get_stage_id(self): + def _get_stage_id(self) -> int: """ return stage id of current heter worker """ @@ -622,7 +632,7 @@ def _get_stage_id(self): self._generate_role() return self._stage_id - def _get_stage_trainers(self): + def _get_stage_trainers(self) -> list[int]: """ return trainer num of all stages """ @@ -630,7 +640,7 @@ def _get_stage_trainers(self): self._generate_role() return self._stage_trainers - def _get_num_stage(self): + def _get_num_stage(self) -> int: """ return stage num """ @@ -638,7 +648,7 @@ def _get_num_stage(self): self._generate_role() return self._stage_num - def _is_worker(self): + def _is_worker(self) -> bool: """ whether current process is worker """ @@ -646,7 +656,7 @@ def _is_worker(self): self._generate_role() return self._role == Role.WORKER - def _is_server(self): + def _is_server(self) -> bool: """ whether current process is server """ @@ -654,12 +664,12 @@ def _is_server(self): self._generate_role() return self._role == Role.SERVER - def _is_coordinator(self): + def _is_coordinator(self) -> bool: if not self._role_is_generated: self._generate_role() return self._role == Role.COORDINATOR - def _is_first_worker(self): + def _is_first_worker(self) -> bool: """ whether current process is worker of rank 0 """ @@ -667,7 +677,7 @@ def _is_first_worker(self): self._generate_role() return self._role == Role.WORKER and self._current_id == 0 - def _worker_index(self): + def _worker_index(self) -> int: """ get index of current worker """ @@ -675,7 +685,7 @@ def _worker_index(self): self._generate_role() return self._current_id - def _server_index(self): + def _server_index(self) -> int: """ get index of current server """ @@ -683,7 +693,7 @@ def _server_index(self): self._generate_role() return self._current_id - def _role_id(self): + def _role_id(self) -> int: """ get index of current node """ @@ -691,7 +701,7 @@ def _role_id(self): self._generate_role() return self._current_id - def _worker_num(self): + def _worker_num(self) -> int: """ return the current number of worker """ @@ -699,7 +709,7 @@ def _worker_num(self): self._generate_role() return self._trainers_num - def _server_num(self): + def _server_num(self) -> int: """ return the current number of server """ @@ -711,7 +721,7 @@ def _server_num(self): else 0 ) - def _node_num(self): + def _node_num(self) -> int: """ return the training node number """ @@ -719,7 +729,7 @@ def _node_num(self): self._generate_role() return self._nodes_num - def _get_node_num(self): + def _get_node_num(self) -> int: """ return the training node number """ @@ -727,22 +737,22 @@ def _get_node_num(self): self._generate_role() return self._nodes_num - def _get_local_rank(self): + def _get_local_rank(self) -> str | None: if not self._role_is_generated: self._generate_role() return self._local_rank - def _get_local_device_ids(self): + def _get_local_device_ids(self) -> str | None: if not self._role_is_generated: self._generate_role() return self._local_device_ids - def _get_world_device_ids(self): + def _get_world_device_ids(self) -> str | None: if not self._role_is_generated: self._generate_role() return self._world_device_ids - def _get_trainer_endpoints(self): + def _get_trainer_endpoints(self) -> list[str]: """ get endpoint of all trainers """ @@ -750,7 +760,7 @@ def _get_trainer_endpoints(self): self._generate_role() return self._worker_endpoints - def _get_trainer_endpoint(self): + def _get_trainer_endpoint(self) -> str: if not self._role_is_generated: self._generate_role() assert ( @@ -758,7 +768,7 @@ def _get_trainer_endpoint(self): ), "get_trainer_endpoint should be called by trainer" return self._cur_endpoint - def _get_heter_worker_endpoints(self): + def _get_heter_worker_endpoints(self) -> list[str]: """ Returns: string: all heter_trainers'endpoints @@ -770,10 +780,10 @@ def _get_heter_worker_endpoints(self): ), "Heter Worker Endpoints Not initialized" return self._heter_trainer_endpoints - def _get_heter_worker_endpoint(self): + def _get_heter_worker_endpoint(self) -> str: """ Returns: - int: corresponding heter_trainer's endpoint + str: corresponding heter_trainer's endpoint """ if not self._role_is_generated: self._generate_role() @@ -782,7 +792,7 @@ def _get_heter_worker_endpoint(self): ), "_get_heter_worker_endpoint should be invoked by heter worker" return self._cur_endpoint - def _get_pserver_endpoints(self): + def _get_pserver_endpoints(self) -> list[str]: """ get endpoint of all pservers """ @@ -790,7 +800,7 @@ def _get_pserver_endpoints(self): self._generate_role() return self._server_endpoints - def _get_coordinator_endpoints(self): + def _get_coordinator_endpoints(self) -> list[str]: if not self._role_is_generated: self._generate_role() return self._coordinator_endpoints @@ -807,7 +817,7 @@ def _get_previous_trainers(self): ), "_get_previous_trainers should be invoked by trainer or heter worker" return self._previous_heter_trainer_endpoints - def _get_next_trainers(self): + def _get_next_trainers(self) -> list[str]: """ invoked by heter worker """ @@ -819,7 +829,7 @@ def _get_next_trainers(self): ), "_get_next_trainers should be invoked by trainer or heter worker" return self._next_heter_trainer_endpoints - def _is_non_distributed(self): + def _is_non_distributed(self) -> bool: """ Return True if indispensable environment for fleetrun is not found (use python-run to launch fleet-code directly) @@ -828,7 +838,7 @@ def _is_non_distributed(self): self._generate_role() return self._non_distributed - def _heter_worker_num(self): + def _heter_worker_num(self) -> int: """ get heter worker nums """ @@ -836,7 +846,7 @@ def _heter_worker_num(self): self._generate_role() return self._heter_trainers_num - def _is_heter_worker(self): + def _is_heter_worker(self) -> bool: """ whether current process is heter worker """ @@ -844,7 +854,7 @@ def _is_heter_worker(self): self._generate_role() return self._role == Role.HETER_WORKER - def _ps_env(self): # each role will execute it + def _ps_env(self) -> None: # each role will execute it # Environment variable PADDLE_PSERVERS_IP_PORT_LIST must be set # format: string(ip:port,ip:port), eg. 127.0.0.1:6001,127.0.0.1:6002 self._server_endpoints = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST", None) @@ -1084,7 +1094,7 @@ def _ps_env(self): # each role will execute it self._current_id = current_id self._nodes_num = len({x.split(':')[0] for x in self._worker_endpoints}) - def _collective_env(self): + def _collective_env(self) -> None: self._current_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) self._training_role = os.getenv("PADDLE_TRAINING_ROLE", "TRAINER") assert self._training_role == "TRAINER" @@ -1107,7 +1117,7 @@ def _collective_env(self): self._local_device_ids = os.getenv("PADDLE_LOCAL_DEVICE_IDS") self._world_device_ids = os.getenv("PADDLE_WORLD_DEVICE_IDS") - def _gloo_init(self): + def _gloo_init(self) -> None: # PADDLE_WITH_GLOO 1: trainer barrier, 2: all barrier use_gloo = int(os.getenv("PADDLE_WITH_GLOO", "0")) if use_gloo not in [1, 2]: @@ -1186,7 +1196,7 @@ def _gloo_init(self): if rendezvous_type == Gloo.RENDEZVOUS.HTTP: http_server_d['running'] = False - def _generate_role(self): + def _generate_role(self) -> None: """ generate role for role maker """ @@ -1217,13 +1227,18 @@ class UserDefinedRoleMaker(PaddleCloudRoleMaker): ... server_endpoints=["127.0.0.1:36011", "127.0.0.1:36012"]) """ - def __init__(self, is_collective=False, init_gloo=False, **kwargs): + def __init__( + self, + is_collective: bool = False, + init_gloo: bool = False, + **kwargs: Any, + ) -> None: super().__init__( is_collective=is_collective, init_gloo=init_gloo, **kwargs ) self._init_gloo = init_gloo - def _user_defined_ps_env(self): + def _user_defined_ps_env(self) -> None: self._server_endpoints = self._kwargs.get("server_endpoints") self._worker_endpoints = self._kwargs.get("worker_endpoints", []) self._trainers_num = self._kwargs.get("worker_num", 0) @@ -1244,14 +1259,14 @@ def _user_defined_ps_env(self): self._cur_endpoint = self._server_endpoints[self._current_id] self._nodes_num = len({x.split(':')[0] for x in self._worker_endpoints}) - def _user_defined_collective_env(self): + def _user_defined_collective_env(self) -> None: self._worker_endpoints = self._kwargs.get("worker_endpoints") self._current_id = self._kwargs.get("current_id") self._trainers_num = len(self._worker_endpoints) self._training_role = Role.WORKER self._nodes_num = len({x.split(':')[0] for x in self._worker_endpoints}) - def _generate_role(self): + def _generate_role(self) -> None: """ generate role for role maker """ diff --git a/python/paddle/distributed/fleet/base/topology.py b/python/paddle/distributed/fleet/base/topology.py index 8105e2672c87f..1ff54f383b019 100644 --- a/python/paddle/distributed/fleet/base/topology.py +++ b/python/paddle/distributed/fleet/base/topology.py @@ -11,17 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import collections import os from functools import reduce from itertools import product +from typing import TYPE_CHECKING, Any, Literal import paddle from paddle.distributed.utils.nccl_utils import check_nccl_version_for_p2p from ..utils.log_util import logger +if TYPE_CHECKING: + from paddle.distributed.collective import Group + __all__ = ['CommunicateTopology', 'HybridCommunicateGroup'] _HYBRID_PARALLEL_GROUP = None @@ -65,9 +70,15 @@ class ParallelMode: class CommunicateTopology: def __init__( self, - hybrid_group_names=["data", "pipe", "sharding", "sep", "model"], - dims=[1, 1, 1, 1, 1], - ): + hybrid_group_names: list[str] = [ + "data", + "pipe", + "sharding", + "sep", + "model", + ], + dims: list[int] = [1, 1, 1, 1, 1], + ) -> None: self._parallel_names = hybrid_group_names self._dims = dims self.coordinate = collections.namedtuple( @@ -83,27 +94,27 @@ def __init__( zip(self._coord2rank.values(), self._coord2rank.keys()) ) - def get_hybrid_group_names(self): + def get_hybrid_group_names(self) -> list[str]: return self._parallel_names - def get_dim(self, axis_name): + def get_dim(self, axis_name: str) -> int: return self._dims[self._parallel_names.index(axis_name)] - def world_size(self): + def world_size(self) -> int: return self._world_size - def get_rank(self, **args): + def get_rank(self, **args: Any) -> int: assert len(args) == len(self._dims) key = self.coordinate(**args) assert key in self._coord2rank.keys() return self._coord2rank[key] - def get_coord(self, rank): + def get_coord(self, rank: int) -> Any: assert rank < self._world_size assert rank in self._rank2coord.keys() return self._rank2coord[rank] - def get_axis_list(self, axis_name, index): + def get_axis_list(self, axis_name: str, index: int) -> list[int]: axis = self._parallel_names.index(axis_name) ranks = [ self._coord2rank[coord] @@ -113,11 +124,11 @@ def get_axis_list(self, axis_name, index): ranks.sort() return ranks - def get_dim_size(self, axis_name): + def get_dim_size(self, axis_name: str) -> int: assert axis_name in self._parallel_names return self._dims[self._parallel_names.index(axis_name)] - def get_fused_ranks(self, fused_axis): + def get_fused_ranks(self, fused_axis: list[int]) -> list[list[int]]: non_fused_axis = list(set(self._parallel_names).difference(fused_axis)) non_fused_ranges = [] for axis_name in non_fused_axis: @@ -144,7 +155,7 @@ def get_fused_ranks(self, fused_axis): return rank_list - def get_comm_list(self, axis_name): + def get_comm_list(self, axis_name: str) -> list[list[int]]: assert axis_name in self._parallel_names other_axis_names = [ name for name in self._parallel_names if name != axis_name @@ -169,14 +180,14 @@ def get_comm_list(self, axis_name): return all_result - def get_rank_from_stage(self, global_rank, **kwargs): + def get_rank_from_stage(self, global_rank: int, **kwargs: Any) -> int: coord = self.get_coord(global_rank) tf = coord._replace(**kwargs)._asdict() return self.get_rank(**tf) class HybridCommunicateGroup: - def __init__(self, topology): + def __init__(self, topology: CommunicateTopology) -> None: self.nranks = paddle.distributed.get_world_size() self.global_rank = paddle.distributed.get_rank() self._topo = topology @@ -281,7 +292,7 @@ def __init__(self, topology): global _HYBRID_PARALLEL_GROUP _HYBRID_PARALLEL_GROUP = self - def get_parallel_mode(self): + def get_parallel_mode(self) -> Literal[0, 1, 2, 3, 4]: # there are five modes : DataParallel / TensorParallel / PipelineParallel / ShardingParallel / SepParallel # NOTE when sharding conjugates with other parallel, sharding should act like a optimizer and # adding its parallel logic within that parallelism @@ -319,7 +330,7 @@ def get_parallel_mode(self): # pp may coexist with mp、sep、dp and sharding return ParallelMode.PIPELINE_PARALLEL - def _check_valid_topo(self): + def _check_valid_topo(self) -> bool: return ( self._dp_degree * self._mp_degree @@ -329,10 +340,12 @@ def _check_valid_topo(self): == self.nranks ) - def _check_sep_exist(self): + def _check_sep_exist(self) -> None: assert self._sep_degree > 1, "sep not exist" - def _set_comm_group(self, parallel_method="data"): + def _set_comm_group( + self, parallel_method: str = "data" + ) -> tuple[list[int], Group]: parallel_group = [] parallel_comm_group = None parallel_groups = self._topo.get_comm_list(parallel_method) @@ -359,7 +372,9 @@ def _set_comm_group(self, parallel_method="data"): ) return parallel_group, parallel_comm_group - def _set_check_group(self, parallel_method="data"): + def _set_check_group( + self, parallel_method: str = "data" + ) -> tuple[list[int], Group]: parallel_group = [] parallel_comm_group = None parallel_size = self._topo.get_dim(parallel_method) @@ -375,15 +390,15 @@ def _set_check_group(self, parallel_method="data"): return parallel_group, parallel_comm_group - def _get_p2p_next_rank(self): + def _get_p2p_next_rank(self) -> int: assert hasattr(self, 'next_rank'), "next_rank has not been inited" return self.next_rank - def _get_p2p_prev_rank(self): + def _get_p2p_prev_rank(self) -> int: assert hasattr(self, 'prev_rank'), "prev_rank has not been inited" return self.prev_rank - def _set_p2p_prev_next(self): + def _set_p2p_prev_next(self) -> None: comm_lists = self._topo.get_comm_list('pipe') for comm_ranks in comm_lists: @@ -397,7 +412,7 @@ def _set_p2p_prev_next(self): self.next_rank = next_rank self.prev_rank = prev_rank - def _set_four_directions_p2p_group(self): + def _set_four_directions_p2p_group(self) -> None: comm_lists = self._topo.get_comm_list('pipe') self.send_next_group = None @@ -434,75 +449,75 @@ def _set_four_directions_p2p_group(self): assert self.recv_next_group is not None assert self.recv_prev_group is not None - def topology(self): + def topology(self) -> CommunicateTopology: return self._topo - def get_global_rank(self): + def get_global_rank(self) -> int: return self.global_rank # data parallel message: - def _get_data_parallel_id(self): + def _get_data_parallel_id(self) -> int: return self._topo.get_coord(self.global_rank).data - def get_data_parallel_rank(self): + def get_data_parallel_rank(self) -> int: return self._data_parallel_id - def get_data_parallel_world_size(self): + def get_data_parallel_world_size(self) -> int: return self._dp_degree - def get_data_parallel_group(self): + def get_data_parallel_group(self) -> Group: return self._dp_comm_group - def get_data_parallel_group_src_rank(self): + def get_data_parallel_group_src_rank(self) -> int: return self._dp_comm_group.ranks[0] # model parallel message: - def _get_model_parallel_id(self): + def _get_model_parallel_id(self) -> str: return self._topo.get_coord(self.global_rank).model - def get_model_parallel_rank(self): + def get_model_parallel_rank(self) -> int: return self._model_parallel_id - def get_model_parallel_world_size(self): + def get_model_parallel_world_size(self) -> int: return self._mp_degree - def get_model_parallel_group(self): + def get_model_parallel_group(self) -> Group: return self._mp_comm_group - def get_model_parallel_group_src_rank(self): + def get_model_parallel_group_src_rank(self) -> int: return self._mp_comm_group.ranks[0] # pipeline parallel message - def _get_pipe_parallel_id(self): + def _get_pipe_parallel_id(self) -> int: return self._topo.get_coord(self.global_rank).pipe - def get_stage_id(self): + def get_stage_id(self) -> int: return self.stage_id - def get_pipe_parallel_world_size(self): + def get_pipe_parallel_world_size(self) -> int: return self._pp_degree - def _get_sep_parallel_id(self): + def _get_sep_parallel_id(self) -> int: return self._topo.get_coord(self.global_rank).sep - def get_sep_parallel_rank(self): + def get_sep_parallel_rank(self) -> int: return self._sep_parallel_id - def get_sep_parallel_world_size(self): + def get_sep_parallel_world_size(self) -> int: return self._sep_degree - def get_sep_parallel_group(self): + def get_sep_parallel_group(self) -> Group: self._check_sep_exist() return self._sep_comm_group - def get_sep_parallel_group_src_rank(self): + def get_sep_parallel_group_src_rank(self) -> int: self._check_sep_exist() return self._sep_comm_group.ranks[0] - def get_pipe_parallel_group(self): + def get_pipe_parallel_group(self) -> Group: return self._pp_comm_group - def get_p2p_groups(self): + def get_p2p_groups(self) -> tuple[Group, Group, Group, Group]: assert ( _use_four_directions ), "If you want to use four directions p2p group, set the environment variable PADDLE_USE_FOUR_DIRECTIONS_P2P to True." @@ -514,44 +529,46 @@ def get_p2p_groups(self): ) # sharding parallel message: - def _get_sharding_parallel_id(self): + def _get_sharding_parallel_id(self) -> int: return self._topo.get_coord(self.global_rank).sharding - def get_sharding_parallel_rank(self): + def get_sharding_parallel_rank(self) -> int: return self._sharding_parallel_id - def get_sharding_parallel_world_size(self): + def get_sharding_parallel_world_size(self) -> int: return self._sharding_degree - def get_sharding_parallel_group(self): + def get_sharding_parallel_group(self) -> Group: return self._sharding_comm_group - def get_sharding_parallel_group_src_rank(self): + def get_sharding_parallel_group_src_rank(self) -> int: # TODO should the src rank related to the shard rank for each parameter ? return self._sharding_comm_group.ranks[0] # check parallel group - def get_check_parallel_group(self, sharding=False): + def get_check_parallel_group(self, sharding: bool = False) -> Group: if sharding: return self.sharding_check_comm_group else: return self._check_comm_group - def get_rank_from_stage(self, stage_id, **kwargs): + def get_rank_from_stage(self, stage_id: int, **kwargs: Any) -> int: return self._topo.get_rank_from_stage( self.global_rank, pipe=stage_id, **kwargs ) # fuse comm group message - def get_dp_sep_parallel_group(self): + def get_dp_sep_parallel_group(self) -> Group: self._check_sep_exist() return self._dp_sep_comm_group - def get_pp_mp_parallel_group(self): + def get_pp_mp_parallel_group(self) -> Group: self._check_sep_exist() return self._pp_mp_comm_group - def create_fuse_group(self, fused_strategy_list): + def create_fuse_group( + self, fused_strategy_list: list[str] + ) -> tuple[list[list[int]], list[Group]] | tuple[list[int], Group]: assert ( len(fused_strategy_list) > 0 ), "the length of fused_strategy_list must be greater than 0."