-
Notifications
You must be signed in to change notification settings - Fork 3.1k
[fsdp, megatron] Refactor fully-async training to support multiple checkpoint engine backends #5029
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the fully-async training to support multiple checkpoint engine backends, such as NCCL and HCCL. The changes introduce a more modular CheckpointEngineRegistry approach, updating the initialization and preparation of the checkpoint engine and its process groups. The weight synchronization logic has also been updated to be asynchronous, which is a significant architectural improvement. However, there are a few areas that could be improved for maintainability and robustness, such as making hardcoded values configurable and removing commented-out code.
| "cuda": "nccl", | ||
| "npu": "hccl", | ||
| } | ||
| checkpoint_backend = backends_candi[get_device_name()] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The checkpoint_backend is retrieved from backends_candi using get_device_name(). If get_device_name() returns a value not present as a key in backends_candi, this will result in a KeyError. It would be more robust to handle this case, perhaps by raising a more informative error or falling back to a default if applicable.
| } | ||
| checkpoint_backend = backends_candi[get_device_name()] | ||
| checkpoint_kwargs = { | ||
| "bucket_size": 2 * 1024 * 1024 * 1024, # 2GB |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| async def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"): | ||
| assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine | ||
| assert hasattr(self, "_weights_info") and self._weights_info is not None | ||
| do_prof = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The do_prof flag is hardcoded to True. This means profiling will always be enabled. For production environments, profiling should typically be controlled by a configuration setting or an environment variable to avoid unnecessary overhead. Please make this configurable or remove it if it's only for temporary debugging.
| ) | ||
| if do_prof: | ||
| prof.stop() | ||
| prof.export_chrome_trace(f"/home/tiger/ckpt_engine_prof_{torch.distributed.get_rank()}.json") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The profiling trace file path /home/tiger/ckpt_engine_prof_{torch.distributed.get_rank()}.json is hardcoded. This path is specific to a user's home directory and is not suitable for general deployment. Please make this path configurable (e.g., via a config object or environment variable) or ensure it's written to a temporary or designated logging directory.
What does this PR do?
Refactor fully-async training to support multiple checkpoint engine backends
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
| Trainig backend | Vanilla cost | Current cost |
| FSDP2 | 2.64s | 0.06s |
| FSDP | TBA | TBA |
| Megatron | TBA | TBA |
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)recipesubmodule, please also update the reference to the submodule commit viagit submodule update --remoteorcd recipe && git pull origin main.