Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Whsjrczr committed Aug 15, 2024
1 parent 47211e5 commit 1c8ed16
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 43 deletions.
45 changes: 26 additions & 19 deletions python/paddle/distributed/fleet/base/role_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import time
import warnings
from multiprocessing import Manager, Process
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, ClassVar, Literal

import numpy as np

Expand All @@ -38,11 +38,11 @@


class Role:
WORKER = 1
SERVER = 2
HETER_WORKER = 3
ALL = 4
COORDINATOR = 5
WORKER: ClassVar[Literal[1]] = 1
SERVER: ClassVar[Literal[2]] = 2
HETER_WORKER: ClassVar[Literal[3]] = 3
ALL: ClassVar[Literal[4]] = 4
COORDINATOR: ClassVar[Literal[5]] = 5


class Gloo:
Expand Down Expand Up @@ -598,7 +598,9 @@ def __init__(self, is_collective: bool = False, **kwargs: Any) -> None:
def _barrier(self, comm_world: str) -> None:
self._gloo.barrier(comm_world)

def _all_gather(self, input: Any, comm_world="worker") -> task:
def _all_gather(
self, input: Any, comm_world: str = "worker"
) -> np.NDArray[Any]:
return self._gloo.all_gather(input, comm_world)

def _all_reduce(
Expand All @@ -622,15 +624,15 @@ def _heter_device_type(self) -> str:
self._generate_role()
return self._heter_trainer_device_type

def _get_stage_id(self) -> int | str:
def _get_stage_id(self) -> int:
"""
return stage id of current heter worker
"""
if not self._role_is_generated:
self._generate_role()
return self._stage_id

def _get_stage_trainers(self) -> list:
def _get_stage_trainers(self) -> list[int]:
"""
return trainer num of all stages
"""
Expand Down Expand Up @@ -675,23 +677,23 @@ def _is_first_worker(self) -> bool:
self._generate_role()
return self._role == Role.WORKER and self._current_id == 0

def _worker_index(self) -> int | str:
def _worker_index(self) -> int:
"""
get index of current worker
"""
if not self._role_is_generated:
self._generate_role()
return self._current_id

def _server_index(self) -> int | str:
def _server_index(self) -> int:
"""
get index of current server
"""
if not self._role_is_generated:
self._generate_role()
return self._current_id

def _role_id(self) -> int | str:
def _role_id(self) -> int:
"""
get index of current node
"""
Expand Down Expand Up @@ -750,23 +752,23 @@ def _get_world_device_ids(self) -> str | None:
self._generate_role()
return self._world_device_ids

def _get_trainer_endpoints(self) -> str | list[str]:
def _get_trainer_endpoints(self) -> list[str]:
"""
get endpoint of all trainers
"""
if not self._role_is_generated:
self._generate_role()
return self._worker_endpoints

def _get_trainer_endpoint(self) -> str | None | list[str]:
def _get_trainer_endpoint(self) -> str:
if not self._role_is_generated:
self._generate_role()
assert (
self._role == Role.WORKER
), "get_trainer_endpoint should be called by trainer"
return self._cur_endpoint

def _get_heter_worker_endpoints(self) -> list[str] | None:
def _get_heter_worker_endpoints(self) -> list[str]:
"""
Returns:
string: all heter_trainers'endpoints
Expand All @@ -778,7 +780,7 @@ def _get_heter_worker_endpoints(self) -> list[str] | None:
), "Heter Worker Endpoints Not initialized"
return self._heter_trainer_endpoints

def _get_heter_worker_endpoint(self) -> str | None:
def _get_heter_worker_endpoint(self) -> str:
"""
Returns:
str: corresponding heter_trainer's endpoint
Expand All @@ -790,15 +792,15 @@ def _get_heter_worker_endpoint(self) -> str | None:
), "_get_heter_worker_endpoint should be invoked by heter worker"
return self._cur_endpoint

def _get_pserver_endpoints(self) -> str | list[str] | None:
def _get_pserver_endpoints(self) -> list[str]:
"""
get endpoint of all pservers
"""
if not self._role_is_generated:
self._generate_role()
return self._server_endpoints

