Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Parallel] add align_mode supporting #68354

Merged
merged 6 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,11 @@ def reduce_gradients(self, parameter_list, hcg):
)
g_var.scale_(1.0 / sharding_nrank)
reduce_op = ReduceOp.SUM

# In align mode, we scale the grad in advance, so we need a SUM here
if paddle.distributed.in_auto_parallel_align_mode():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同样建议加一下NOTE注释

reduce_op = ReduceOp.SUM

param_rank = self._param2rank[param.name]

need_check = strtobool(
Expand Down
13 changes: 13 additions & 0 deletions python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
else:
from .pp_utils import p2p_communication as p2p

from paddle.distributed import fleet
from paddle.distributed.fleet.utils.tensor_fusion_helper import (
HOOK_ACTION,
FusedCommBuffer,
Expand All @@ -66,6 +67,15 @@ def get_action(is_dp, shard_split_param=False):
return HOOK_ACTION.REDUCE


def _get_align_mode_scale():
hcg = fleet.get_hybrid_communicate_group()
data_parallel_world_size = hcg.get_data_parallel_world_size()
sharding_parallel_world_size = hcg.get_sharding_parallel_world_size()
return max(data_parallel_world_size, 1) * max(
sharding_parallel_world_size, 1
)


# assume only the first stage and last stage need data, and data consumption is ordered
# to be replaced by real micro dataset from reader
class FakeMicroDataset:
Expand Down Expand Up @@ -997,6 +1007,9 @@ def _backward_step(
)
if self.is_pipeline_last_stage():
assert output_tensor_grad is None
# In align mode, we scale the grad directly after forward
if paddle.distributed.in_auto_parallel_align_mode():
output_tensor = output_tensor / _get_align_mode_scale()
if self.scaler:
paddle.autograd.backward(self.scaler.scale(output_tensor))
else:
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/distributed/fleet/utils/tensor_fusion_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,9 @@ def _comm_grads(self):
)

elif self._act == HOOK_ACTION.REDUCE_SCATTER:
# In align mode, we scale the grad in advance, so we need a SUM head
if paddle.distributed.in_auto_parallel_align_mode():
reduce_op = paddle.distributed.ReduceOp.SUM
shard_size = self.grad_storage._numel() // self._comm_group.nranks
begin = shard_size * self._comm_group.rank
end = begin + shard_size
Expand Down