diff --git a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py index a6f5d4b9e6af9..c44b3faa07763 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/dygraph_sharding_optimizer.py @@ -338,6 +338,11 @@ def reduce_gradients(self, parameter_list, hcg): ) g_var.scale_(1.0 / sharding_nrank) reduce_op = ReduceOp.SUM + + # In align mode, we scale the grad in advance, so we need a SUM here + if paddle.distributed.in_auto_parallel_align_mode(): + reduce_op = ReduceOp.SUM + param_rank = self._param2rank[param.name] need_check = strtobool( diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py index 5f8f512d384bb..c3d47240cdb9a 100644 --- a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -45,6 +45,7 @@ else: from .pp_utils import p2p_communication as p2p +from paddle.distributed import fleet from paddle.distributed.fleet.utils.tensor_fusion_helper import ( HOOK_ACTION, FusedCommBuffer, @@ -66,6 +67,15 @@ def get_action(is_dp, shard_split_param=False): return HOOK_ACTION.REDUCE +def _get_align_mode_scale(): + hcg = fleet.get_hybrid_communicate_group() + data_parallel_world_size = hcg.get_data_parallel_world_size() + sharding_parallel_world_size = hcg.get_sharding_parallel_world_size() + return max(data_parallel_world_size, 1) * max( + sharding_parallel_world_size, 1 + ) + + # assume only the first stage and last stage need data, and data consumption is ordered # to be replaced by real micro dataset from reader class FakeMicroDataset: @@ -997,6 +1007,9 @@ def _backward_step( ) if self.is_pipeline_last_stage(): assert output_tensor_grad is None + # In align mode, we scale the grad directly after forward + if paddle.distributed.in_auto_parallel_align_mode(): + output_tensor = output_tensor / _get_align_mode_scale() if self.scaler: paddle.autograd.backward(self.scaler.scale(output_tensor)) else: diff --git a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py index 80c69dd87a41a..8f5b07799e9b8 100644 --- a/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py +++ b/python/paddle/distributed/fleet/utils/tensor_fusion_helper.py @@ -687,6 +687,9 @@ def _comm_grads(self): ) elif self._act == HOOK_ACTION.REDUCE_SCATTER: + # In align mode, we scale the grad in advance, so we need a SUM head + if paddle.distributed.in_auto_parallel_align_mode(): + reduce_op = paddle.distributed.ReduceOp.SUM shard_size = self.grad_storage._numel() // self._comm_group.nranks begin = shard_size * self._comm_group.rank end = begin + shard_size