Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][C-41, C-42][BUAA] Add type annotations for python/paddle/distributed/fleet/base/* #67439

Merged
merged 11 commits into from
Aug 19, 2024
107 changes: 60 additions & 47 deletions python/paddle/distributed/fleet/base/role_maker.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件的 Role 需要标注

class Role:
    WORKER: ClassVar[Literal[1]] = 1
    SERVER: ClassVar[Literal[2]]  = 2
    HETER_WORKER: ClassVar[Literal[3]]  = 3
    ALL: ClassVar[Literal[4]]  = 4
    COORDINATOR: ClassVar[Literal[5]]  = 5

另外,UserDefinedRoleMaker__init__ 没有标注

Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Definition of Role Makers."""
from __future__ import annotations

import os
import re
import time
import warnings
from multiprocessing import Manager, Process
from typing import Any, ClassVar, Literal

import numpy as np
import numpy.typing as npt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

放到 TYPE_CHECKING 里面

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改


import paddle
from paddle.base import core
Expand All @@ -32,11 +36,11 @@


class Role:
WORKER = 1
SERVER = 2
HETER_WORKER = 3
ALL = 4
COORDINATOR = 5
WORKER: ClassVar[Literal[1]] = 1
SERVER: ClassVar[Literal[2]] = 2
HETER_WORKER: ClassVar[Literal[3]] = 3
ALL: ClassVar[Literal[4]] = 4
COORDINATOR: ClassVar[Literal[5]] = 5


class Gloo:
Expand Down Expand Up @@ -563,7 +567,7 @@ class PaddleCloudRoleMaker(RoleMakerBase):

"""

def __init__(self, is_collective=False, **kwargs):
def __init__(self, is_collective: bool = False, **kwargs: Any) -> None:
super().__init__()
self._is_collective = is_collective
self._non_distributed = False
Expand All @@ -589,117 +593,121 @@ def __init__(self, is_collective=False, **kwargs):

self._gloo = Gloo() # gloo instance

def _barrier(self, comm_world):
def _barrier(self, comm_world: str) -> None:
self._gloo.barrier(comm_world)

def _all_gather(self, input, comm_world="worker"):
def _all_gather(
self, input: Any, comm_world: str = "worker"
) -> list[float]:
return self._gloo.all_gather(input, comm_world)

def _all_reduce(self, input, mode="sum", comm_world="worker"):
def _all_reduce(
self, input: Any, mode: str = "sum", comm_world: str = "worker"
) -> npt.NDArray[Any]:
return self._gloo.all_reduce(input, mode, comm_world)

def _heter_device(self):
def _heter_device(self) -> str:
"""
return the heter device that current heter worker is using
"""
if not self._role_is_generated:
self._generate_role()
return self._heter_trainer_device

def _heter_device_type(self):
def _heter_device_type(self) -> str:
"""
return the heter device type that current heter worker is using
"""
if not self._role_is_generated:
self._generate_role()
return self._heter_trainer_device_type

def _get_stage_id(self):
def _get_stage_id(self) -> int:
"""
return stage id of current heter worker
"""
if not self._role_is_generated:
self._generate_role()
return self._stage_id

def _get_stage_trainers(self):
def _get_stage_trainers(self) -> list[int]:
"""
return trainer num of all stages
"""
if not self._role_is_generated:
self._generate_role()
return self._stage_trainers

def _get_num_stage(self):
def _get_num_stage(self) -> int:
"""
return stage num
"""
if not self._role_is_generated:
self._generate_role()
return self._stage_num

def _is_worker(self):
def _is_worker(self) -> bool:
"""
whether current process is worker
"""
if not self._role_is_generated:
self._generate_role()
return self._role == Role.WORKER

def _is_server(self):
def _is_server(self) -> bool:
"""
whether current process is server
"""
if not self._role_is_generated:
self._generate_role()
return self._role == Role.SERVER

def _is_coordinator(self):
def _is_coordinator(self) -> bool:
if not self._role_is_generated:
self._generate_role()
return self._role == Role.COORDINATOR

def _is_first_worker(self):
def _is_first_worker(self) -> bool:
"""
whether current process is worker of rank 0
"""
if not self._role_is_generated:
self._generate_role()
return self._role == Role.WORKER and self._current_id == 0

def _worker_index(self):
def _worker_index(self) -> int:
"""
get index of current worker
"""
if not self._role_is_generated:
self._generate_role()
return self._current_id

def _server_index(self):
def _server_index(self) -> int:
"""
get index of current server
"""
if not self._role_is_generated:
self._generate_role()
return self._current_id

def _role_id(self):
def _role_id(self) -> int:
"""
get index of current node
"""
if not self._role_is_generated:
self._generate_role()
return self._current_id

def _worker_num(self):
def _worker_num(self) -> int:
"""
return the current number of worker
"""
if not self._role_is_generated:
self._generate_role()
return self._trainers_num

def _server_num(self):
def _server_num(self) -> int:
"""
return the current number of server
"""
Expand All @@ -711,54 +719,54 @@ def _server_num(self):
else 0
)

def _node_num(self):
def _node_num(self) -> int:
"""
return the training node number
"""
if not self._role_is_generated:
self._generate_role()
return self._nodes_num

def _get_node_num(self):
def _get_node_num(self) -> int:
"""
return the training node number
"""
if not self._role_is_generated:
self._generate_role()
return self._nodes_num

def _get_local_rank(self):
def _get_local_rank(self) -> str | None:
if not self._role_is_generated:
self._generate_role()
return self._local_rank

def _get_local_device_ids(self):
def _get_local_device_ids(self) -> str | None:
if not self._role_is_generated:
self._generate_role()
return self._local_device_ids

def _get_world_device_ids(self):
def _get_world_device_ids(self) -> str | None:
if not self._role_is_generated:
self._generate_role()
return self._world_device_ids

def _get_trainer_endpoints(self):
def _get_trainer_endpoints(self) -> list[str]:
"""
get endpoint of all trainers
"""
if not self._role_is_generated:
self._generate_role()
return self._worker_endpoints

def _get_trainer_endpoint(self):
def _get_trainer_endpoint(self) -> str:
if not self._role_is_generated:
self._generate_role()
assert (
self._role == Role.WORKER
), "get_trainer_endpoint should be called by trainer"
return self._cur_endpoint

def _get_heter_worker_endpoints(self):
def _get_heter_worker_endpoints(self) -> list[str]:
"""
Returns:
string: all heter_trainers'endpoints
Expand All @@ -770,10 +778,10 @@ def _get_heter_worker_endpoints(self):
), "Heter Worker Endpoints Not initialized"
return self._heter_trainer_endpoints

def _get_heter_worker_endpoint(self):
def _get_heter_worker_endpoint(self) -> str:
"""
Returns:
int: corresponding heter_trainer's endpoint
str: corresponding heter_trainer's endpoint
"""
if not self._role_is_generated:
self._generate_role()
Expand All @@ -782,15 +790,15 @@ def _get_heter_worker_endpoint(self):
), "_get_heter_worker_endpoint should be invoked by heter worker"
return self._cur_endpoint

def _get_pserver_endpoints(self):
def _get_pserver_endpoints(self) -> list[str]:
"""
get endpoint of all pservers
"""
if not self._role_is_generated:
self._generate_role()
return self._server_endpoints

def _get_coordinator_endpoints(self):
def _get_coordinator_endpoints(self) -> list[str]:
if not self._role_is_generated:
self._generate_role()
return self._coordinator_endpoints
Expand All @@ -807,7 +815,7 @@ def _get_previous_trainers(self):
), "_get_previous_trainers should be invoked by trainer or heter worker"
return self._previous_heter_trainer_endpoints

def _get_next_trainers(self):
def _get_next_trainers(self) -> list[str]:
"""
invoked by heter worker
"""
Expand All @@ -819,7 +827,7 @@ def _get_next_trainers(self):
), "_get_next_trainers should be invoked by trainer or heter worker"
return self._next_heter_trainer_endpoints

def _is_non_distributed(self):
def _is_non_distributed(self) -> bool:
"""
Return True if indispensable environment for fleetrun is not found
(use python-run to launch fleet-code directly)
Expand All @@ -828,23 +836,23 @@ def _is_non_distributed(self):
self._generate_role()
return self._non_distributed

def _heter_worker_num(self):
def _heter_worker_num(self) -> int:
"""
get heter worker nums
"""
if not self._role_is_generated:
self._generate_role()
return self._heter_trainers_num

def _is_heter_worker(self):
def _is_heter_worker(self) -> bool:
"""
whether current process is heter worker
"""
if not self._role_is_generated:
self._generate_role()
return self._role == Role.HETER_WORKER

def _ps_env(self): # each role will execute it
def _ps_env(self) -> None: # each role will execute it
# Environment variable PADDLE_PSERVERS_IP_PORT_LIST must be set
# format: string(ip:port,ip:port), eg. 127.0.0.1:6001,127.0.0.1:6002
self._server_endpoints = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST", None)
Expand Down Expand Up @@ -1084,7 +1092,7 @@ def _ps_env(self): # each role will execute it
self._current_id = current_id
self._nodes_num = len({x.split(':')[0] for x in self._worker_endpoints})

def _collective_env(self):
def _collective_env(self) -> None:
self._current_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
self._training_role = os.getenv("PADDLE_TRAINING_ROLE", "TRAINER")
assert self._training_role == "TRAINER"
Expand All @@ -1107,7 +1115,7 @@ def _collective_env(self):
self._local_device_ids = os.getenv("PADDLE_LOCAL_DEVICE_IDS")
self._world_device_ids = os.getenv("PADDLE_WORLD_DEVICE_IDS")

def _gloo_init(self):
def _gloo_init(self) -> None:
# PADDLE_WITH_GLOO 1: trainer barrier, 2: all barrier
use_gloo = int(os.getenv("PADDLE_WITH_GLOO", "0"))
if use_gloo not in [1, 2]:
Expand Down Expand Up @@ -1186,7 +1194,7 @@ def _gloo_init(self):
if rendezvous_type == Gloo.RENDEZVOUS.HTTP:
http_server_d['running'] = False

def _generate_role(self):
def _generate_role(self) -> None:
"""
generate role for role maker
"""
Expand Down Expand Up @@ -1217,13 +1225,18 @@ class UserDefinedRoleMaker(PaddleCloudRoleMaker):
... server_endpoints=["127.0.0.1:36011", "127.0.0.1:36012"])
"""

def __init__(self, is_collective=False, init_gloo=False, **kwargs):
def __init__(
self,
is_collective: bool = False,
init_gloo: bool = False,
**kwargs: Any,
) -> None:
super().__init__(
is_collective=is_collective, init_gloo=init_gloo, **kwargs
)
self._init_gloo = init_gloo

def _user_defined_ps_env(self):
def _user_defined_ps_env(self) -> None:
self._server_endpoints = self._kwargs.get("server_endpoints")
self._worker_endpoints = self._kwargs.get("worker_endpoints", [])
self._trainers_num = self._kwargs.get("worker_num", 0)
Expand All @@ -1244,14 +1257,14 @@ def _user_defined_ps_env(self):
self._cur_endpoint = self._server_endpoints[self._current_id]
self._nodes_num = len({x.split(':')[0] for x in self._worker_endpoints})

def _user_defined_collective_env(self):
def _user_defined_collective_env(self) -> None:
self._worker_endpoints = self._kwargs.get("worker_endpoints")
self._current_id = self._kwargs.get("current_id")
self._trainers_num = len(self._worker_endpoints)
self._training_role = Role.WORKER
self._nodes_num = len({x.split(':')[0] for x in self._worker_endpoints})

def _generate_role(self):
def _generate_role(self) -> None:
"""
generate role for role maker
"""
Expand Down
Loading