From ec7db1d4cf4efa5b32ed579109a177012e4bba7a Mon Sep 17 00:00:00 2001 From: ooo <21421006@buaa.edu.cn> Date: Wed, 14 Aug 2024 14:23:21 +0800 Subject: [PATCH 1/3] fixed --- .../paddle/distributed/auto_parallel/api.py | 297 ++++++++++++------ python/paddle/distributed/rpc/rpc.py | 46 ++- python/paddle/distributed/spawn.py | 48 ++- 3 files changed, 266 insertions(+), 125 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index 7bdc2d9ec4f14..7ef49180459e6 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -15,7 +15,7 @@ import copy from types import MethodType -from typing import Callable +from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence import numpy as np @@ -63,6 +63,30 @@ from .placement_type import check_placements_equal, get_shard_spec from .random import determinate_rng, rng_state +if TYPE_CHECKING: + from core.op_proto_and_checker_maker.OpRole import Loss + from typing_extensions import TypeAlias + + from paddle.amp import GradScaler + from paddle.base.framework import Block + from paddle.metric import Metric + + FusePasses_attr: TypeAlias = Literal[ + 'enable', 'gemm_epilogue', 'drogemm_epiloguepout_add' + ] + ConfigAttr: TypeAlias = Literal[ + auto_strategy.constants.SHARDING, + auto_strategy.constants.GRADIENT_MERGE, + auto_strategy.constants.PIPELINE, + auto_strategy.constants.AMP, + auto_strategy.constants.FUSED_PASSES, + auto_strategy.constants.RECOMPUTE, + auto_strategy.constants.MP_OPTIMIZATION, + auto_strategy.constants.DP_OPTIMIZATION, + auto_strategy.constants.SP_OPTIMIZATION, + ] + + # There are the auto parallel API of the unified version of dynamic and static mode. # Some APIs have the same name with the previous APIs implementation, which are # a temporary state, and the APIs here will eventually be used. @@ -91,7 +115,9 @@ class DistAttr(core.TensorDistAttr): """ - def __init__(self, mesh, sharding_specs): + def __init__( + self, mesh: ProcessMesh, sharding_specs: list[str | None] + ) -> None: # 1. inputs checking if not isinstance(mesh, core.ProcessMesh): raise ValueError( @@ -132,8 +158,13 @@ def sharding_specs(self): def shard_tensor( - data, mesh, placements, dtype=None, place=None, stop_gradient=None -): + data: tuple | list | np.ndarray | paddle.Tensor, + mesh: ProcessMesh, + placements: list[dist.Placement], + dtype: str | np.dtype | None = None, + place: paddle.Place | str | None = None, + stop_gradient: bool | None = None, +) -> paddle.Tensor: """ Creates a distributed Tensor (i.e., Tensor with distributed attributes or DistTensor for short) from the input data, which can be a scalar, tuple, list, numpy.ndarray, or paddle.Tensor. @@ -274,14 +305,14 @@ def _init_func(var, block): class _moe_global_mesh_tensor(PyLayer): @staticmethod def forward( - ctx, - local_tensor_list, - local_mesh_list, - idx, - global_dims, - mesh, - placements, - ): + ctx: PyLayer.Context, + local_tensor_list: list[paddle.Tensor], + local_mesh_list: list[ProcessMesh], + idx: int, + global_dims: list[int], + mesh: ProcessMesh, + placements: list[dist.Placement], + ) -> paddle.Tensor: local_tensor = local_tensor_list[idx] if local_tensor.is_dist(): local_mesh = local_tensor.process_mesh @@ -312,7 +343,9 @@ def forward( return global_tensor @staticmethod - def backward(ctx, grad_tensor): + def backward( + ctx: PyLayer.Context, grad_tensor: paddle.Tensor + ) -> list[paddle.Tensor]: if ctx.local_mesh_list is None: return grad_tensor._local_value() else: @@ -334,8 +367,10 @@ def backward(ctx, grad_tensor): def get_sub_meshes_from_global_mesh( - global_mesh, global_placements, local_mesh_dim -): + global_mesh: ProcessMesh, + global_placements: list[dist.Placement], + local_mesh_dim: int, +) -> tuple[list[ProcessMesh], list[dist.Placement]]: if ( global_mesh is not None and local_mesh_dim is not None @@ -371,8 +406,11 @@ def get_sub_meshes_from_global_mesh( def moe_global_mesh_tensor( - local_tensor_list, mesh, placements, local_mesh_dim=-1 -): + local_tensor_list: list[paddle.Tensor], + mesh: ProcessMesh, + placements: list[dist.Placement], + local_mesh_dim: int = -1, +) -> paddle.Tensor: # assume the each rank has the same tensor shape for now, just use the local shape to calculate the global shape local_mesh_list, local_placements = get_sub_meshes_from_global_mesh( mesh, placements, local_mesh_dim @@ -431,13 +469,13 @@ def moe_global_mesh_tensor( class _moe_sub_mesh_tensors(PyLayer): @staticmethod def forward( - ctx, - dist_tensor, - local_mesh_list=None, - local_placements=None, - global_mesh=None, - global_placements=None, - ): + ctx: PyLayer.Context, + dist_tensor: paddle.Tensor, + local_mesh_list: list[ProcessMesh] | None = None, + local_placements: list[dist.Placement] | None = None, + global_mesh: ProcessMesh | None = None, + global_placements: list[dist.Placement] | None = None, + ) -> list[paddle.Tensor]: ctx.local_mesh_list = copy.deepcopy(local_mesh_list) ctx.local_placements = local_placements ctx.global_mesh = copy.deepcopy(global_mesh) @@ -485,7 +523,9 @@ def forward( return local_tensor_list @staticmethod - def backward(ctx, *grad_tensor): + def backward( + ctx: PyLayer.Context, *grad_tensor: paddle.Tensor + ) -> paddle.Tensor: place = paddle.framework._current_expected_place() place = paddle.framework._get_paddle_place(place) idx = ctx.global_mesh.process_ids.index(dist.get_rank()) @@ -501,8 +541,11 @@ def backward(ctx, *grad_tensor): def moe_sub_mesh_tensors( - dist_tensor, global_mesh=None, local_mesh_dim=None, global_placements=None -): + dist_tensor: paddle.Tensor, + global_mesh: ProcessMesh | None = None, + local_mesh_dim: int | None = None, + global_placements: list[dist.Placement] | None = None, +) -> list[paddle.Tensor]: """ Get the local part of the ``dist_tensor`` on the specific ``local_mesh_dim``. """ @@ -536,7 +579,11 @@ def moe_sub_mesh_tensors( ) -def dtensor_from_local(local_tensor, mesh, placements): +def dtensor_from_local( + local_tensor: list[paddle.Tensor], + mesh: ProcessMesh, + placements: list[dist.Placement], +) -> paddle.Tensor: # assume the each rank has the same tensor shape for now, just use the local shape to calculate the global shape global_dims = list(local_tensor.shape) for idx, placement in enumerate(placements): @@ -582,7 +629,13 @@ def dtensor_from_local(local_tensor, mesh, placements): ) -def dtensor_from_fn(fn, mesh, placements, *args, **kwargs): +def dtensor_from_fn( + fn: Callable, + mesh: ProcessMesh, + placements: list[dist.Placement], + *args: tuple, + **kwargs: dict, +) -> paddle.Tensor: """ Construct a Distributed Tensor from a function of arguments. @@ -616,7 +669,11 @@ def dtensor_from_fn(fn, mesh, placements, *args, **kwargs): # Part3: Data conversion related APIs -def reshard(dist_tensor, mesh, placements): +def reshard( + dist_tensor: paddle.Tensor, + mesh: ProcessMesh, + placements: list[dist.Placement], +) -> paddle.Tensor: """ Reshard a distributed ``paddle.Tensor`` with given distributed attributes. @@ -866,7 +923,9 @@ def replicate_layer_params_and_buffers( ) -def get_placement_with_sharding(param, sharding_mesh_axis): +def get_placement_with_sharding( + param: paddle.Tensor, sharding_mesh_axis: int +) -> list[dist.Placement]: shard_axis = -1 for placement in param.placements: if isinstance(placement, dist.Shard): @@ -891,7 +950,9 @@ def get_placement_with_sharding(param, sharding_mesh_axis): class _ShardOptimizer(Optimizer): - def __init__(self, optimizer, shard_fn=None): + def __init__( + self, optimizer: Optimizer, shard_fn: Callable | None = None + ) -> None: assert ( optimizer is not None ), "The argument `optimizer` cannot be empty." @@ -933,7 +994,7 @@ def __init__(self, optimizer, shard_fn=None): for param in self._inner_opt._parameter_list: self._shard_fn._shard_parameter(param) - def _set_and_check_sharding_prop_from_param(self): + def _set_and_check_sharding_prop_from_param(self) -> None: if (self._shard_fn._mesh is not None) and ( len(self._shard_fn._mesh._shape) == 1 ): @@ -974,7 +1035,7 @@ def _set_and_check_sharding_prop_from_param(self): self._sharding_degree is not None ), "The sharding degree is None in ShardOptimizer" - def _shard_accumulator(self, param): + def _shard_accumulator(self, param: paddle.Tensor) -> None: target_name = param.name if param.name in self._inner_opt._master_weights.keys(): target_name = self._inner_opt._master_weights[param.name].name @@ -1014,7 +1075,7 @@ def _shard_accumulator(self, param): target_name + "_" + key ) - def _reset_placements(self, param): + def _reset_placements(self, param: paddle.Tensor) -> None: if param.is_dist() and isinstance( self._shard_fn, (ShardingStage1, ShardingStage2) ): @@ -1027,14 +1088,20 @@ def _reset_placements(self, param): ) param.get_tensor()._share_data_with(out_param.get_tensor()) - def _create_accumulators(self, block, parameters): + def _create_accumulators( + self, block: Block, parameters: list[paddle.Tensor] + ) -> None: self._inner_opt._create_accumulators(block, parameters) if isinstance(parameters, dict): parameters = parameters.get('params') for p in parameters: self._shard_accumulator(p) - def _finish_update(self, block, parameters_and_grads): + def _finish_update( + self, + block: Block, + parameters_and_grads: list[tuple[paddle.Tensor, paddle.Tensor]], + ) -> None: self._inner_opt._finish_update(block, parameters_and_grads) if isinstance(parameters_and_grads, list): for p, _ in parameters_and_grads: @@ -1044,7 +1111,7 @@ def _finish_update(self, block, parameters_and_grads): for p, _ in parameters_and_grads['params']: self._reset_placements(p) - def state_dict(self): + def state_dict(self) -> dict[str, dict[str, paddle.Tensor]]: """ Create and shard the optimizer states e.g., accumulators and master_weights before load_state_dict. If training has already started or the optimizer states are already created and sharded, do nothing. @@ -1121,10 +1188,12 @@ def state_dict(self): return self._inner_opt.state_dict() - def _append_optimize_op(self, block, param_and_grad): + def _append_optimize_op( + self, block: Block, param_and_grad: tuple[paddle.Tensor, paddle.Tensor] + ) -> paddle.Tensor: return self._inner_opt._append_optimize_op(block, param_and_grad) - def __getattr__(self, item): + def __getattr__(self, item: str) -> Any: if "_inner_opt" in self.__dict__: if item == "_inner_opt": return self.__dict__[item] @@ -1132,7 +1201,7 @@ def __getattr__(self, item): else: raise AttributeError - def __setattr__(self, item, value): + def __setattr__(self, item: str, value: Any) -> None: if item == '_inner_opt': msg = f'{type(self).__name__}._inner_opt is READ ONLY' raise AttributeError(msg) @@ -1140,11 +1209,11 @@ def __setattr__(self, item, value): class _ShardingStageBase: - def __init__(self, mesh): + def __init__(self, mesh: ProcessMesh | None = None) -> None: self._mesh = mesh self._sharding_mesh_axis = None - def _set_sharding_mesh_axis(self, sharding_mesh_axis): + def _set_sharding_mesh_axis(self, sharding_mesh_axis: int) -> None: self._sharding_mesh_axis = sharding_mesh_axis @@ -1186,10 +1255,12 @@ class ShardingStage1(_ShardingStageBase): >>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py """ - def __init__(self, mesh=None): + def __init__(self, mesh: ProcessMesh | None = None) -> None: super().__init__(mesh) - def __call__(self, key, param, accumulator): + def __call__( + self, key: list[str], param: paddle.Tensor, accumulator: paddle.Tensor + ) -> paddle.Tensor: if param.is_dist(): # Only deal with momentum in optimizer, beta should be replicated cross param's mesh if 'beta' not in key: @@ -1247,10 +1318,12 @@ class ShardingStage2(_ShardingStageBase): >>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py """ - def __init__(self, mesh=None): + def __init__(self, mesh: ProcessMesh | None = None) -> None: super().__init__(mesh) - def __call__(self, key, param, accumulator): + def __call__( + self, key: list[str], param: paddle.Tensor, accumulator: paddle.Tensor + ) -> paddle.Tensor: if param.is_dist(): # Only deal with momentum in optimizer, beta should be replicated cross param's mesh if 'beta' not in key: @@ -1270,7 +1343,7 @@ def __call__(self, key, param, accumulator): return accumulator @staticmethod - def _grad_hook(grad): + def _grad_hook(grad: paddle.Tensor) -> paddle.Tensor: # do reshard only if the grad is dist tensor and in partial status if grad.is_dist(): partial_mesh_axis = None @@ -1285,7 +1358,7 @@ def _grad_hook(grad): return grad - def _register_hook_for_param_grad(self, param): + def _register_hook_for_param_grad(self, param: paddle.Tensor) -> None: if param.is_dense() and self._mesh is not None: placements = [] for _ in range(len(self._mesh.shape)): @@ -1333,10 +1406,10 @@ class ShardingStage3(_ShardingStageBase): >>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py """ - def __init__(self, mesh=None): + def __init__(self, mesh: ProcessMesh | None = None) -> None: super().__init__(mesh) - def _shard_parameter(self, param): + def _shard_parameter(self, param: paddle.Tensor) -> None: if param.is_dense() and self._mesh is not None: placements = [] for _ in range(len(self._mesh.shape)): @@ -1352,7 +1425,7 @@ def _shard_parameter(self, param): # change the holder of param to new shard_param param.get_tensor()._share_data_with(shard_param.get_tensor()) - def _unshard_parameter(self, param): + def _unshard_parameter(self, param: paddle.Tensor) -> None: if param.is_dist(): new_placements = param.placements if isinstance(new_placements[self._sharding_mesh_axis], dist.Shard): @@ -1361,7 +1434,9 @@ def _unshard_parameter(self, param): new_param = dist.reshard(param, param.process_mesh, new_placements) param.get_tensor()._share_data_with(new_param.get_tensor()) - def __call__(self, key, param, accumulator): + def __call__( + self, key: list[str], param: paddle.Tensor, accumulator: paddle.Tensor + ) -> paddle.Tensor: if param.is_dist(): # Only deal with momentum in optimizer, beta should be replicated cross param's mesh if 'beta' not in key: @@ -1387,7 +1462,9 @@ def __call__(self, key, param, accumulator): return accumulator -def shard_optimizer(optimizer, shard_fn=None): +def shard_optimizer( + optimizer: Optimizer, shard_fn: Callable | None = None +) -> Optimizer: """ Warp the global view optimizer to distributed view. @@ -1434,7 +1511,7 @@ def shard_fn(accumulator_name, param, accumulator) -> sharded_accumulator return _ShardOptimizer(optimizer, shard_fn) -def shard_scaler(scaler): +def shard_scaler(scaler: GradScaler) -> GradScaler: """ Warp the global view grad_scaler to distributed view. @@ -1479,7 +1556,7 @@ def shard_scaler(scaler): """ - def unscale_method(self, optimizer): + def unscale_method(self, optimizer: Optimizer) -> None: if not self._enable: return @@ -1605,10 +1682,12 @@ class FusePasses: A helper class for users to configure the fuse passes. """ - def __init__(self, config_dict=None): + def __init__( + self, config_dict: dict[FusePasses_attr, Any] | None = None + ) -> None: self.enable = False self.gemm_epilogue = False - self.dropout_add = False + self.drogemm_epiloguepout_add = False if config_dict is not None: for key, value in config_dict.items(): if hasattr(self, key): @@ -1662,7 +1741,7 @@ class Strategy(auto_strategy.BaseConfig): >>> strategy.pipeline.micro_batch_size = 2 """ - def __init__(self, config=None): + def __init__(self, config: dict[ConfigAttr, Any] | None = None) -> None: if config is not None: if isinstance(config, dict): self._config_dict = copy.deepcopy(config) @@ -1719,7 +1798,9 @@ def __init__(self, config=None): ) self._sp_optimization = auto_strategy.SPOptimizationConfig(config_dict) - def _from_legacy_strategy(self, legacy_strategy): + def _from_legacy_strategy( + self, legacy_strategy: auto_strategy.Strategy + ) -> None: """ NOTE(lizhiyu): This is a template function to get `dist.Strategy` from `fleet.auto.Strategy`. """ @@ -1754,7 +1835,7 @@ def _from_legacy_strategy(self, legacy_strategy): self._sp_optimization = copy.deepcopy(legacy_strategy.sp_optimization) @property - def sharding(self): + def sharding(self) -> auto_strategy.ShardingConfig: """ ``sharding`` is used to configure the sharding states of the optimizer, containing following configs: @@ -1782,7 +1863,7 @@ def sharding(self): return self._sharding @property - def gradient_merge(self): + def gradient_merge(self) -> auto_strategy.GradientMergeConfig: """ ``gradient_merge`` is used to configure the gradient merge strategy in training, containing following configs: @@ -1808,7 +1889,7 @@ def gradient_merge(self): return self._gradient_merge @property - def fused_passes(self): + def fused_passes(self) -> FusePasses: """ ``fused_passes`` is used to configure the fusion of the computation in the model, containing following configs: @@ -1835,7 +1916,7 @@ def fused_passes(self): return self._fused_passes @property - def pipeline(self): + def pipeline(self) -> auto_strategy.PipelineConfig: """ ``pipeline`` is used to configure the pipeline parallelism, containing following configs: @@ -1862,7 +1943,7 @@ def pipeline(self): return self._pipeline @property - def amp(self): + def amp(self) -> auto_strategy.AMPConfig: """ ``amp`` is used to configure the amp, containing following configs: @@ -1937,12 +2018,12 @@ class DistModel: def __init__( self, - layer, - loader, - loss=None, - optimizer=None, - strategy=None, - metrics=None, + layer: paddle.nn.Layer, + loader: ShardDataloader | paddle.io.DataLoader, + loss: Loss | Callable | None = None, + optimizer: paddle.optimizer.Optimizer | None = None, + strategy: Strategy | None = None, + metrics: dict[str, Metric] | None = None, ): self._feed_name_list = [] self._inner_strategy = self.__convert_strategy(strategy) @@ -1990,7 +2071,7 @@ def __init__( else: self.predict() - def train(self): + def train(self) -> None: """ Set the DistModel to "train" mode. In "train" mode, executing ``__call__`` method will update the @@ -2003,7 +2084,7 @@ def train(self): self._engine.to_mode("train") paddle.disable_static() - def eval(self): + def eval(self) -> None: """ Set the mode of DistModel to "eval". In "eval" mode, executing ``__call__`` will return the loss. @@ -2015,7 +2096,7 @@ def eval(self): self._engine.to_mode("eval") paddle.disable_static() - def predict(self): + def predict(self) -> None: """ Set the mode of DistModel to "predict". In "predict" mode, executing ``__call__`` returns a dict that contains the @@ -2033,7 +2114,7 @@ def predict(self): self._engine.to_mode("predict") paddle.disable_static() - def __validate_mode(self, mode): + def __validate_mode(self, mode: str | None) -> str: if mode is None and self._mode is None: raise ValueError( "Please set the mode or call train()/eval()/predict() first." @@ -2044,7 +2125,9 @@ def __validate_mode(self, mode): raise ValueError("mode can only be 'train', 'eval' or 'predict'.") return mode - def dist_main_program(self, mode=None): + def dist_main_program( + self, mode: str | None = None + ) -> paddle.static.Program: """ Get the distributed main program of the specified ``mode``. Each 'mode' has its own distributed main program, ``dist_main_program`` @@ -2064,7 +2147,9 @@ def dist_main_program(self, mode=None): mode = self.__validate_mode(mode) return self._engine.get_dist_main_program(mode) - def dist_startup_program(self, mode=None): + def dist_startup_program( + self, mode: str | None = None + ) -> paddle.static.Program: """ Get the corresponding distributed startup program of ``mode``, which is used for initializing the parameters. @@ -2083,7 +2168,9 @@ def dist_startup_program(self, mode=None): mode = self.__validate_mode(mode) return self._engine.get_dist_startup_program(mode) - def serial_main_program(self, mode=None): + def serial_main_program( + self, mode: str | None = None + ) -> paddle.static.Program: """ Get the corresponding serial main program of ``mode``, containing the whole variables and operators of the given ``layer``. @@ -2102,7 +2189,9 @@ def serial_main_program(self, mode=None): mode = self.__validate_mode(mode) return self._engine.get_serial_main_program(mode) - def serial_startup_program(self, mode=None): + def serial_startup_program( + self, mode: str | None = None + ) -> paddle.static.Program: """ Get the corresponding serial startup program of ``mode``. @@ -2120,7 +2209,9 @@ def serial_startup_program(self, mode=None): mode = self.__validate_mode(mode) return self._engine.get_serial_startup_program(mode) - def _make_feeds(self, data_list): + def _make_feeds( + self, data_list: list[paddle.Tensor | core.LoDTensor] + ) -> dict[str, paddle.Tensor]: # TODO (2024-Q2): formula make feed if self._in_pir_mode: self._feed_name_list[self._mode] = ['input0', 'label0'] @@ -2171,7 +2262,9 @@ def _to_lodtensor(tensor: paddle.Tensor): feed_name_list_with_data.append(feed_name) return dict(zip(feed_name_list_with_data, feed_list)) - def __convert_strategy(self, strategy): + def __convert_strategy( + self, strategy: auto_strategy.Strategy | None + ) -> auto_strategy.Strategy | None: import copy if strategy is None: @@ -2217,7 +2310,9 @@ def __convert_strategy(self, strategy): return inner_strategy @switch_to_static_graph - def __call__(self, *args): + def __call__( + self, *args: paddle.Tensor | core.LoDTensor + ) -> paddle.Tensor | None: if self._mode is None: raise ValueError("Please call train()/eval()/predict() first.") if self._mode == "train": @@ -2255,7 +2350,7 @@ def __call__(self, *args): else: return None - def _fetch_value(self, value, name=None): + def _fetch_value(self, value: pir.Value, name: str | None = None) -> None: """ Get the value of the variable with the given name. @@ -2270,7 +2365,7 @@ def _fetch_value(self, value, name=None): name = len(self._engine._pir_fetch_values) - 1 self._engine._pir_user_defined_fetch_names.append(name) - def state_dict(self, mode="all"): + def state_dict(self, mode: str | None = None) -> dict[str, paddle.Tensor]: """ Get the state dict of model and optimizer. @@ -2307,7 +2402,9 @@ def state_dict(self, mode="all"): ) return dist_state_dict - def _build_distributed_state_dict(self, local_state_dict): + def _build_distributed_state_dict( + self, local_state_dict: dict[str, paddle.Tensor] + ) -> dict[str, paddle.Tensor]: """ Args: local_state_dict(Dict[str, libpaddle.Tensor]): The state dict from program. @@ -2358,7 +2455,7 @@ def build_distributed_tensor(local_tensor, dist_attr): ) return global_state_dict - def set_state_dict(self, state_dict): + def set_state_dict(self, state_dict: dict[str, paddle.Tensor]) -> None: local_state_dict = {} dist_main_program = self.dist_main_program(mode=self._engine._mode) cur_state_dict = self.state_dict() @@ -2382,11 +2479,11 @@ def set_state_dict(self, state_dict): def to_static( layer: paddle.nn.Layer, - loader=None, - loss=None, - optimizer=None, - strategy=None, -): + loader: ShardDataloader | paddle.io.DataLoader, + loss: Loss | Callable | None = None, + optimizer: paddle.optimizer.Optimizer | _ShardOptimizer | None = None, + strategy: Strategy | None = None, +) -> DistModel: """ Converts the ``layer`` with distributed tensor (constructed from ``paddle.distributed.shard_tensor``) to a static graph. ``to_static`` @@ -2541,7 +2638,7 @@ def to_static( return dist_model -def unshard_dtensor(dist_tensor): +def unshard_dtensor(dist_tensor: paddle.Tensor) -> paddle.Tensor: """ Converts a distributed tensor to a dense tensor. ``unshard_dtensor`` first make the ``dist_tensor`` be ``Replicate`` state on all processes and @@ -2634,11 +2731,11 @@ class ShardDataloader: def __init__( self, dataloader: paddle.io.DataLoader, - meshes: ProcessMesh | list[ProcessMesh] | tuple[ProcessMesh], - input_keys: list[str] | tuple[str] | None = None, - shard_dims: list | tuple | str | int | None = None, + meshes: ProcessMesh | Sequence[ProcessMesh] | tuple[ProcessMesh], + input_keys: Sequence[str] | tuple[str] | None = None, + shard_dims: Sequence[str | int] | str | int | None = None, is_dataset_splitted: bool = False, - ): + ) -> None: # do some check if is_dataset_splitted is True and shard_dims is None: raise ValueError( @@ -2711,7 +2808,9 @@ def __init__( # Note(lizhiyu): In dygraph mode, the flag "pin_memory" is defualt "True", but it decrease the speed of `AutoParallel` self._dataloader.pin_memory = False - def _process_shard_dims(self, shard_dims): + def _process_shard_dims( + self, shard_dims: Sequence[str | int] | str | int | None + ) -> list: if isinstance(shard_dims, (int, str)) or shard_dims is None: res = [] for i in range(len(self._meshes)): @@ -2727,7 +2826,7 @@ def _process_shard_dims(self, shard_dims): ) return shard_dims - def _get_mesh_and_shard_dim(self, process_id): + def _get_mesh_and_shard_dim(self, process_id: int) -> tuple: for i in range(len(self._meshes)): if isinstance(self._meshes[i], (list, tuple)): for j in range(len(self._meshes[i])): @@ -2890,7 +2989,7 @@ def shard_dataloader( dataloader: paddle.io.DataLoader, meshes: ProcessMesh | list[ProcessMesh] | tuple[ProcessMesh], input_keys: list[str] | tuple[str] | None = None, - shard_dims: list | tuple | str | int | None = None, + shard_dims: Sequence[str | int] | str | int | None = None, is_dataset_splitted: bool = False, ) -> ShardDataloader: """ diff --git a/python/paddle/distributed/rpc/rpc.py b/python/paddle/distributed/rpc/rpc.py index 8018fd0741da5..7406a2b86dffc 100644 --- a/python/paddle/distributed/rpc/rpc.py +++ b/python/paddle/distributed/rpc/rpc.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import datetime import os @@ -35,7 +36,7 @@ _barrier_count = 0 -def _set_barrier_store(store): +def _set_barrier_store(store: core.TCPStore) -> None: global _barrier_store _barrier_store = store @@ -45,12 +46,12 @@ def _del_barrier_store(): del _barrier_store -def _set_self_info(name, rank, ip, port): +def _set_self_info(name: str, rank: int, ip: str, port: int) -> None: self_info = pickle.dumps(WorkerInfo(name, rank, ip, port)) _barrier_store.set(str(rank), self_info) -def _exchange_all_service_infos(world_size): +def _exchange_all_service_infos(world_size: int) -> list[WorkerInfo]: all_infos = [] s = set() for rank in range(world_size): @@ -63,14 +64,19 @@ def _exchange_all_service_infos(world_size): return all_infos -def _gen_endpoint(): +def _gen_endpoint() -> str: node = Node() ip = node.get_host_ip() free_port = node.get_free_port() return f"{ip}:{free_port}" -def init_rpc(name, rank=None, world_size=None, master_endpoint=None): +def init_rpc( + name: str, + rank: int | None = None, + world_size: int | None = None, + master_endpoint: str | None = None, +) -> None: """ init rpc. @@ -140,7 +146,13 @@ def init_rpc(name, rank=None, world_size=None, master_endpoint=None): logger.info(f"Trainer {rank}: Init RPC done!") -def rpc_sync(to, fn, args=None, kwargs=None, timeout=_DEFAULT_RPC_TIMEOUT): +def rpc_sync_framework( + to: str, + fn: callable, + args: tuple | None = None, + kwargs: dict | None = None, + timeout: int = _DEFAULT_RPC_TIMEOUT, +) -> core.FutureWrapper: """ Make a blocking RPC call to run function ``fn`` on worker ``to``. Attention: Users must use this API in a secure network environment. @@ -180,7 +192,13 @@ def rpc_sync(to, fn, args=None, kwargs=None, timeout=_DEFAULT_RPC_TIMEOUT): return fut.wait() -def rpc_async(to, fn, args=None, kwargs=None, timeout=_DEFAULT_RPC_TIMEOUT): +def rpc_async_framework( + to: str, + fn: callable, + args: tuple | None = None, + kwargs: dict | None = None, + timeout: int = _DEFAULT_RPC_TIMEOUT, +) -> core.FutureWrapper: """ Make a non-blocking RPC call to run function ``fn`` on worker ``to``. Attention: Users must use this API in a secure network environment. @@ -224,7 +242,9 @@ def rpc_async(to, fn, args=None, kwargs=None, timeout=_DEFAULT_RPC_TIMEOUT): return _invoke_rpc(to, fn, args, kwargs, timeout) -def _invoke_rpc(to, fn, args, kwargs, timeout): +def _invoke_rpc( + to: str, fn: callable, args: tuple, kwargs: dict, timeout: int +) -> core.FutureWrapper: args = args if args else () kwargs = kwargs if kwargs else {} serial_obj = _serialize(PythonFunc(fn, args, kwargs)) @@ -234,7 +254,7 @@ def _invoke_rpc(to, fn, args, kwargs, timeout): return future -def _barrier_never_timeout(global_rank, global_world_size): +def _barrier_never_timeout(global_rank: int, global_world_size: int) -> None: # max timeout timeout = datetime.timedelta(days=_BARRIER_TIMEOUT_MAX_DAYS) @@ -273,7 +293,7 @@ def _check_keys_ready(wait_keys): _barrier_store.add(barrier_prefix + str(global_rank), 1) -def shutdown(): +def shutdown() -> None: """ Perform a shutdown of the RPC agent, stop the worker and destroy the agent. This will block until all local and remote RPC processes reach this method @@ -304,7 +324,7 @@ def shutdown(): logger.info(f"Trainer {rank}: rpc shutdown!") -def get_worker_info(name): +def get_worker_info(name: str) -> WorkerInfo: """ Get worker information by worker name. @@ -334,7 +354,7 @@ class `WorkerInfo` with attribute `name`, `rank`, `ip` and `port`. return core.rpc_get_worker_info(name) -def get_all_worker_infos(): +def get_all_worker_infos() -> list[WorkerInfo]: """ Get all worker informations. @@ -361,7 +381,7 @@ def get_all_worker_infos(): return core.rpc_get_all_worker_infos() -def get_current_worker_info(): +def get_current_worker_info() -> WorkerInfo: """ Get current worker information. diff --git a/python/paddle/distributed/spawn.py b/python/paddle/distributed/spawn.py index f1b77cfaf91a6..0e1c74fb99646 100644 --- a/python/paddle/distributed/spawn.py +++ b/python/paddle/distributed/spawn.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import multiprocessing import os @@ -44,7 +45,7 @@ class ParallelEnvArgs: - def __init__(self): + def __init__(self) -> None: # Paddle cluster nodes ips, such as 192.168.0.16,192.168.0.17.. self.cluster_node_ips = None @@ -68,7 +69,7 @@ def __init__(self): self.selected_devices = None -def _options_valid_check(options): +def _options_valid_check(options: dict[str]) -> None: # `print_config` keeped as a debug options, not show to users supported_options = [ 'start_method', @@ -99,7 +100,7 @@ def _options_valid_check(options): ) -def _get_default_nprocs(): +def _get_default_nprocs() -> int: device = get_device() if 'gpu' in device: return core.get_cuda_device_count() @@ -115,7 +116,7 @@ def _get_default_nprocs(): ) -def _get_default_backend(): +def _get_default_backend() -> str: device = get_device() if 'gpu' in device: return 'nccl' @@ -131,7 +132,7 @@ def _get_default_backend(): ) -def _get_node_ip(ips): +def _get_node_ip(ips: str) -> str: node_ip = None node_ips = [x.strip() for x in ips.split(',')] if len(node_ips) == 1: @@ -141,7 +142,9 @@ def _get_node_ip(ips): return node_ip -def _get_subprocess_env_list(nprocs, options): +def _get_subprocess_env_list( + nprocs: int, options: dict[str] +) -> list[dict[str, str]]: # NOTE (xiongkun03) Why put backend deduction here ? # Because _get_subprocess_env_list is used by many testcases. # So for compatibility, we put backend deduction here @@ -330,14 +333,14 @@ def _get_subprocess_env_list(nprocs, options): return processes_env_list -def _remove_risky_env(): +def _remove_risky_env() -> None: # remove useless env vars # no copy, each process will hold env vars itself os.environ.pop("http_proxy", None) os.environ.pop("https_proxy", None) -def _set_trainer_env(env_dict, backend): +def _set_trainer_env(env_dict: dict[str, str], backend: str) -> None: # NOTE(chenweihang): [ Why need set FLAGS_selected_gpus or FLAGS_selected_xpus here? ] # When the child process starts, it will inherit the configuration of the # main process and set the FLAGS once, but the environment variable has @@ -361,7 +364,14 @@ def _set_trainer_env(env_dict, backend): os.environ[var_name] = env_dict[var_name] -def _func_wrapper(func, args, error_queue, return_queue, env_dict, backend): +def _func_wrapper( + func: callable, + args: tuple, + error_queue: multiprocessing.SimpleQueue, + return_queue: multiprocessing.SimpleQueue, + env_dict: dict[str, str], + backend: str, +) -> None: try: # config subprocess environment variables _remove_risky_env() @@ -380,7 +390,12 @@ def _func_wrapper(func, args, error_queue, return_queue, env_dict, backend): class MultiprocessContext: - def __init__(self, processes, error_queues, return_queues): + def __init__( + self, + processes: list[multiprocessing.Process], + error_queues: list[multiprocessing.SimpleQueue], + return_queues: list[multiprocessing.SimpleQueue], + ) -> None: self.error_queues = error_queues # NOTE(chenweihang): The `spawn` method is mainly used # to wrap the outermost execution function of the program for @@ -393,7 +408,7 @@ def __init__(self, processes, error_queues, return_queues): process.sentinel: index for index, process in enumerate(processes) } - def join(self, timeout=None): + def join(self, timeout: float | None = None) -> bool: if len(self.sentinels) == 0: return True @@ -420,7 +435,7 @@ def join(self, timeout=None): self._throw_exception(error_index) - def _throw_exception(self, error_index): + def _throw_exception(self, error_index: int) -> None: if self.error_queues[error_index].empty(): exitcode = self.processes[error_index].exitcode if exitcode < 0: @@ -445,7 +460,14 @@ def _throw_exception(self, error_index): raise Exception(msg) -def spawn(func, args=(), nprocs=-1, join=True, daemon=False, **options): +def spawn( + func: callable, + args: list | tuple | None = None, + nprocs: int = -1, + join: bool = True, + daemon: bool = False, + **options: dict[str, str], +) -> MultiprocessContext: """ Start multiple processes with ``spawn`` method for parallel training. From ee3eed041350573394f42cc2ef360bd82c96a074 Mon Sep 17 00:00:00 2001 From: ooo <21421006@buaa.edu.cn> Date: Wed, 14 Aug 2024 15:07:12 +0800 Subject: [PATCH 2/3] fixed --- python/paddle/distributed/rpc/rpc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/rpc/rpc.py b/python/paddle/distributed/rpc/rpc.py index 7406a2b86dffc..7122b4e201a6c 100644 --- a/python/paddle/distributed/rpc/rpc.py +++ b/python/paddle/distributed/rpc/rpc.py @@ -192,7 +192,7 @@ def rpc_sync_framework( return fut.wait() -def rpc_async_framework( +def rpc_async( to: str, fn: callable, args: tuple | None = None, From 57515773860b30895732bc5c129c59ddf16b668c Mon Sep 17 00:00:00 2001 From: ooo <21421006@buaa.edu.cn> Date: Wed, 14 Aug 2024 15:47:15 +0800 Subject: [PATCH 3/3] fixed --- python/paddle/distributed/rpc/rpc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distributed/rpc/rpc.py b/python/paddle/distributed/rpc/rpc.py index 7122b4e201a6c..d27f03f47b202 100644 --- a/python/paddle/distributed/rpc/rpc.py +++ b/python/paddle/distributed/rpc/rpc.py @@ -146,7 +146,7 @@ def init_rpc( logger.info(f"Trainer {rank}: Init RPC done!") -def rpc_sync_framework( +def rpc_sync( to: str, fn: callable, args: tuple | None = None,