From 1c6879e182711e60f919d4efc9d7f44b1c30740c Mon Sep 17 00:00:00 2001 From: megemini Date: Wed, 14 Aug 2024 15:49:16 +0800 Subject: [PATCH] [Add] typing --- .../paddle/distributed/auto_parallel/api.py | 171 ++++++++++++------ .../distributed/auto_parallel/constants.py | 164 +++++++++++++++++ .../distributed/auto_parallel/strategy.py | 34 ++++ python/paddle/distributed/rpc/rpc.py | 43 ++++- python/paddle/distributed/spawn.py | 24 ++- 5 files changed, 376 insertions(+), 60 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index 7bdc2d9ec4f14..389fd19437cbe 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -15,7 +15,7 @@ import copy from types import MethodType -from typing import Callable +from typing import TYPE_CHECKING, Any, Literal, TypedDict import numpy as np @@ -63,6 +63,51 @@ from .placement_type import check_placements_equal, get_shard_spec from .random import determinate_rng, rng_state +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + from typing_extensions import TypeAlias + + from paddle import Tensor + from paddle._typing import ( + DTypeLike, + NestedNumbericSequence, + PlaceLike, + TensorLike, + ) + from paddle.amp import GradScaler + from paddle.base.framework import Program + from paddle.distributed import Placement + from paddle.io import DataLoader + from paddle.metric import Metric + from paddle.nn import Layer + + from .constants import ( + _AMPConfig, + _DPOptimizationConfig, + _FusedPassesConfig, + _GradientMergeConfig, + _MPOptimizationConfig, + _PipelineConfig, + _RecomputeConfig, + _ShardingConfig, + _SPOptimizationConfig, + ) + + _Mode: TypeAlias = Literal['train', 'eval', 'predict'] + + class _Config(TypedDict, total=False): + sharding: _ShardingConfig + fused_passes: _FusedPassesConfig + gradient_merge: _GradientMergeConfig + pipeline: _PipelineConfig + amp: _AMPConfig + recompute: _RecomputeConfig + mp_optimization: _MPOptimizationConfig + dp_optimization: _DPOptimizationConfig + sp_optimization: _SPOptimizationConfig + + # There are the auto parallel API of the unified version of dynamic and static mode. # Some APIs have the same name with the previous APIs implementation, which are # a temporary state, and the APIs here will eventually be used. @@ -132,8 +177,13 @@ def sharding_specs(self): def shard_tensor( - data, mesh, placements, dtype=None, place=None, stop_gradient=None -): + data: Tensor | TensorLike | NestedNumbericSequence, + mesh: ProcessMesh, + placements: list[Placement], + dtype: DTypeLike | None = None, + place: PlaceLike | None = None, + stop_gradient: bool | None = None, +) -> Tensor: """ Creates a distributed Tensor (i.e., Tensor with distributed attributes or DistTensor for short) from the input data, which can be a scalar, tuple, list, numpy.ndarray, or paddle.Tensor. @@ -582,7 +632,13 @@ def dtensor_from_local(local_tensor, mesh, placements): ) -def dtensor_from_fn(fn, mesh, placements, *args, **kwargs): +def dtensor_from_fn( + fn: Callable[..., Tensor], + mesh: ProcessMesh, + placements: list[Placement], + *args: Any, + **kwargs: Any, +) -> Tensor: """ Construct a Distributed Tensor from a function of arguments. @@ -616,7 +672,9 @@ def dtensor_from_fn(fn, mesh, placements, *args, **kwargs): # Part3: Data conversion related APIs -def reshard(dist_tensor, mesh, placements): +def reshard( + dist_tensor: Tensor, mesh: ProcessMesh, placements: list[Placement] +) -> Tensor: """ Reshard a distributed ``paddle.Tensor`` with given distributed attributes. @@ -716,12 +774,12 @@ def reshard(dist_tensor, mesh, placements): def shard_layer( - layer: nn.Layer, + layer: Layer, process_mesh: ProcessMesh, - shard_fn: Callable | None = None, - input_fn: Callable | None = None, - output_fn: Callable | None = None, -) -> nn.Layer: + shard_fn: Callable[[str, Layer, ProcessMesh], None] | None = None, + input_fn: Callable[[Any, ProcessMesh], list[Tensor]] | None = None, + output_fn: Callable[[Any, ProcessMesh], list[Tensor]] | None = None, +) -> Layer: """ Converts all layer's parameters to DistTensor parameters according to the `shard_fn` specified. It could also control the conversion of input @@ -1186,10 +1244,10 @@ class ShardingStage1(_ShardingStageBase): >>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py """ - def __init__(self, mesh=None): + def __init__(self, mesh: ProcessMesh | None = None) -> None: super().__init__(mesh) - def __call__(self, key, param, accumulator): + def __call__(self, key: str, param: Tensor, accumulator: Tensor) -> Tensor: if param.is_dist(): # Only deal with momentum in optimizer, beta should be replicated cross param's mesh if 'beta' not in key: @@ -1247,10 +1305,10 @@ class ShardingStage2(_ShardingStageBase): >>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py """ - def __init__(self, mesh=None): + def __init__(self, mesh: ProcessMesh | None = None) -> None: super().__init__(mesh) - def __call__(self, key, param, accumulator): + def __call__(self, key: str, param: Tensor, accumulator: Tensor) -> Tensor: if param.is_dist(): # Only deal with momentum in optimizer, beta should be replicated cross param's mesh if 'beta' not in key: @@ -1333,7 +1391,7 @@ class ShardingStage3(_ShardingStageBase): >>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py """ - def __init__(self, mesh=None): + def __init__(self, mesh: ProcessMesh | None = None) -> None: super().__init__(mesh) def _shard_parameter(self, param): @@ -1361,7 +1419,7 @@ def _unshard_parameter(self, param): new_param = dist.reshard(param, param.process_mesh, new_placements) param.get_tensor()._share_data_with(new_param.get_tensor()) - def __call__(self, key, param, accumulator): + def __call__(self, key: str, param: Tensor, accumulator: Tensor) -> Tensor: if param.is_dist(): # Only deal with momentum in optimizer, beta should be replicated cross param's mesh if 'beta' not in key: @@ -1387,7 +1445,10 @@ def __call__(self, key, param, accumulator): return accumulator -def shard_optimizer(optimizer, shard_fn=None): +def shard_optimizer( + optimizer: Optimizer, + shard_fn: Callable[[str, Layer, ProcessMesh], None] | None = None, +) -> _ShardOptimizer: """ Warp the global view optimizer to distributed view. @@ -1434,7 +1495,7 @@ def shard_fn(accumulator_name, param, accumulator) -> sharded_accumulator return _ShardOptimizer(optimizer, shard_fn) -def shard_scaler(scaler): +def shard_scaler(scaler: GradScaler) -> GradScaler: """ Warp the global view grad_scaler to distributed view. @@ -1605,6 +1666,10 @@ class FusePasses: A helper class for users to configure the fuse passes. """ + enable: bool + gemm_epilogue: bool + dropout_add: bool + def __init__(self, config_dict=None): self.enable = False self.gemm_epilogue = False @@ -1662,7 +1727,7 @@ class Strategy(auto_strategy.BaseConfig): >>> strategy.pipeline.micro_batch_size = 2 """ - def __init__(self, config=None): + def __init__(self, config: _Config | None = None) -> None: if config is not None: if isinstance(config, dict): self._config_dict = copy.deepcopy(config) @@ -1754,7 +1819,7 @@ def _from_legacy_strategy(self, legacy_strategy): self._sp_optimization = copy.deepcopy(legacy_strategy.sp_optimization) @property - def sharding(self): + def sharding(self) -> auto_strategy.ShardingConfig: """ ``sharding`` is used to configure the sharding states of the optimizer, containing following configs: @@ -1782,7 +1847,7 @@ def sharding(self): return self._sharding @property - def gradient_merge(self): + def gradient_merge(self) -> auto_strategy.GradientMergeConfig: """ ``gradient_merge`` is used to configure the gradient merge strategy in training, containing following configs: @@ -1808,7 +1873,7 @@ def gradient_merge(self): return self._gradient_merge @property - def fused_passes(self): + def fused_passes(self) -> FusePasses: """ ``fused_passes`` is used to configure the fusion of the computation in the model, containing following configs: @@ -1835,7 +1900,7 @@ def fused_passes(self): return self._fused_passes @property - def pipeline(self): + def pipeline(self) -> auto_strategy.PipelineConfig: """ ``pipeline`` is used to configure the pipeline parallelism, containing following configs: @@ -1862,7 +1927,7 @@ def pipeline(self): return self._pipeline @property - def amp(self): + def amp(self) -> auto_strategy.AMPConfig: """ ``amp`` is used to configure the amp, containing following configs: @@ -1937,13 +2002,13 @@ class DistModel: def __init__( self, - layer, - loader, - loss=None, - optimizer=None, - strategy=None, - metrics=None, - ): + layer: Layer, + loader: ShardDataloader | DataLoader, + loss: Layer | Callable[..., Any] | None = None, + optimizer: Optimizer | None = None, + strategy: Strategy | None = None, + metrics: list[Metric] | None = None, + ) -> None: self._feed_name_list = [] self._inner_strategy = self.__convert_strategy(strategy) self._structured_to_parameter_name = { @@ -1990,7 +2055,7 @@ def __init__( else: self.predict() - def train(self): + def train(self) -> None: """ Set the DistModel to "train" mode. In "train" mode, executing ``__call__`` method will update the @@ -2003,7 +2068,7 @@ def train(self): self._engine.to_mode("train") paddle.disable_static() - def eval(self): + def eval(self) -> None: """ Set the mode of DistModel to "eval". In "eval" mode, executing ``__call__`` will return the loss. @@ -2015,7 +2080,7 @@ def eval(self): self._engine.to_mode("eval") paddle.disable_static() - def predict(self): + def predict(self) -> None: """ Set the mode of DistModel to "predict". In "predict" mode, executing ``__call__`` returns a dict that contains the @@ -2044,7 +2109,7 @@ def __validate_mode(self, mode): raise ValueError("mode can only be 'train', 'eval' or 'predict'.") return mode - def dist_main_program(self, mode=None): + def dist_main_program(self, mode: _Mode | None = None) -> Program: """ Get the distributed main program of the specified ``mode``. Each 'mode' has its own distributed main program, ``dist_main_program`` @@ -2064,7 +2129,7 @@ def dist_main_program(self, mode=None): mode = self.__validate_mode(mode) return self._engine.get_dist_main_program(mode) - def dist_startup_program(self, mode=None): + def dist_startup_program(self, mode: _Mode | None = None) -> Program: """ Get the corresponding distributed startup program of ``mode``, which is used for initializing the parameters. @@ -2083,7 +2148,7 @@ def dist_startup_program(self, mode=None): mode = self.__validate_mode(mode) return self._engine.get_dist_startup_program(mode) - def serial_main_program(self, mode=None): + def serial_main_program(self, mode: _Mode | None = None) -> Program: """ Get the corresponding serial main program of ``mode``, containing the whole variables and operators of the given ``layer``. @@ -2102,7 +2167,7 @@ def serial_main_program(self, mode=None): mode = self.__validate_mode(mode) return self._engine.get_serial_main_program(mode) - def serial_startup_program(self, mode=None): + def serial_startup_program(self, mode: _Mode | None = None) -> Program: """ Get the corresponding serial startup program of ``mode``. @@ -2217,7 +2282,7 @@ def __convert_strategy(self, strategy): return inner_strategy @switch_to_static_graph - def __call__(self, *args): + def __call__(self, *args: Sequence[Any] | Tensor) -> Any: if self._mode is None: raise ValueError("Please call train()/eval()/predict() first.") if self._mode == "train": @@ -2270,7 +2335,9 @@ def _fetch_value(self, value, name=None): name = len(self._engine._pir_fetch_values) - 1 self._engine._pir_user_defined_fetch_names.append(name) - def state_dict(self, mode="all"): + def state_dict( + self, mode: Literal['opt', 'param', 'all'] = "all" + ) -> dict[str, Tensor]: """ Get the state dict of model and optimizer. @@ -2358,7 +2425,7 @@ def build_distributed_tensor(local_tensor, dist_attr): ) return global_state_dict - def set_state_dict(self, state_dict): + def set_state_dict(self, state_dict: dict[str, Tensor]) -> None: local_state_dict = {} dist_main_program = self.dist_main_program(mode=self._engine._mode) cur_state_dict = self.state_dict() @@ -2381,12 +2448,12 @@ def set_state_dict(self, state_dict): def to_static( - layer: paddle.nn.Layer, - loader=None, - loss=None, - optimizer=None, - strategy=None, -): + layer: Layer, + loader: ShardDataloader | DataLoader | None = None, + loss: Layer | Callable[..., Any] | None = None, + optimizer: Optimizer | None = None, + strategy: Strategy | None = None, +) -> DistModel: """ Converts the ``layer`` with distributed tensor (constructed from ``paddle.distributed.shard_tensor``) to a static graph. ``to_static`` @@ -2541,7 +2608,7 @@ def to_static( return dist_model -def unshard_dtensor(dist_tensor): +def unshard_dtensor(dist_tensor: Tensor) -> Tensor: """ Converts a distributed tensor to a dense tensor. ``unshard_dtensor`` first make the ``dist_tensor`` be ``Replicate`` state on all processes and @@ -2887,10 +2954,10 @@ def __call__(self): def shard_dataloader( - dataloader: paddle.io.DataLoader, - meshes: ProcessMesh | list[ProcessMesh] | tuple[ProcessMesh], - input_keys: list[str] | tuple[str] | None = None, - shard_dims: list | tuple | str | int | None = None, + dataloader: DataLoader, + meshes: ProcessMesh | Sequence[ProcessMesh], + input_keys: Sequence[str] | None = None, + shard_dims: Sequence[str] | Sequence[int] | str | int | None = None, is_dataset_splitted: bool = False, ) -> ShardDataloader: """ diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py index 1be52f0e3594f..c433a3395d0ce 100644 --- a/python/paddle/distributed/auto_parallel/constants.py +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -12,7 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License +from __future__ import annotations + from collections import defaultdict +from typing import TYPE_CHECKING, TypedDict + +if TYPE_CHECKING: + from paddle import Tensor + from paddle._typing.dtype_like import _DTypeLiteral # _g_default_config[category][field] = default_value _g_default_config = defaultdict(dict) @@ -50,6 +57,20 @@ def set_field_default_config(category, field, default_value): set_field_default_config(BASE, "seed", None) set_field_default_config(BASE, "reinit", False) # Only for debug +if TYPE_CHECKING: + + class _BaseConfig(TypedDict, total=False): # noqa: PYI049 + auto_mode: str + gradient_scale: bool + gradient_scale_using_allreduce_avg: bool + use_cache: bool + return_numpy: bool + all_ranks: bool + split_data: bool + seed: int | None + reinit: bool + + ######################################### # recompute configuration ######################################### @@ -61,6 +82,23 @@ def set_field_default_config(category, field, default_value): set_field_default_config(RECOMPUTE, "refined_ops_patterns", []) # List[Dict] set_field_default_config(RECOMPUTE, "enable_tuning", False) +if TYPE_CHECKING: + + class _RefinedOpsPatterns(TypedDict, total=False): + main_ops: list[str] + num: int + pre_ops: list[str] + suf_ops: list[str] + + class _RecomputeConfig(TypedDict, total=False): # noqa: PYI049 + enable: bool + checkpoints: list[Tensor] + no_recompute_segments: list[int] + sr: int + refined_ops_patterns: list[_RefinedOpsPatterns] + enable_tuning: bool + + ######################################### # AMP configuration ######################################### @@ -82,6 +120,27 @@ def set_field_default_config(category, field, default_value): set_field_default_config(AMP, "use_master_grad", False) set_field_default_config(AMP, "use_promote", False) +if TYPE_CHECKING: + + class _AMPConfig(TypedDict, total=False): # noqa: PYI049 + enable: bool + dtype: _DTypeLiteral + level: str + init_loss_scaling: float + incr_every_n_steps: int + decr_every_n_nan_or_inf: int + incr_ratio: float + decr_ratio: float + use_dynamic_loss_scaling: bool + custom_white_list: list[str] + custom_black_list: list[str] + custom_black_varnames: list[str] + use_fp16_guard: bool + use_bf16_guard: bool + use_master_grad: bool + use_promote: bool + + ######################################### # sharding configuration ######################################### @@ -99,6 +158,23 @@ def set_field_default_config(category, field, default_value): set_field_default_config(SHARDING, "enable_tuning", False) set_field_default_config(SHARDING, "tuning_range", []) +if TYPE_CHECKING: + + class _ShardingConfig(TypedDict, total=False): # noqa: PYI049 + enable: bool + stage: int + degree: int + enable_overlap: bool + param_comm_stream_num: int + grad_comm_stream_num: int + param_bucket_size_numel: int + grad_bucket_size_numel: int + enable_hierarchical_comm: bool + partition_algor: str + enable_tuning: bool + tuning_range: list[int] | tuple[int, int] + + ######################################### # gradient merge configuration ######################################### @@ -107,6 +183,14 @@ def set_field_default_config(category, field, default_value): set_field_default_config(GRADIENT_MERGE, "k_steps", 1) set_field_default_config(GRADIENT_MERGE, "avg", True) +if TYPE_CHECKING: + + class _GradientMergeConfig(TypedDict, total=False): # noqa: PYI049 + enable: bool + k_steps: int + avg: bool + + ######################################### # pipeline configuration ######################################### @@ -124,6 +208,23 @@ def set_field_default_config(category, field, default_value): set_field_default_config(PIPELINE, "job_schedule_profiler_stop", -1) set_field_default_config(PIPELINE, "split_backward", False) +if TYPE_CHECKING: + + class _PipelineConfig(TypedDict, total=False): # noqa: PYI049 + enable: bool + schedule_mode: str + pp_degree: int + vpp_degree: int + vpp_seg_method: str + micro_batch_size: int + accumulate_steps: int + generation_batch_size: int + enable_send_recv_overlap: bool + job_schedule_profiler_start: int + job_schedule_profiler_stop: int + split_backward: bool + + ######################################### # quantization configuration ######################################### @@ -136,6 +237,18 @@ def set_field_default_config(category, field, default_value): set_field_default_config(QAT, "algo", None) set_field_default_config(QAT, "onnx_format", True) +if TYPE_CHECKING: + + class _QATConfig(TypedDict, total=False): # noqa: PYI049 + enable: bool + channel_wise_abs_max: bool + weight_bits: int + activation_bits: int + not_quant_pattern: list[str] + algo: str | None + onnx_format: bool + + ######################################### # auto tuning configuration ######################################### @@ -146,6 +259,16 @@ def set_field_default_config(category, field, default_value): set_field_default_config(TUNING, "run_after_tuning", True) set_field_default_config(TUNING, "debug", False) +if TYPE_CHECKING: + + class _TuningConfig(TypedDict, total=False): # noqa: PYI049 + enable: bool + profile_start_step: int + profile_end_step: int + run_after_tuning: bool + debug: bool + + ######################################### # dataset configuration ######################################### @@ -153,12 +276,25 @@ def set_field_default_config(category, field, default_value): set_field_default_config(DATASET, "enable", False) set_field_default_config(DATASET, "num_shards", 1) +if TYPE_CHECKING: + + class _DatasetConfig(TypedDict, total=False): # noqa: PYI049 + enable: bool + num_shards: int + + # ######################################### # # offload configuration # ######################################### FUSEDLINEARPROMOTION = "fused_linear_promotion" set_field_default_config(FUSEDLINEARPROMOTION, "enable", False) +if TYPE_CHECKING: + + class _FusedLinearPromotionConfig(TypedDict, total=False): # noqa: PYI049 + enable: bool + + ######################################### # fused passes configuration ######################################### @@ -166,6 +302,13 @@ def set_field_default_config(category, field, default_value): set_field_default_config(FUSED_PASSES, "enable", False) set_field_default_config(FUSED_PASSES, "fused_passes_list", []) +if TYPE_CHECKING: + + class _FusedPassesConfig(TypedDict, total=False): # noqa: PYI049 + enable: bool + fused_passes_list: list[str] + + ######################################### # data parallel configuration ######################################### @@ -178,6 +321,16 @@ def set_field_default_config(category, field, default_value): DP_OPTIMIZATION, "gradient_sync_after_accumulate", False ) +if TYPE_CHECKING: + + class _DPOptimizationConfig(TypedDict, total=False): # noqa: PYI049 + enable: bool + fuse_all_reduce_ops: bool + fuse_grad_size_in_MB: int + overlap_comm_cacl: bool + gradient_sync_after_accumulate: bool + + ######################################### # model parallel configuration ######################################### @@ -186,8 +339,19 @@ def set_field_default_config(category, field, default_value): MP_OPTIMIZATION, "allreduce_matmul_grad_overlapping", False ) +if TYPE_CHECKING: + + class _MPOptimizationConfig(TypedDict, total=False): # noqa: PYI049 + allreduce_matmul_grad_overlapping: bool + + ######################################### # sequence parallel configuration ######################################### SP_OPTIMIZATION = "sp_optimization" set_field_default_config(SP_OPTIMIZATION, "enable", True) + +if TYPE_CHECKING: + + class _SPOptimizationConfig(TypedDict, total=False): # noqa: PYI049 + enable: bool diff --git a/python/paddle/distributed/auto_parallel/strategy.py b/python/paddle/distributed/auto_parallel/strategy.py index 7f026827d3446..b256757869b9e 100644 --- a/python/paddle/distributed/auto_parallel/strategy.py +++ b/python/paddle/distributed/auto_parallel/strategy.py @@ -11,11 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License +from __future__ import annotations import copy +from typing import TYPE_CHECKING from . import constants +if TYPE_CHECKING: + from paddle._typing.dtype_like import _DTypeLiteral + class BaseConfig: def __init__(self, category, config_dict=None): @@ -89,24 +94,53 @@ def __init__(self, config_dict=None): class AMPConfig(BaseConfig): + enable: bool + dtype: _DTypeLiteral + level: str + init_loss_scaling: float + incr_every_n_steps: int + decr_every_n_nan_or_inf: int + incr_ratio: float + decr_ratio: float + use_dynamic_loss_scaling: bool + custom_white_list: list[str] + custom_black_list: list[str] + custom_black_varnames: list[str] + use_fp16_guard: bool + use_bf16_guard: bool + use_master_grad: bool + def __init__(self, config_dict=None): category = constants.AMP super().__init__(category, config_dict) class ShardingConfig(BaseConfig): + enable: bool + stage: int + degree: int + def __init__(self, config_dict=None): category = constants.SHARDING super().__init__(category, config_dict) class GradientMergeConfig(BaseConfig): + enable: bool + k_steps: int + avg: bool + def __init__(self, config_dict=None): category = constants.GRADIENT_MERGE super().__init__(category, config_dict) class PipelineConfig(BaseConfig): + enable: bool + schedule_mode: str + micro_batch_size: int + accumulate_steps: int + def __init__(self, config_dict=None): category = constants.PIPELINE super().__init__(category, config_dict) diff --git a/python/paddle/distributed/rpc/rpc.py b/python/paddle/distributed/rpc/rpc.py index 8018fd0741da5..4b6e80c8320df 100644 --- a/python/paddle/distributed/rpc/rpc.py +++ b/python/paddle/distributed/rpc/rpc.py @@ -12,17 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import datetime import os import pickle import time from collections import namedtuple +from typing import TYPE_CHECKING, Any, Protocol, TypeVar from paddle.base import core from paddle.distributed.launch.context import Node from paddle.distributed.rpc.internal import PythonFunc, _serialize from paddle.distributed.utils.launch_utils import logger +if TYPE_CHECKING: + from collections.abc import Callable + + _RetT = TypeVar("_RetT", covariant=True) + + class _FutureWrapper(Protocol[_RetT]): + def wait(self) -> _RetT: ... + + WorkerInfo = namedtuple("WorkerInfo", ["name", "rank", "ip", "port"]) _DEFAULT_RPC_TIMEOUT = -1 @@ -70,7 +82,12 @@ def _gen_endpoint(): return f"{ip}:{free_port}" -def init_rpc(name, rank=None, world_size=None, master_endpoint=None): +def init_rpc( + name: str, + rank: int | None = None, + world_size: int | None = None, + master_endpoint: str | None = None, +) -> None: """ init rpc. @@ -140,7 +157,13 @@ def init_rpc(name, rank=None, world_size=None, master_endpoint=None): logger.info(f"Trainer {rank}: Init RPC done!") -def rpc_sync(to, fn, args=None, kwargs=None, timeout=_DEFAULT_RPC_TIMEOUT): +def rpc_sync( + to: str, + fn: Callable[..., _RetT], + args: tuple[Any, ...] | None = None, + kwargs: dict[str, Any] | None = None, + timeout: int = _DEFAULT_RPC_TIMEOUT, +) -> _RetT: """ Make a blocking RPC call to run function ``fn`` on worker ``to``. Attention: Users must use this API in a secure network environment. @@ -180,7 +203,13 @@ def rpc_sync(to, fn, args=None, kwargs=None, timeout=_DEFAULT_RPC_TIMEOUT): return fut.wait() -def rpc_async(to, fn, args=None, kwargs=None, timeout=_DEFAULT_RPC_TIMEOUT): +def rpc_async( + to: str, + fn: Callable[..., _RetT], + args: tuple[Any, ...] | None = None, + kwargs: dict[str, Any] | None = None, + timeout: int = _DEFAULT_RPC_TIMEOUT, +) -> _FutureWrapper[_RetT]: """ Make a non-blocking RPC call to run function ``fn`` on worker ``to``. Attention: Users must use this API in a secure network environment. @@ -273,7 +302,7 @@ def _check_keys_ready(wait_keys): _barrier_store.add(barrier_prefix + str(global_rank), 1) -def shutdown(): +def shutdown() -> None: """ Perform a shutdown of the RPC agent, stop the worker and destroy the agent. This will block until all local and remote RPC processes reach this method @@ -304,7 +333,7 @@ def shutdown(): logger.info(f"Trainer {rank}: rpc shutdown!") -def get_worker_info(name): +def get_worker_info(name: str) -> WorkerInfo: """ Get worker information by worker name. @@ -334,7 +363,7 @@ class `WorkerInfo` with attribute `name`, `rank`, `ip` and `port`. return core.rpc_get_worker_info(name) -def get_all_worker_infos(): +def get_all_worker_infos() -> list[WorkerInfo]: """ Get all worker informations. @@ -361,7 +390,7 @@ def get_all_worker_infos(): return core.rpc_get_all_worker_infos() -def get_current_worker_info(): +def get_current_worker_info() -> WorkerInfo: """ Get current worker information. diff --git a/python/paddle/distributed/spawn.py b/python/paddle/distributed/spawn.py index f1b77cfaf91a6..9fd9bfe56cd39 100644 --- a/python/paddle/distributed/spawn.py +++ b/python/paddle/distributed/spawn.py @@ -12,11 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import multiprocessing import os import signal import sys import warnings +from typing import TYPE_CHECKING, Any, Literal, TypedDict # deprecated module import # (TODO: GhostScreaming) It will be removed later. @@ -40,6 +43,18 @@ ) from paddle.framework import set_flags +if TYPE_CHECKING: + from collections.abc import Callable, Iterable + + from typing_extensions import NotRequired, Unpack + + class _SpawnOptions(TypedDict): + start_method: NotRequired[Literal['spawn', 'fork', 'forkserver']] + gpus: NotRequired[str | None] + xpus: NotRequired[str | None] + ips: NotRequired[str] + + __all__ = [] @@ -445,7 +460,14 @@ def _throw_exception(self, error_index): raise Exception(msg) -def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options): +def spawn( + func: Callable[..., None], + args: Iterable[Any] = (), + nprocs: int = -1, + join: bool = True, + daemon: bool = False, + **options: Unpack[_SpawnOptions], +) -> MultiprocessContext: """ Start multiple processes with ``spawn`` method for parallel training.