Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 28 additions & 9 deletions verl/experimental/fully_async_policy/base_detach_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ray.util.collective import collective

from verl.single_controller.base.decorator import Dispatch, register
from verl.utils.device import get_torch_device
from verl.utils.device import get_device_name, get_torch_device

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
Expand Down Expand Up @@ -120,16 +120,35 @@ def __del__(self):
self._bg_thread.join(timeout=1.0)

@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
def init_checkpoint_engine(self, rank_offset: int, actor_num: int, rollout_num: int):
from .checkpoint_engine import CheckpointEngine
def init_checkpoint_engine(self):
# from .checkpoint_engine import CheckpointEngine
from verl.checkpoint_engine import CheckpointEngineRegistry

backends_candi = {
"cuda": "nccl",
"npu": "hccl",
}
checkpoint_backend = backends_candi[get_device_name()]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The checkpoint_backend is retrieved from backends_candi using get_device_name(). If get_device_name() returns a value not present as a key in backends_candi, this will result in a KeyError. It would be more robust to handle this case, perhaps by raising a more informative error or falling back to a default if applicable.

checkpoint_kwargs = {
"bucket_size": 2 * 1024 * 1024 * 1024, # 2GB
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The bucket_size is hardcoded to 2GB. While this might be a reasonable default, it limits flexibility. Consider making this value configurable through the config object, similar to other parameters, to allow for optimization in different environments or with varying model sizes.

"rebuild_group": False,
}
if torch.distributed.get_rank() == 0 and self._is_actor:
checkpoint_kwargs["is_master"] = True
self.checkpoint_engine = CheckpointEngineRegistry.new(checkpoint_backend, **checkpoint_kwargs)

current_rank = torch.distributed.get_rank() + rank_offset
actor_ranks = list(range(actor_num))
rollout_ranks = [rank + actor_num for rank in range(rollout_num)]
assert rank_offset == 0 or rank_offset == actor_num
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
def prepare(self):
metadata = self.checkpoint_engine.prepare()
return metadata

self.checkpoint_engine = CheckpointEngine(
current_rank, actor_ranks, rollout_ranks, self.config.checkpoint_engine.device_buffer_size_M
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
def init_ckpt_engine_process_group(self, rank_offset: int, world_size: int, master_metadata: dict):
current_rank = torch.distributed.get_rank() + rank_offset
if self._is_actor:
current_rank = 0 if current_rank == 0 else -1
self.checkpoint_engine.init_process_group(
rank=current_rank, world_size=world_size, master_metadata=master_metadata
)

@staticmethod
Expand Down
59 changes: 47 additions & 12 deletions verl/experimental/fully_async_policy/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,46 +133,78 @@ def cache_actor_weights_to_cpu(self):
get_torch_device().synchronize()

@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"):
async def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"):
assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine
assert hasattr(self, "_weights_info") and self._weights_info is not None
do_prof = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The do_prof flag is hardcoded to True. This means profiling will always be enabled. For production environments, profiling should typically be controlled by a configuration setting or an environment variable to avoid unnecessary overhead. Please make this configurable or remove it if it's only for temporary debugging.

if do_prof:
prof = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]
)
prof.start()

# Load model to GPU
load_start_time = time.time()
if self._is_actor and self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
load_duration = time.time() - load_start_time

from ray.util.collective import collective

# Cache actor weights to CPU and measure the time taken
cache_start_time = time.time()
self.cache_actor_weights_to_cpu()
# self.cache_actor_weights_to_cpu()
cache_end_time = time.time()
cache_duration = cache_end_time - cache_start_time

# Register the cached weights into the checkpoint engine
self.checkpoint_engine.register_checkpoint(self._weights_info, self.cpu_named_params)
# self.checkpoint_engine.register_checkpoint(self._weights_info, self.cpu_named_params)
register_end_time = time.time()
register_duration = register_end_time - cache_end_time
self.cpu_named_params = {}

collective.barrier(group_name=sync_group_name)
# collective.barrier(group_name=sync_group_name)
update_start_time = time.time()