def _get_coordinator_endpoints(self) -> str | list[str] | None:
def _get_coordinator_endpoints(self) -> list[str]:
if not self._role_is_generated:
self._generate_role()
return self._coordinator_endpoints
Expand Down Expand Up @@ -1225,7 +1227,12 @@ class UserDefinedRoleMaker(PaddleCloudRoleMaker):
... server_endpoints=["127.0.0.1:36011", "127.0.0.1:36012"])
"""

def __init__(self, is_collective=False, init_gloo=False, **kwargs):
def __init__(
self,
is_collective: bool = False,
init_gloo: bool = False,
**kwargs: Any,
) -> None:
super().__init__(
is_collective=is_collective, init_gloo=init_gloo, **kwargs
)
Expand Down
43 changes: 19 additions & 24 deletions python/paddle/distributed/fleet/base/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from ..utils.log_util import logger

if TYPE_CHECKING:
from paddle.base.core import task
from paddle.distributed.collective import Group

__all__ = ['CommunicateTopology', 'HybridCommunicateGroup']
Expand Down Expand Up @@ -98,24 +97,24 @@ def __init__(
def get_hybrid_group_names(self) -> list[str]:
return self._parallel_names

def get_dim(self, axis_name: str) -> list[str]:
def get_dim(self, axis_name: str) -> int:
return self._dims[self._parallel_names.index(axis_name)]

def world_size(self) -> int:
return self._world_size

def get_rank(self, **args: Any) -> dict[list[str], int]:
def get_rank(self, **args: Any) -> int:
assert len(args) == len(self._dims)
key = self.coordinate(**args)
assert key in self._coord2rank.keys()
return self._coord2rank[key]

def get_coord(self, rank: str) -> int | list[str] | None:
def get_coord(self, rank: int) -> Any:
assert rank < self._world_size
assert rank in self._rank2coord.keys()
return self._rank2coord[rank]

def get_axis_list(self, axis_name: str, index: str) -> list[int]:
def get_axis_list(self, axis_name: str, index: int) -> list[int]:
axis = self._parallel_names.index(axis_name)
ranks = [
self._coord2rank[coord]
Expand All @@ -129,7 +128,7 @@ def get_dim_size(self, axis_name: str) -> int:
assert axis_name in self._parallel_names
return self._dims[self._parallel_names.index(axis_name)]

def get_fused_ranks(self, fused_axis: str) -> list[int]:
def get_fused_ranks(self, fused_axis: list[int]) -> list[list[int]]:
non_fused_axis = list(set(self._parallel_names).difference(fused_axis))
non_fused_ranges = []
for axis_name in non_fused_axis:
Expand All @@ -156,7 +155,7 @@ def get_fused_ranks(self, fused_axis: str) -> list[int]:

return rank_list

def get_comm_list(self, axis_name: str) -> list[list[int | list[str]]]:
def get_comm_list(self, axis_name: str) -> list[list[int]]:
assert axis_name in self._parallel_names
other_axis_names = [
name for name in self._parallel_names if name != axis_name
Expand All @@ -181,16 +180,14 @@ def get_comm_list(self, axis_name: str) -> list[list[int | list[str]]]:

return all_result

def get_rank_from_stage(
self, global_rank: int | list[str] | None, **kwargs: Any
) -> task:
def get_rank_from_stage(self, global_rank: int, **kwargs: Any) -> int:
coord = self.get_coord(global_rank)
tf = coord._replace(**kwargs)._asdict()
return self.get_rank(**tf)


class HybridCommunicateGroup:
def __init__(self, topology: dict[int]) -> None:
def __init__(self, topology: CommunicateTopology) -> None:
self.nranks = paddle.distributed.get_world_size()
self.global_rank = paddle.distributed.get_rank()
self._topo = topology
Expand Down Expand Up @@ -348,7 +345,7 @@ def _check_sep_exist(self) -> None:

def _set_comm_group(
self, parallel_method: str = "data"
) -> None | tuple[list[int], Group]:
) -> tuple[list[int], Group]:
parallel_group = []
parallel_comm_group = None
parallel_groups = self._topo.get_comm_list(parallel_method)
Expand Down Expand Up @@ -377,7 +374,7 @@ def _set_comm_group(

def _set_check_group(
self, parallel_method: str = "data"
) -> tuple[list[int], Group] | None:
) -> tuple[list[int], Group]:
parallel_group = []
parallel_comm_group = None
parallel_size = self._topo.get_dim(parallel_method)
Expand All @@ -393,11 +390,11 @@ def _set_check_group(

return parallel_group, parallel_comm_group

def _get_p2p_next_rank(self) -> int | None:
def _get_p2p_next_rank(self) -> int:
assert hasattr(self, 'next_rank'), "next_rank has not been inited"
return self.next_rank

def _get_p2p_prev_rank(self) -> int | None:
def _get_p2p_prev_rank(self) -> int:
assert hasattr(self, 'prev_rank'), "prev_rank has not been inited"
return self.prev_rank

Expand Down Expand Up @@ -452,7 +449,7 @@ def _set_four_directions_p2p_group(self) -> None:
assert self.recv_next_group is not None
assert self.recv_prev_group is not None

def topology(self) -> dict[int]:
def topology(self) -> CommunicateTopology:
return self._topo

def get_global_rank(self) -> int:
Expand All @@ -468,7 +465,7 @@ def get_data_parallel_rank(self) -> int:
def get_data_parallel_world_size(self) -> int:
return self._dp_degree

def get_data_parallel_group(self) -> list[int]:
def get_data_parallel_group(self) -> Group:
return self._dp_comm_group

def get_data_parallel_group_src_rank(self) -> int:
Expand Down Expand Up @@ -549,9 +546,7 @@ def get_sharding_parallel_group_src_rank(self) -> int:
return self._sharding_comm_group.ranks[0]

# check parallel group
def get_check_parallel_group(
self, sharding: bool = False
) -> list[int] | Group:
def get_check_parallel_group(self, sharding: bool = False) -> Group:
if sharding:
return self.sharding_check_comm_group
else:
Expand All @@ -563,17 +558,17 @@ def get_rank_from_stage(self, stage_id: int, **kwargs: Any) -> int:
)

# fuse comm group message
def get_dp_sep_parallel_group(self) -> list[int]:
def get_dp_sep_parallel_group(self) -> Group:
self._check_sep_exist()
return self._dp_sep_comm_group

def get_pp_mp_parallel_group(self) -> list[int]:
def get_pp_mp_parallel_group(self) -> Group:
self._check_sep_exist()
return self._pp_mp_comm_group

def create_fuse_group(
self, fused_strategy_list: list[int]
) -> list[int] | int:
self, fused_strategy_list: list[str]
) -> tuple[list[list[int]], list[Group]] | tuple[list[int], Group]:
assert (
len(fused_strategy_list) > 0
), "the length of fused_strategy_list must be greater than 0."
Expand Down

0 comments on commit 1c8ed16

Please sign in to comment.