-
Notifications
You must be signed in to change notification settings - Fork 3.1k
[fsdp, megatron] Refactor fully-async training to support multiple checkpoint engine backends #5029
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,7 +23,7 @@ | |
| from ray.util.collective import collective | ||
|
|
||
| from verl.single_controller.base.decorator import Dispatch, register | ||
| from verl.utils.device import get_torch_device | ||
| from verl.utils.device import get_device_name, get_torch_device | ||
|
|
||
| logger = logging.getLogger(__file__) | ||
| logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) | ||
|
|
@@ -120,16 +120,35 @@ def __del__(self): | |
| self._bg_thread.join(timeout=1.0) | ||
|
|
||
| @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) | ||
| def init_checkpoint_engine(self, rank_offset: int, actor_num: int, rollout_num: int): | ||
| from .checkpoint_engine import CheckpointEngine | ||
| def init_checkpoint_engine(self): | ||
| # from .checkpoint_engine import CheckpointEngine | ||
| from verl.checkpoint_engine import CheckpointEngineRegistry | ||
|
|
||
| backends_candi = { | ||
| "cuda": "nccl", | ||
| "npu": "hccl", | ||
| } | ||
| checkpoint_backend = backends_candi[get_device_name()] | ||
| checkpoint_kwargs = { | ||
| "bucket_size": 2 * 1024 * 1024 * 1024, # 2GB | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| "rebuild_group": False, | ||
| } | ||
| if torch.distributed.get_rank() == 0 and self._is_actor: | ||
| checkpoint_kwargs["is_master"] = True | ||
| self.checkpoint_engine = CheckpointEngineRegistry.new(checkpoint_backend, **checkpoint_kwargs) | ||
|
|
||
| current_rank = torch.distributed.get_rank() + rank_offset | ||
| actor_ranks = list(range(actor_num)) | ||
| rollout_ranks = [rank + actor_num for rank in range(rollout_num)] | ||
| assert rank_offset == 0 or rank_offset == actor_num | ||
| @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) | ||
| def prepare(self): | ||
| metadata = self.checkpoint_engine.prepare() | ||
| return metadata | ||
|
|
||
| self.checkpoint_engine = CheckpointEngine( | ||
| current_rank, actor_ranks, rollout_ranks, self.config.checkpoint_engine.device_buffer_size_M | ||
| @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) | ||
| def init_ckpt_engine_process_group(self, rank_offset: int, world_size: int, master_metadata: dict): | ||
| current_rank = torch.distributed.get_rank() + rank_offset | ||
| if self._is_actor: | ||
| current_rank = 0 if current_rank == 0 else -1 | ||
| self.checkpoint_engine.init_process_group( | ||
| rank=current_rank, world_size=world_size, master_metadata=master_metadata | ||
| ) | ||
|
|
||
| @staticmethod | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -133,46 +133,78 @@ def cache_actor_weights_to_cpu(self): | |
| get_torch_device().synchronize() | ||
|
|
||
| @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) | ||
| def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"): | ||
| async def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"): | ||
| assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine | ||
| assert hasattr(self, "_weights_info") and self._weights_info is not None | ||
| do_prof = True | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| if do_prof: | ||
| prof = torch.profiler.profile( | ||
| activities=[ | ||
| torch.profiler.ProfilerActivity.CPU, | ||
| torch.profiler.ProfilerActivity.CUDA, | ||
| ] | ||
| ) | ||
| prof.start() | ||
|
|
||
| # Load model to GPU | ||
| load_start_time = time.time() | ||
| if self._is_actor and self._is_offload_param: | ||
| load_fsdp_model_to_gpu(self.actor_module_fsdp) | ||
| load_duration = time.time() - load_start_time | ||
|
|
||
| from ray.util.collective import collective | ||
|
|
||
| # Cache actor weights to CPU and measure the time taken | ||
| cache_start_time = time.time() | ||
| self.cache_actor_weights_to_cpu() | ||
| # self.cache_actor_weights_to_cpu() | ||
| cache_end_time = time.time() | ||
| cache_duration = cache_end_time - cache_start_time | ||
|
|
||
| # Register the cached weights into the checkpoint engine | ||
| self.checkpoint_engine.register_checkpoint(self._weights_info, self.cpu_named_params) | ||
| # self.checkpoint_engine.register_checkpoint(self._weights_info, self.cpu_named_params) | ||
| register_end_time = time.time() | ||
| register_duration = register_end_time - cache_end_time | ||
| self.cpu_named_params = {} | ||
|
|
||
| collective.barrier(group_name=sync_group_name) | ||
| # collective.barrier(group_name=sync_group_name) | ||
| update_start_time = time.time() | ||
|
|
||
| inference_model = None | ||
| if self._is_rollout: | ||
| # import asyncio | ||
| # def async_generator_to_sync(async_gen): | ||
| # """ | ||
| # 把异步生成器转换成普通同步生成器 | ||
| # :param async_gen: 异步生成器对象 | ||
| # :return: 普通同步生成器 | ||
| # """ | ||
| # # 定义内部异步函数:遍历异步生成器并收集元素 | ||
| # async def consume_async_gen(): | ||
| # result = [] | ||
| # async for item in async_gen: # 异步遍历异步生成器 | ||
| # result.append(item) | ||
| # return result | ||
|
|
||
| # # 同步生成器逻辑:运行事件循环,逐个yield元素 | ||
| # for item in asyncio.run(consume_async_gen()): | ||
| # yield item | ||
|
|
||
| inference_model = BaseDetachNcclSync.get_inference_model(self.rollout) | ||
| from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader | ||
|
|
||
| patch_vllm_moe_model_weight_loader(inference_model) | ||
|
|
||
| # Update the checkpoint with the inference model and broadcast weights | ||
| self.checkpoint_engine.update_checkpoint( | ||
| inference_model=inference_model, | ||
| group_name=sync_group_name, | ||
| overlap_broadcast_and_consume=self.config.checkpoint_engine.overlap_broadcast_and_consume, | ||
| ) | ||
| # inference_model.load_weights(async_generator_to_sync(self.checkpoint_engine.receive_weights())) | ||
| async for name, weight in self.checkpoint_engine.receive_weights(): | ||
| inference_model.load_weights([(name, weight)]) | ||
| else: | ||
|
|
||
| def actor_params_to_full(actor_params): | ||
| for key, param in actor_params.items(): | ||
| if hasattr(param, "full_tensor"): | ||
| yield key, param.full_tensor() | ||
|
|
||
| actor_params = self._get_actor_params() | ||
| # print(f'[debug] actor_params["model.embed_tokens.weight"]: {actor_params["model.embed_tokens.weight"]}') | ||
| await self.checkpoint_engine.send_weights(actor_params_to_full(actor_params)) | ||
|
|
||
| update_end_time = time.time() | ||
| update_duration = update_end_time - update_start_time | ||
|
|
@@ -194,6 +226,9 @@ def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"): | |
| f"sync_rollout_weights_by_checkpoint load model to gpu cost {load_duration} seconds," | ||
| f" offload model to cpu cost {offload_duration} seconds" | ||
| ) | ||
| if do_prof: | ||
| prof.stop() | ||
| prof.export_chrome_trace(f"/home/tiger/ckpt_engine_prof_{torch.distributed.get_rank()}.json") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The profiling trace file path |
||
|
|
||
|
|
||
| class DetachActorWorker(DetachNcclSync): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
checkpoint_backendis retrieved frombackends_candiusingget_device_name(). Ifget_device_name()returns a value not present as a key inbackends_candi, this will result in aKeyError. It would be more robust to handle this case, perhaps by raising a more informative error or falling back to a default if applicable.