diff --git a/tests/checkpoint_engine/test_kimi_checkpoint_engine.py b/tests/checkpoint_engine/test_kimi_checkpoint_engine.py new file mode 100644 index 00000000000..4eb461e4a14 --- /dev/null +++ b/tests/checkpoint_engine/test_kimi_checkpoint_engine.py @@ -0,0 +1,121 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import pytest +import ray + +from tests.checkpoint_engine.test_utils import create_rollout_worker_group, create_trainer_worker_group +from verl.single_controller.ray.base import ( + RayResourcePool, + split_resource_pool, +) +from verl.utils.device import get_device_name + + +@pytest.mark.parametrize("rebuild_group", [False, True]) +@pytest.mark.parametrize("num_trainer, num_rollout", [(2, 6)]) +def test_kimi_checkpoint_engine( + rebuild_group, + num_trainer, + num_rollout, + num_nodes=1, + num_gpus_per_node=8, + check_allclose=True, + model_path="~/models/Qwen/Qwen3-8B-Base", +): + model_path = os.path.expanduser(model_path) + ray.init( + runtime_env={ + "env_vars": { + "UCX_TLS": "rc,tcp,cuda", + "UCX_MAX_RNDV_RAILS": "4", + "UCX_LOG_LEVEL": "INFO", + "VERL_LOGGING_LEVEL": "DEBUG", + "NCCL_IB_HCA": "mlx5", + "ASCEND_USE_SHORT_CONNECTION": "1", + } + } + ) + + resource_pool = RayResourcePool(process_on_nodes=[num_gpus_per_node] * num_nodes, max_colocate_count=3) + resource_pool.get_placement_groups(device_name=get_device_name()) + trainer_pool, rollout_pool = split_resource_pool(resource_pool, [num_trainer, num_rollout]) + checkpoint_kwargs = { + "train_world_size": num_trainer, + "rollout_world_size": num_rollout, + "bucket_size": 2 * 1024 * 1024 * 1024, # 2GB + "rebuild_group": rebuild_group, + } + + trainer = create_trainer_worker_group(model_path, trainer_pool, "kimi_ckpt_engine", checkpoint_kwargs) + trainer.reset() + rollout = create_rollout_worker_group( + model_path, + rollout_pool, + "kimi_ckpt_engine", + checkpoint_kwargs, + check_allclose=check_allclose, + ) + + world_size = trainer.world_size + rollout.world_size + for _ in range(3): + # 1. prepare all workers + metadata = ray.get( + trainer.execute_checkpoint_engine(["prepare"] * trainer.world_size) + + rollout.execute_checkpoint_engine(["prepare"] * rollout.world_size) + ) + trainer_kwargs = { + "method": ["init_process_group"] * trainer.world_size, + "rank": list(range(0, trainer.world_size)), + "world_size": [world_size] * trainer.world_size, + "master_metadata": [metadata[0]] * trainer.world_size, + } + rollout_kwargs = { + "method": ["init_process_group"] * rollout.world_size, + "rank": list(range(trainer.world_size, world_size)), + "world_size": [world_size] * rollout.world_size, + "master_metadata": [metadata[0]] * rollout.world_size, + } + + # 2. init process group between all workers + ray.get( + trainer.execute_checkpoint_engine(**trainer_kwargs) + rollout.execute_checkpoint_engine(**rollout_kwargs) + ) + + # 3. update weights of all workers + ray.get(trainer.update_weights() + rollout.update_weights()) + + # 4. finish all workers + ray.get( + trainer.execute_checkpoint_engine(["finish"] * trainer.world_size) + + rollout.execute_checkpoint_engine(["finish"] * rollout.world_size) + ) + + # 5. check weights of rollout workers + rollout.check_weights() + + ray.shutdown() + + +if __name__ == "__main__": + test_kimi_checkpoint_engine( + rebuild_group=False, + num_trainer=4, + num_rollout=28, + num_nodes=2, + num_gpus_per_node=16, + check_allclose=False, + model_path=os.environ["HDFS_ROOT"] + "/model/Qwen3-30B-A3B-Base", + ) diff --git a/tests/checkpoint_engine/test_utils.py b/tests/checkpoint_engine/test_utils.py index 4e18b227d09..ec2178d484d 100644 --- a/tests/checkpoint_engine/test_utils.py +++ b/tests/checkpoint_engine/test_utils.py @@ -29,7 +29,7 @@ class TrainingWorkerTest(TrainingWorker): def __init__(self, config: TrainingWorkerConfig, checkpoint_backend: str, checkpoint_kwargs: dict) -> None: copy_to_local(config.model_config.path) super().__init__(config) - if torch.distributed.get_rank() == 0 and checkpoint_backend in ["nccl", "hccl"]: + if torch.distributed.get_rank() == 0 and checkpoint_backend in ["nccl", "hccl", "kimi_ckpt_engine"]: checkpoint_kwargs["is_master"] = True self.checkpoint_engine = CheckpointEngineRegistry.new(checkpoint_backend, **checkpoint_kwargs) diff --git a/verl/checkpoint_engine/README.md b/verl/checkpoint_engine/README.md index 2318dd9477d..ed335e1cbba 100644 --- a/verl/checkpoint_engine/README.md +++ b/verl/checkpoint_engine/README.md @@ -18,16 +18,20 @@ Checkpoint Engine is an unified abstract layer to synchronize weights between va |nccl|NCCL|all_gather+broadcast|NVIDIA GPU & NCCL|Very High|Low: rebuild nccl group|Off-policy training
- Trainer/rollout disaggregated
- Fixed clusters |hccl|HCCL|all_gather+broadcast|Ascend NPU & HCCL| High|Low: rebuild hccl group|Off-policy training
- Trainer/rollout disaggregated
- Fixed clusters |nixl|NIXL|all_gather+ring p2p|Various transport backends (D2D, H2H, H2D, etc)
- UCX
- UCCL
- Mooncacke|Medium/High|High: dynamic adjust ring topology|Off-policy training
- Trainer/rollout disaggregated
- Elastic rollout
- Rollout fault tolerance
- Heterogeneous hardware rollout +|kimi_ckpt_engine|MOONCAKE+NCCL/HCCL|p2p+broadcast|NVIDIA/Ascend|High|Low: rebuild communication group|Off-policy training
- Trainer/rollout disaggregated
- Save checkpoint each time + +PS: kimi_ckpt_engine first offloads all weights to the CPU. Then, using Mooncake transfer engine, these weights are transmitted via P2P to a specific worker in the rollout, followed by a broadcast to all other rollout workers. ### Benchmark 1. benchmark setup - model: Qwen/Qwen3-30B-A3B-Base -- trainer: fsdp world_size=2 +- trainer: fsdp world_size=2 (since Ascend 910C has 64GB of HBM, we set world_size=4) - rollout: num_rollout=30 (only receive weight without cuda ipc to vllm/sglang) ```bash python3 tests/checkpoint_engine/test_nixl_checkpoint_engine.py python3 tests/checkpoint_engine/test_nccl_checkpoint_engine.py python3 tests/checkpoint_engine/test_hccl_checkpoint_engine.py +python3 tests/checkpoint_engine/test_kimi_checkpoint_engine.py ``` 2. benchmark result @@ -36,4 +40,5 @@ python3 tests/checkpoint_engine/test_hccl_checkpoint_engine.py |----|----|----|----| |4*8 H100, ConnectX-7 400 Gbps (InfiniBand)| NCCL | ~7 | 8.25| |4*8 H100, ConnectX-7 400 Gbps (InfiniBand)| NIXL | ~7 | 8.25| -|2*16 Ascend 910C, inner suppernode| HCCL | ~11 | 5.3| \ No newline at end of file +|2*16 Ascend 910C, inner suppernode| HCCL | ~11 | 5.3| +|2*16 Ascend 910C, inner suppernode| kimi_ckpt_engine | offload: 7 update: 3.5 | 16.5| diff --git a/verl/checkpoint_engine/__init__.py b/verl/checkpoint_engine/__init__.py index 3c20eb35401..c36292621fd 100644 --- a/verl/checkpoint_engine/__init__.py +++ b/verl/checkpoint_engine/__init__.py @@ -37,3 +37,10 @@ __all__ += ["NIXLCheckpointEngine"] except ImportError: NIXLCheckpointEngine = None + +try: + from .kimi_checkpoint_engine import KIMICheckpointEngine + + __all__ += ["KIMICheckpointEngine"] +except ImportError: + KIMICheckpointEngine = None diff --git a/verl/checkpoint_engine/kimi_checkpoint_engine.py b/verl/checkpoint_engine/kimi_checkpoint_engine.py new file mode 100644 index 00000000000..be6f92f35d2 --- /dev/null +++ b/verl/checkpoint_engine/kimi_checkpoint_engine.py @@ -0,0 +1,357 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import concurrent.futures +import logging +import os +import time +import types +from collections import defaultdict +from dataclasses import dataclass +from typing import AsyncGenerator, Generator + +import checkpoint_engine.distributed as dist +import ray +import torch +from checkpoint_engine.ps import H2DBucket, ParameterMeta, ParameterServer, _gen_h2d_buckets, _to_named_tensor + +from verl.checkpoint_engine.base import CheckpointEngine, CheckpointEngineRegistry +from verl.utils.device import get_device_name, get_nccl_backend, get_torch_device +from verl.utils.net_utils import get_free_port + +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +def ckpt_get_named_tensor_buckets( + iterable: Generator[tuple[str, torch.Tensor], None, None], + bucket_bytes: int, + world_size: int, + rank_id: int, + rollout_dtype: torch.dtype = torch.bfloat16, +) -> dict[str, torch.Tensor]: + if bucket_bytes <= 0: + raise ValueError(f"bucket_bytes must be greater than 0, got {bucket_bytes}") + + current_bucket = {} + current_size = 0 + for tensor_idx, (name, tensor) in enumerate(iterable): + tensor = tensor.to(rollout_dtype) + if tensor_idx % world_size == rank_id: + tensor_size = tensor.element_size() * tensor.numel() + if current_size + tensor_size > bucket_bytes: + if current_bucket: + yield current_bucket + current_bucket = {} + current_size = 0 + + current_bucket[name] = tensor + current_size += tensor_size + + if current_bucket: + yield current_bucket + + +async def receive_tensor( + self, + checkpoint_name: str, + ranks_group: int, + ranks: list[int] | None = None, + bucket_size: int = 2 << 30, + disable_h2d_buffer: bool = False, +) -> AsyncGenerator[tuple[str, torch.Tensor], None]: + assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty" + assert dist.is_initialized(), "process group is not initialized" + assert self._p2p_store is not None, "p2p store is not initialized" + assert ranks, "ranks should be set" + + # first execute a barrier to avoid subsequent device oom + dist.barrier(group=ranks_group) + buckets = _gen_h2d_buckets( + self._current_global_parameter_metas, + bucket_size, + self._local_rdma_devices, + self._remote_rdma_devices, + ranks, + ) + h2d_buffer: torch.Tensor | None = ( + None + if disable_h2d_buffer + else torch.empty(bucket_size, dtype=torch.uint8, device=self.device_manager.device_type) + ) + # p2p store need to register h2d_buffer to let other ranks read + if ranks: + h2d_buffer_name = "__h2d_buffer__" + if h2d_buffer is not None and self._p2p_store is not None: + self._p2p_store.register_named_tensors({h2d_buffer_name: h2d_buffer}) + receiver_rank_buckets: list[tuple[int, H2DBucket]] = [] + for receiver_rank, owner_rank, bucket in buckets: + if receiver_rank != self._rank: + continue + receiver_rank_buckets.append((owner_rank, bucket)) + buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type) + buckets_by_receiver_rank: dict[int, list[H2DBucket]] = defaultdict(list) + + max_len = 0 + for receiver_rank, _, bucket in buckets: + buckets_by_receiver_rank[receiver_rank].append(bucket) + if len(buckets_by_receiver_rank[receiver_rank]) > max_len: + max_len = len(buckets_by_receiver_rank[receiver_rank]) + gidx = 0 + metadata: list[ParameterMeta] + try: + for i in range(max_len): + if i < len(receiver_rank_buckets) and not disable_h2d_buffer: + self._copy_to_buffer( + checkpoint_name, + receiver_rank_buckets[i][1], + h2d_buffer, + receiver_rank_buckets[i][0] if ranks else None, + ) + for receiver_rank, _buckets in buckets_by_receiver_rank.items(): + if i >= len(_buckets): + continue + bucket = _buckets[i] + start = gidx % 2 * bucket_size + buffer_b: torch.Tensor = buffer[start : start + bucket.size] + if receiver_rank == self._rank: + if disable_h2d_buffer: + self._copy_to_buffer(checkpoint_name, bucket, buffer_b) + else: + buffer_b.data.copy_(h2d_buffer[: bucket.size]) + broadcast_op = BroadcastOperation( + rank=receiver_rank, + ranks_group=ranks_group, + bucket=buffer_b, + metadata=bucket.items, + ) + if gidx == 0: + metadata = await broadcast_op.wait_for_complete() + gidx += 1 + continue + meta_list = _to_named_tensor(metadata, (gidx - 1) % 2 * bucket_size) + for item in meta_list: + shape = item["shape"] + if isinstance(shape, list | tuple): + shape = torch.Size(shape) + assert isinstance(shape, torch.Size) + dtype, offset = item["dtype"], item["offset"] + size = dtype.itemsize * shape.numel() + tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape) + yield item["name"], tensor + metadata = await broadcast_op.wait_for_complete() + self.device_manager.device_module.synchronize() + gidx += 1 + + meta_list = _to_named_tensor(metadata, (gidx - 1) % 2 * bucket_size) + for item in meta_list: + shape = item["shape"] + if isinstance(shape, list | tuple): + shape = torch.Size(shape) + assert isinstance(shape, torch.Size) + dtype, offset = item["dtype"], item["offset"] + size = dtype.itemsize * shape.numel() + tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape) + yield item["name"], tensor + + finally: + dist.barrier(group=ranks_group) + if ranks and h2d_buffer is not None: + self._p2p_store.unregister_named_tensors([h2d_buffer_name]) + self.device_manager.device_module.empty_cache() + + +@dataclass +class MasterMetadata: + ip: str + port: int + + +class BroadcastOperation: + """Async broadcast operation with NCCL in separate thread. + + Args: + rank (int): The rank of the current process. + ranks_group (int): The process group's value. + bucket (torch.Tensor): The tensor to broadcast. + metadata (list[ParameterMeta]): The metadata of the tensor. + """ + + def __init__( + self, + rank: int, + ranks_group: int, + bucket: torch.Tensor, + metadata: list[ParameterMeta], + ) -> None: + self.rank = rank + self.ranks_group = ranks_group + self.bucket = bucket + self.metadata = metadata + + loop = asyncio.get_running_loop() + self._task = loop.run_in_executor(None, self._run) + + def _run(self): + # broadcast tensor + dist.broadcast(self.bucket, src=self.rank, group=self.ranks_group) + + async def wait_for_complete(self) -> list[ParameterMeta]: + """Wait for the broadcast operation to complete. + + Returns: + list[ParameterMeta]: The bucket meta after broadcast. + """ + await self._task + return self.metadata + + +@CheckpointEngineRegistry.register("kimi_ckpt_engine") +class KIMICheckpointEngine(CheckpointEngine): + """NCCL checkpoint engine with collective communication. + + Args: + bucket_size (int): Bucket size in bytes to transfer multiple weights at one time. Note that we use + two buffer to send and recv weights at same time, so the device memory overhead is 2 * bucket_size. + rebuild_group (bool): Whether to rebuild the NCCL process group in each update. Defaults to False. + is_master (bool): Whether the current process is the master process. Defaults to False. + rollout_dtype (torch.dtype): The dtype of the weights received from rollout workers. Defaults to torch.bfloat16. + """ + + def __init__( + self, + train_world_size: int, + rollout_world_size: int, + bucket_size: int, + rebuild_group: bool = False, + is_master: bool = False, + rollout_dtype: torch.dtype = torch.bfloat16, + ) -> None: + self.train_world_size = train_world_size + self.rollout_world_size = rollout_world_size + self.world_size = train_world_size + rollout_world_size + + self.bucket_size = bucket_size + self.rebuild_group = rebuild_group + self.rollout_dtype = rollout_dtype + self.is_master = is_master + self.initialized = False + self.checkpoint_name = "kimi_checkpoint_engine" + + def prepare(self) -> MasterMetadata: + if self.is_master: + self.ip = ray.util.get_node_ip_address().strip("[]") + self.listen_port, _ = get_free_port(self.ip) + + return MasterMetadata(ip=self.ip, port=self.listen_port) if self.is_master else None + + def finish(self): + """Destroy the ckpt engine process group if rebuild_group is True.""" + if self.rebuild_group: + dist.destroy_process_group() + self.rank = None + self.world_size = None + self.initialized = False + + def init_process_group(self, rank: int, world_size: int, master_metadata: MasterMetadata): + """Initialize the ckpt engine process group. + + Args: + rank (int): The rank of the current process. + world_size (int): The total number of processes. + """ + self.rank = rank + # unregister_memory in transfer engine is not supported on NPU, + # so we have to initialize ParameterServer each time + if get_device_name() == "npu" or not self.initialized: + self.parameter_server = ParameterServer(rank=rank, world_size=world_size, auto_pg=False, custom_dist=True) + self.parameter_server.receive_tensor = types.MethodType(receive_tensor, self.parameter_server) + if not self.initialized: + dist.init_process_group( + host=master_metadata.ip, + port=master_metadata.port, + rank=rank, + world_size=world_size, + backend=get_nccl_backend(), + ) + + self.rollout_ranks = list(range(self.train_world_size, world_size)) + self.rollout_group = dist.new_group(self.rollout_ranks) + self.initialized = True + + @torch.no_grad() + async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]): + """Send the weights of the model. + + Args: + weights: A generator that yields the name of the weight tensor and the tensor itself. + """ + + def offload_cpu(named_tensors: dict[str, torch.Tensor], name: str, tensor: torch.Tensor): + named_tensors[name] = tensor.to("cpu", non_blocking=True) + + start_time = time.time() + named_tensors = {} + for named_tensors_gpu in ckpt_get_named_tensor_buckets( + weights, self.bucket_size, self.train_world_size, self.rank, self.rollout_dtype + ): + with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: + futures = [ + executor.submit( + offload_cpu, + named_tensors, + name, + tensor, + ) + for name, tensor in named_tensors_gpu.items() + ] + for future in concurrent.futures.as_completed(futures): + future.result() + + get_torch_device().synchronize() + + self.parameter_server.register_checkpoint(self.checkpoint_name, named_tensors=named_tensors) + named_tensors = {} + torch.cuda.empty_cache() + logger.info(f"Rank {self.rank} offload and register, time cost: {time.time() - start_time:.2f}s") + + self.parameter_server.gather_metas(self.checkpoint_name) + dist.barrier() + self.parameter_server.unregister_checkpoint(self.checkpoint_name) + logger.info(f"Rank {self.rank} send weights done, time cost: {time.time() - start_time:.2f}s") + + @torch.no_grad() + async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]: + """Receive the weights of the model. + + Yields: + A tuple of the name of the weight tensor and the tensor itself. + """ + self.parameter_server.gather_metas(self.checkpoint_name) + + start_time = time.time() + total_bytes, total_params = 0, 0 + async for name, tensor in self.parameter_server.receive_tensor( + self.checkpoint_name, self.rollout_group, self.rollout_ranks, self.bucket_size + ): + total_bytes += tensor.element_size() * tensor.nelement() + total_params += 1 + yield name, tensor + dist.barrier() + time_cost = time.time() - start_time + bandwidth = total_bytes / time_cost / (1024 * 1024 * 1024) + logger.info( + f"Rank {self.rank} receive weights done, total_params: {total_params}, " + f"time cost: {time_cost:.2f}s, bandwidth: {bandwidth:.2f} GB/s" + )