Skip to content

Commit

Permalink
[Typing][C-41, C-42][BUAA] Add type annotations for `python/paddle/di…
Browse files Browse the repository at this point in the history
…stributed/fleet/base/*` (#67439)

---------

Co-authored-by: megemini <[email protected]>
  • Loading branch information
Whsjrczr and megemini authored Aug 19, 2024
1 parent 4c44259 commit b07423c
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 102 deletions.
109 changes: 62 additions & 47 deletions python/paddle/distributed/fleet/base/role_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Definition of Role Makers."""
from __future__ import annotations

import os
import re
import time
import warnings
from multiprocessing import Manager, Process
from typing import TYPE_CHECKING, Any, ClassVar, Literal

import numpy as np

Expand All @@ -28,15 +31,18 @@

from ...backup_env import getenv_or_backup

if TYPE_CHECKING:
import numpy.typing as npt

__all__ = []


class Role:
WORKER = 1
SERVER = 2
HETER_WORKER = 3
ALL = 4
COORDINATOR = 5
WORKER: ClassVar[Literal[1]] = 1
SERVER: ClassVar[Literal[2]] = 2
HETER_WORKER: ClassVar[Literal[3]] = 3
ALL: ClassVar[Literal[4]] = 4
COORDINATOR: ClassVar[Literal[5]] = 5


class Gloo:
Expand Down Expand Up @@ -563,7 +569,7 @@ class PaddleCloudRoleMaker(RoleMakerBase):
"""

def __init__(self, is_collective=False, **kwargs):
def __init__(self, is_collective: bool = False, **kwargs: Any) -> None:
super().__init__()
self._is_collective = is_collective
self._non_distributed = False
Expand All @@ -589,117 +595,121 @@ def __init__(self, is_collective=False, **kwargs):

self._gloo = Gloo() # gloo instance

def _barrier(self, comm_world):
def _barrier(self, comm_world: str) -> None:
self._gloo.barrier(comm_world)

def _all_gather(self, input, comm_world="worker"):
def _all_gather(
self, input: Any, comm_world: str = "worker"
) -> list[float]:
return self._gloo.all_gather(input, comm_world)

def _all_reduce(self, input, mode="sum", comm_world="worker"):
def _all_reduce(
self, input: Any, mode: str = "sum", comm_world: str = "worker"
) -> npt.NDArray[Any]:
return self._gloo.all_reduce(input, mode, comm_world)

def _heter_device(self):
def _heter_device(self) -> str:
"""
return the heter device that current heter worker is using
"""
if not self._role_is_generated:
self._generate_role()
return self._heter_trainer_device

def _heter_device_type(self):
def _heter_device_type(self) -> str:
"""
return the heter device type that current heter worker is using
"""
if not self._role_is_generated:
self._generate_role()
return self._heter_trainer_device_type

def _get_stage_id(self):
def _get_stage_id(self) -> int:
"""
return stage id of current heter worker
"""
if not self._role_is_generated:
self._generate_role()
return self._stage_id

def _get_stage_trainers(self):
def _get_stage_trainers(self) -> list[int]:
"""
return trainer num of all stages
"""
if not self._role_is_generated:
self._generate_role()
return self._stage_trainers

def _get_num_stage(self):
def _get_num_stage(self) -> int:
"""
return stage num
"""
if not self._role_is_generated:
self._generate_role()
return self._stage_num

def _is_worker(self):
def _is_worker(self) -> bool:
"""
whether current process is worker
"""
if not self._role_is_generated:
self._generate_role()
return self._role == Role.WORKER

def _is_server(self):
def _is_server(self) -> bool:
"""
whether current process is server
"""
if not self._role_is_generated:
self._generate_role()
return self._role == Role.SERVER

def _is_coordinator(self):
def _is_coordinator(self) -> bool:
if not self._role_is_generated:
self._generate_role()
return self._role == Role.COORDINATOR

def _is_first_worker(self):
def _is_first_worker(self) -> bool:
"""
whether current process is worker of rank 0
"""
if not self._role_is_generated:
self._generate_role()
return self._role == Role.WORKER and self._current_id == 0

def _worker_index(self):
def _worker_index(self) -> int:
"""
get index of current worker
"""
if not self._role_is_generated:
self._generate_role()
return self._current_id

def _server_index(self):
def _server_index(self) -> int:
"""
get index of current server
"""
if not self._role_is_generated:
self._generate_role()
return self._current_id

def _role_id(self):
def _role_id(self) -> int:
"""
get index of current node
"""
if not self._role_is_generated:
self._generate_role()
return self._current_id

def _worker_num(self):
def _worker_num(self) -> int:
"""
return the current number of worker
"""
if not self._role_is_generated:
self._generate_role()
return self._trainers_num

def _server_num(self):
def _server_num(self) -> int:
"""
return the current number of server
"""
Expand All @@ -711,54 +721,54 @@ def _server_num(self):
else 0
)

