-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Typing][C-41, C-42][BUAA] Add type annotations for python/paddle/distributed/fleet/base/*
#67439
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
@megemini 顺师傅再看下 |
self._gloo.barrier(comm_world) | ||
|
||
def _all_gather(self, input, comm_world="worker"): | ||
def _all_gather(self, input: Any, comm_world="worker") -> task: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comm_world
没有标注,另外,为什么返回 task
? list[float]
def _all_reduce(self, input, mode="sum", comm_world="worker"): | ||
def _all_reduce( | ||
self, input: Any, mode: str = "sum", comm_world: str = "worker" | ||
) -> task: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
) -> task: | |
) -> npt.NDArray[Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Whsjrczr 这里没修改
另外,上面是 list[float]
""" | ||
return the heter device type that current heter worker is using | ||
""" | ||
if not self._role_is_generated: | ||
self._generate_role() | ||
return self._heter_trainer_device_type | ||
|
||
def _get_stage_id(self): | ||
def _get_stage_id(self) -> int | str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _get_stage_id(self) -> int | str: | |
def _get_stage_id(self) -> int: |
""" | ||
return stage id of current heter worker | ||
""" | ||
if not self._role_is_generated: | ||
self._generate_role() | ||
return self._stage_id | ||
|
||
def _get_stage_trainers(self): | ||
def _get_stage_trainers(self) -> list: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _get_stage_trainers(self) -> list: | |
def _get_stage_trainers(self) -> list[int]: |
""" | ||
whether current process is worker of rank 0 | ||
""" | ||
if not self._role_is_generated: | ||
self._generate_role() | ||
return self._role == Role.WORKER and self._current_id == 0 | ||
|
||
def _worker_index(self): | ||
def _worker_index(self) -> int | str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def _worker_index(self) -> int | str: | |
def _worker_index(self) -> int: |
return self._dp_degree | ||
|
||
def get_data_parallel_group(self): | ||
def get_data_parallel_group(self) -> list[int]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def get_data_parallel_group(self) -> list[int]: | |
def get_data_parallel_group(self) -> Group: |
def get_check_parallel_group(self, sharding=False): | ||
def get_check_parallel_group( | ||
self, sharding: bool = False | ||
) -> list[int] | Group: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
) -> list[int] | Group: | |
) -> Group: |
def create_fuse_group( | ||
self, fused_strategy_list: list[int] | ||
) -> list[int] | int: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def create_fuse_group( | |
self, fused_strategy_list: list[int] | |
) -> list[int] | int: | |
def create_fuse_group( | |
self, fused_strategy_list: list[str] | |
) -> tuple[list[list[int]], list[Group]] | tuple[list[int], Group]: |
return self._topo.get_rank_from_stage( | ||
self.global_rank, pipe=stage_id, **kwargs | ||
) | ||
|
||
# fuse comm group message | ||
def get_dp_sep_parallel_group(self): | ||
def get_dp_sep_parallel_group(self) -> list[int]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def get_dp_sep_parallel_group(self) -> list[int]: | |
def get_dp_sep_parallel_group(self) -> Group: |
self._check_sep_exist() | ||
return self._dp_sep_comm_group | ||
|
||
def get_pp_mp_parallel_group(self): | ||
def get_pp_mp_parallel_group(self) -> list[int]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def get_pp_mp_parallel_group(self) -> list[int]: | |
def get_pp_mp_parallel_group(self) -> Group: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
谢谢,麻烦了,已更新
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
上次有地方没修改完,再检查一下
python/paddle/distributed/fleet/base/
python/paddle/distributed/fleet/base/*
Co-authored-by: megemini <[email protected]>
|
||
import numpy as np | ||
import numpy.typing as npt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
放到 TYPE_CHECKING
里面
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM ~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR Category
User Experience
PR Types
Improvements
Description
类型标注:
Related links
@SigureMo @megemini