Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][C-41, C-42][BUAA] Add type annotations for python/paddle/distributed/fleet/base/* #67439

Merged
merged 11 commits into from
Aug 19, 2024

Conversation

Whsjrczr
Copy link
Contributor

@Whsjrczr Whsjrczr commented Aug 14, 2024

PR Category

User Experience

PR Types

Improvements

Description

类型标注:

  • python/paddle/distributed/fleet/base/role_maker.py
  • python/paddle/distributed/fleet/base/topology.py

Related links

@SigureMo @megemini

Copy link

paddle-bot bot commented Aug 14, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@luotao1 luotao1 added contributor External developers HappyOpenSource Pro 进阶版快乐开源活动,更具挑战性的任务 labels Aug 14, 2024
risemeup1
risemeup1 previously approved these changes Aug 15, 2024
@luotao1
Copy link
Contributor

luotao1 commented Aug 15, 2024

@megemini 顺师傅再看下

self._gloo.barrier(comm_world)

def _all_gather(self, input, comm_world="worker"):
def _all_gather(self, input: Any, comm_world="worker") -> task:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comm_world 没有标注,另外,为什么返回 tasklist[float]

def _all_reduce(self, input, mode="sum", comm_world="worker"):
def _all_reduce(
self, input: Any, mode: str = "sum", comm_world: str = "worker"
) -> task:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
) -> task:
) -> npt.NDArray[Any]:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Whsjrczr 这里没修改

另外,上面是 list[float]

"""
return the heter device type that current heter worker is using
"""
if not self._role_is_generated:
self._generate_role()
return self._heter_trainer_device_type

def _get_stage_id(self):
def _get_stage_id(self) -> int | str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _get_stage_id(self) -> int | str:
def _get_stage_id(self) -> int:

"""
return stage id of current heter worker
"""
if not self._role_is_generated:
self._generate_role()
return self._stage_id

def _get_stage_trainers(self):
def _get_stage_trainers(self) -> list:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _get_stage_trainers(self) -> list:
def _get_stage_trainers(self) -> list[int]:

"""
whether current process is worker of rank 0
"""
if not self._role_is_generated:
self._generate_role()
return self._role == Role.WORKER and self._current_id == 0

def _worker_index(self):
def _worker_index(self) -> int | str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _worker_index(self) -> int | str:
def _worker_index(self) -> int:

return self._dp_degree

def get_data_parallel_group(self):
def get_data_parallel_group(self) -> list[int]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def get_data_parallel_group(self) -> list[int]:
def get_data_parallel_group(self) -> Group:

def get_check_parallel_group(self, sharding=False):
def get_check_parallel_group(
self, sharding: bool = False
) -> list[int] | Group:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
) -> list[int] | Group:
) -> Group:

Comment on lines 574 to 576
def create_fuse_group(
self, fused_strategy_list: list[int]
) -> list[int] | int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def create_fuse_group(
self, fused_strategy_list: list[int]
) -> list[int] | int:
def create_fuse_group(
self, fused_strategy_list: list[str]
) -> tuple[list[list[int]], list[Group]] | tuple[list[int], Group]:

return self._topo.get_rank_from_stage(
self.global_rank, pipe=stage_id, **kwargs
)

# fuse comm group message
def get_dp_sep_parallel_group(self):
def get_dp_sep_parallel_group(self) -> list[int]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def get_dp_sep_parallel_group(self) -> list[int]:
def get_dp_sep_parallel_group(self) -> Group:

self._check_sep_exist()
return self._dp_sep_comm_group

def get_pp_mp_parallel_group(self):
def get_pp_mp_parallel_group(self) -> list[int]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def get_pp_mp_parallel_group(self) -> list[int]:
def get_pp_mp_parallel_group(self) -> Group:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

谢谢,麻烦了,已更新

Copy link
Contributor

@megemini megemini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上次有地方没修改完,再检查一下

@Whsjrczr Whsjrczr changed the title [Typing][C-41, C-42][BUAA] Add type annotations for 2 files in python/paddle/distributed/fleet/base/ [Typing][C-41, C-42][BUAA] Add type annotations for python/paddle/distributed/fleet/base/* Aug 16, 2024
megemini

This comment was marked as outdated.


import numpy as np
import numpy.typing as npt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

放到 TYPE_CHECKING 里面

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

Copy link
Contributor

@megemini megemini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ~

Copy link
Member

@SigureMo SigureMo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTMeow 🐾

@SigureMo SigureMo merged commit b07423c into PaddlePaddle:develop Aug 19, 2024
27 of 28 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers HappyOpenSource Pro 进阶版快乐开源活动,更具挑战性的任务
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants