-
Notifications
You must be signed in to change notification settings - Fork 3.1k
[fsdp, megatron] feat: Support fully-async training on Ascend NPU #5043
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for fully-async training on Ascend NPU by replacing Ray's collective communication with StatelessProcessGroup and PyHcclCommunicator. The changes look good overall, but I've identified a few critical issues related to potential race conditions and fragile implementation details in the new NPU communication path. Addressing these will improve the robustness and correctness of the implementation.
| if is_npu_available: | ||
| self._weight_sync_group.broadcast(tensor, src=0, stream=get_torch_device().current_stream()) | ||
| else: | ||
| collective.broadcast(tensor, src_rank=0, group_name=sync_group_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition is_npu_available could lead to a runtime AttributeError. The _weight_sync_group is initialized in param_sync.py only when self.config.trainer.device == "npu". If is_npu_available is true but the device in the config is not 'npu', self._weight_sync_group will not exist, causing a crash here. To make this more robust, you should check for the existence of _weight_sync_group instead.
| if is_npu_available: | |
| self._weight_sync_group.broadcast(tensor, src=0, stream=get_torch_device().current_stream()) | |
| else: | |
| collective.broadcast(tensor, src_rank=0, group_name=sync_group_name) | |
| if hasattr(self, "_weight_sync_group"): | |
| self._weight_sync_group.broadcast(tensor, src=0, stream=get_torch_device().current_stream()) | |
| else: | |
| collective.broadcast(tensor, src_rank=0, group_name=sync_group_name) |
| self.actor_wg.create_weight_sync_group( | ||
| master_address, | ||
| master_port, | ||
| 0, | ||
| n_workers, | ||
| ) | ||
| ray.get( | ||
| self.rollout_wg.create_weight_sync_group( | ||
| master_address, | ||
| master_port, | ||
| len(self.actor_wg.workers), | ||
| n_workers, | ||
| ) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a race condition in the initialization of the weight sync group. The call to self.actor_wg.create_weight_sync_group is non-blocking, and the returned futures are not awaited. The code immediately proceeds to initialize the rollout workers. Since the master for the process group (rank 0) is an actor worker, it may not have finished setting up the TCPStore when the rollout workers attempt to connect, which would cause the initialization to fail. You should collect the futures from both calls and wait for all of them to complete.
actor_futures = self.actor_wg.create_weight_sync_group(
master_address,
master_port,
0,
n_workers,
)
rollout_futures = self.rollout_wg.create_weight_sync_group(
master_address,
master_port,
len(self.actor_wg.workers),
n_workers,
)
ray.get(actor_futures + rollout_futures)| ) | ||
| n_workers = len(self.actor_wg.workers + self.rollout_wg.workers) | ||
| if self.config.trainer.device == "npu": | ||
| master_address = ray.get(self.actor_wg.workers[0]._get_node_ip.remote()).strip("[]") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The method for retrieving master_address is fragile. It relies on a private method _get_node_ip and then uses .strip("[]") on the result. This assumes a very specific string format (e.g., '[1.2.3.4]') and can easily break if the format changes, for example if the remote method returns a list object or a plain IP string. A more robust solution would be for the remote method to return a clean IP address string directly, removing the need for string manipulation here.
What does this PR do?
Since Ray's collective communication interface does not support the hccl backend,This PR
Use StatelessProcessGroup and PyHcclCommunicator instead of ray's create_collective_group to support fully-async training on Ascend NPU.
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
Add the following configuration to the script:
trainer.device=npu \ async_training.checkpoint_engine.enable=FalseDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)recipesubmodule, please also update the reference to the submodule commit viagit submodule update --remoteorcd recipe && git pull origin main.