inference_model = None
if self._is_rollout:
# import asyncio
# def async_generator_to_sync(async_gen):
# """
# 把异步生成器转换成普通同步生成器
# :param async_gen: 异步生成器对象
# :return: 普通同步生成器
# """
# # 定义内部异步函数:遍历异步生成器并收集元素
# async def consume_async_gen():
# result = []
# async for item in async_gen: # 异步遍历异步生成器
# result.append(item)
# return result

# # 同步生成器逻辑:运行事件循环,逐个yield元素
# for item in asyncio.run(consume_async_gen()):
# yield item

inference_model = BaseDetachNcclSync.get_inference_model(self.rollout)
from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader

patch_vllm_moe_model_weight_loader(inference_model)

# Update the checkpoint with the inference model and broadcast weights
self.checkpoint_engine.update_checkpoint(
inference_model=inference_model,
group_name=sync_group_name,
overlap_broadcast_and_consume=self.config.checkpoint_engine.overlap_broadcast_and_consume,
)
# inference_model.load_weights(async_generator_to_sync(self.checkpoint_engine.receive_weights()))
async for name, weight in self.checkpoint_engine.receive_weights():
inference_model.load_weights([(name, weight)])
else:

def actor_params_to_full(actor_params):
for key, param in actor_params.items():
if hasattr(param, "full_tensor"):
yield key, param.full_tensor()

actor_params = self._get_actor_params()
# print(f'[debug] actor_params["model.embed_tokens.weight"]: {actor_params["model.embed_tokens.weight"]}')
await self.checkpoint_engine.send_weights(actor_params_to_full(actor_params))

update_end_time = time.time()
update_duration = update_end_time - update_start_time
Expand All @@ -194,6 +226,9 @@ def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"):
f"sync_rollout_weights_by_checkpoint load model to gpu cost {load_duration} seconds,"
f" offload model to cpu cost {offload_duration} seconds"
)
if do_prof:
prof.stop()
prof.export_chrome_trace(f"/home/tiger/ckpt_engine_prof_{torch.distributed.get_rank()}.json")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The profiling trace file path /home/tiger/ckpt_engine_prof_{torch.distributed.get_rank()}.json is hardcoded. This path is specific to a user's home directory and is not suitable for general deployment. Please make this path configurable (e.g., via a config object or environment variable) or ensure it's written to a temporary or designated logging directory.



class DetachActorWorker(DetachNcclSync):
Expand Down
23 changes: 13 additions & 10 deletions verl/experimental/fully_async_policy/param_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, config, trainer, rollouter, mq):
self.current_version = 0

self._init_weights_info()
self._init_sync_group()
# self._init_sync_group()

if self.config.async_training.checkpoint_engine.enable:
self._init_actor_rollout_checkpoint_engine()
Expand Down Expand Up @@ -80,18 +80,21 @@ def _init_sync_group(self):
)

def _init_actor_rollout_checkpoint_engine(self):
ray.get(self.actor_wg.init_checkpoint_engine())
ray.get(self.rollout_wg.init_checkpoint_engine())

metadata = ray.get(self.actor_wg.prepare() + self.rollout_wg.prepare())
print(f"metadata: {metadata}")
ray.get(
self.actor_wg.init_checkpoint_engine(
self.actor_wg.init_ckpt_engine_process_group(
rank_offset=0,
actor_num=len(self.actor_wg.workers),
rollout_num=len(self.rollout_wg.workers),
world_size=1 + len(self.rollout_wg.workers),
master_metadata=metadata[0],
)
)
ray.get(
self.rollout_wg.init_checkpoint_engine(
rank_offset=len(self.actor_wg.workers),
actor_num=len(self.actor_wg.workers),
rollout_num=len(self.rollout_wg.workers),
+ self.rollout_wg.init_ckpt_engine_process_group(
rank_offset=1,
world_size=1 + len(self.rollout_wg.workers),
master_metadata=metadata[0],
)
)

Expand Down