def _node_num(self):
def _node_num(self) -> int:
"""
return the training node number
"""
if not self._role_is_generated:
self._generate_role()
return self._nodes_num

def _get_node_num(self):
def _get_node_num(self) -> int:
"""
return the training node number
"""
if not self._role_is_generated:
self._generate_role()
return self._nodes_num

def _get_local_rank(self):
def _get_local_rank(self) -> str | None:
if not self._role_is_generated:
self._generate_role()
return self._local_rank

def _get_local_device_ids(self):
def _get_local_device_ids(self) -> str | None:
if not self._role_is_generated:
self._generate_role()
return self._local_device_ids

def _get_world_device_ids(self):
def _get_world_device_ids(self) -> str | None:
if not self._role_is_generated:
self._generate_role()
return self._world_device_ids

def _get_trainer_endpoints(self):
def _get_trainer_endpoints(self) -> list[str]:
"""
get endpoint of all trainers
"""
if not self._role_is_generated:
self._generate_role()
return self._worker_endpoints

def _get_trainer_endpoint(self):
def _get_trainer_endpoint(self) -> str:
if not self._role_is_generated:
self._generate_role()
assert (
self._role == Role.WORKER
), "get_trainer_endpoint should be called by trainer"
return self._cur_endpoint

def _get_heter_worker_endpoints(self):
def _get_heter_worker_endpoints(self) -> list[str]:
"""
Returns:
string: all heter_trainers'endpoints
Expand All @@ -770,10 +780,10 @@ def _get_heter_worker_endpoints(self):
), "Heter Worker Endpoints Not initialized"
return self._heter_trainer_endpoints

def _get_heter_worker_endpoint(self):
def _get_heter_worker_endpoint(self) -> str:
"""
Returns:
int: corresponding heter_trainer's endpoint
str: corresponding heter_trainer's endpoint
"""
if not self._role_is_generated:
self._generate_role()
Expand All @@ -782,15 +792,15 @@ def _get_heter_worker_endpoint(self):
), "_get_heter_worker_endpoint should be invoked by heter worker"
return self._cur_endpoint

def _get_pserver_endpoints(self):
def _get_pserver_endpoints(self) -> list[str]:
"""
get endpoint of all pservers
"""
if not self._role_is_generated:
self._generate_role()
return self._server_endpoints

def _get_coordinator_endpoints(self):
def _get_coordinator_endpoints(self) -> list[str]:
if not self._role_is_generated:
self._generate_role()
return self._coordinator_endpoints
Expand All @@ -807,7 +817,7 @@ def _get_previous_trainers(self):
), "_get_previous_trainers should be invoked by trainer or heter worker"
return self._previous_heter_trainer_endpoints

def _get_next_trainers(self):
def _get_next_trainers(self) -> list[str]:
"""
invoked by heter worker
"""
Expand All @@ -819,7 +829,7 @@ def _get_next_trainers(self):
), "_get_next_trainers should be invoked by trainer or heter worker"
return self._next_heter_trainer_endpoints

def _is_non_distributed(self):
def _is_non_distributed(self) -> bool:
"""
Return True if indispensable environment for fleetrun is not found
(use python-run to launch fleet-code directly)
Expand All @@ -828,23 +838,23 @@ def _is_non_distributed(self):
self._generate_role()
return self._non_distributed

def _heter_worker_num(self):
def _heter_worker_num(self) -> int:
"""
get heter worker nums
"""
if not self._role_is_generated:
self._generate_role()
return self._heter_trainers_num

def _is_heter_worker(self):
def _is_heter_worker(self) -> bool:
"""
whether current process is heter worker
"""
if not self._role_is_generated:
self._generate_role()
return self._role == Role.HETER_WORKER

def _ps_env(self): # each role will execute it
def _ps_env(self) -> None: # each role will execute it
# Environment variable PADDLE_PSERVERS_IP_PORT_LIST must be set
# format: string(ip:port,ip:port), eg. 127.0.0.1:6001,127.0.0.1:6002
self._server_endpoints = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST", None)
Expand Down Expand Up @@ -1084,7 +1094,7 @@ def _ps_env(self): # each role will execute it
self._current_id = current_id
self._nodes_num = len({x.split(':')[0] for x in self._worker_endpoints})

def _collective_env(self):
def _collective_env(self) -> None:
self._current_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
self._training_role = os.getenv("PADDLE_TRAINING_ROLE", "TRAINER")
assert self._training_role == "TRAINER"
Expand All @@ -1107,7 +1117,7 @@ def _collective_env(self):
self._local_device_ids = os.getenv("PADDLE_LOCAL_DEVICE_IDS")
self._world_device_ids = os.getenv("PADDLE_WORLD_DEVICE_IDS")

def _gloo_init(self):
def _gloo_init(self) -> None:
# PADDLE_WITH_GLOO 1: trainer barrier, 2: all barrier
use_gloo = int(os.getenv("PADDLE_WITH_GLOO", "0"))
if use_gloo not in [1, 2]:
Expand Down Expand Up @@ -1186,7 +1196,7 @@ def _gloo_init(self):
if rendezvous_type == Gloo.RENDEZVOUS.HTTP:
http_server_d['running'] = False

def _generate_role(self):
def _generate_role(self) -> None:
"""
generate role for role maker
"""
Expand Down Expand Up @@ -1217,13 +1227,18 @@ class UserDefinedRoleMaker(PaddleCloudRoleMaker):
... server_endpoints=["127.0.0.1:36011", "127.0.0.1:36012"])
"""

def __init__(self, is_collective=False, init_gloo=False, **kwargs):
def __init__(
self,
is_collective: bool = False,
init_gloo: bool = False,
**kwargs: Any,
) -> None:
super().__init__(
is_collective=is_collective, init_gloo=init_gloo, **kwargs
)
self._init_gloo = init_gloo

def _user_defined_ps_env(self):
def _user_defined_ps_env(self) -> None:
self._server_endpoints = self._kwargs.get("server_endpoints")
self._worker_endpoints = self._kwargs.get("worker_endpoints", [])
self._trainers_num = self._kwargs.get("worker_num", 0)
Expand All @@ -1244,14 +1259,14 @@ def _user_defined_ps_env(self):
self._cur_endpoint = self._server_endpoints[self._current_id]
self._nodes_num = len({x.split(':')[0] for x in self._worker_endpoints})

def _user_defined_collective_env(self):
def _user_defined_collective_env(self) -> None:
self._worker_endpoints = self._kwargs.get("worker_endpoints")
self._current_id = self._kwargs.get("current_id")
self._trainers_num = len(self._worker_endpoints)
self._training_role = Role.WORKER
self._nodes_num = len({x.split(':')[0] for x in self._worker_endpoints})

def _generate_role(self):
def _generate_role(self) -> None:
"""
generate role for role maker
"""
Expand Down
Loading

0 comments on commit b07423c

Please sign in to comment.