diff --git a/tests/checkpoint_engine/test_kimi_checkpoint_engine.py b/tests/checkpoint_engine/test_kimi_checkpoint_engine.py
new file mode 100644
index 00000000000..4eb461e4a14
--- /dev/null
+++ b/tests/checkpoint_engine/test_kimi_checkpoint_engine.py
@@ -0,0 +1,121 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+
+import pytest
+import ray
+
+from tests.checkpoint_engine.test_utils import create_rollout_worker_group, create_trainer_worker_group
+from verl.single_controller.ray.base import (
+ RayResourcePool,
+ split_resource_pool,
+)
+from verl.utils.device import get_device_name
+
+
+@pytest.mark.parametrize("rebuild_group", [False, True])
+@pytest.mark.parametrize("num_trainer, num_rollout", [(2, 6)])
+def test_kimi_checkpoint_engine(
+ rebuild_group,
+ num_trainer,
+ num_rollout,
+ num_nodes=1,
+ num_gpus_per_node=8,
+ check_allclose=True,
+ model_path="~/models/Qwen/Qwen3-8B-Base",
+):
+ model_path = os.path.expanduser(model_path)
+ ray.init(
+ runtime_env={
+ "env_vars": {
+ "UCX_TLS": "rc,tcp,cuda",
+ "UCX_MAX_RNDV_RAILS": "4",
+ "UCX_LOG_LEVEL": "INFO",
+ "VERL_LOGGING_LEVEL": "DEBUG",
+ "NCCL_IB_HCA": "mlx5",
+ "ASCEND_USE_SHORT_CONNECTION": "1",
+ }
+ }
+ )
+
+ resource_pool = RayResourcePool(process_on_nodes=[num_gpus_per_node] * num_nodes, max_colocate_count=3)
+ resource_pool.get_placement_groups(device_name=get_device_name())
+ trainer_pool, rollout_pool = split_resource_pool(resource_pool, [num_trainer, num_rollout])
+ checkpoint_kwargs = {
+ "train_world_size": num_trainer,
+ "rollout_world_size": num_rollout,
+ "bucket_size": 2 * 1024 * 1024 * 1024, # 2GB
+ "rebuild_group": rebuild_group,
+ }
+
+ trainer = create_trainer_worker_group(model_path, trainer_pool, "kimi_ckpt_engine", checkpoint_kwargs)
+ trainer.reset()
+ rollout = create_rollout_worker_group(
+ model_path,
+ rollout_pool,
+ "kimi_ckpt_engine",
+ checkpoint_kwargs,
+ check_allclose=check_allclose,
+ )
+
+ world_size = trainer.world_size + rollout.world_size
+ for _ in range(3):
+ # 1. prepare all workers
+ metadata = ray.get(
+ trainer.execute_checkpoint_engine(["prepare"] * trainer.world_size)
+ + rollout.execute_checkpoint_engine(["prepare"] * rollout.world_size)
+ )
+ trainer_kwargs = {
+ "method": ["init_process_group"] * trainer.world_size,
+ "rank": list(range(0, trainer.world_size)),
+ "world_size": [world_size] * trainer.world_size,
+ "master_metadata": [metadata[0]] * trainer.world_size,
+ }
+ rollout_kwargs = {
+ "method": ["init_process_group"] * rollout.world_size,
+ "rank": list(range(trainer.world_size, world_size)),
+ "world_size": [world_size] * rollout.world_size,
+ "master_metadata": [metadata[0]] * rollout.world_size,
+ }
+
+ # 2. init process group between all workers
+ ray.get(
+ trainer.execute_checkpoint_engine(**trainer_kwargs) + rollout.execute_checkpoint_engine(**rollout_kwargs)
+ )
+
+ # 3. update weights of all workers
+ ray.get(trainer.update_weights() + rollout.update_weights())
+
+ # 4. finish all workers
+ ray.get(
+ trainer.execute_checkpoint_engine(["finish"] * trainer.world_size)
+ + rollout.execute_checkpoint_engine(["finish"] * rollout.world_size)
+ )
+
+ # 5. check weights of rollout workers
+ rollout.check_weights()
+
+ ray.shutdown()
+
+
+if __name__ == "__main__":
+ test_kimi_checkpoint_engine(
+ rebuild_group=False,
+ num_trainer=4,
+ num_rollout=28,
+ num_nodes=2,
+ num_gpus_per_node=16,
+ check_allclose=False,
+ model_path=os.environ["HDFS_ROOT"] + "/model/Qwen3-30B-A3B-Base",
+ )
diff --git a/tests/checkpoint_engine/test_utils.py b/tests/checkpoint_engine/test_utils.py
index 4e18b227d09..ec2178d484d 100644
--- a/tests/checkpoint_engine/test_utils.py
+++ b/tests/checkpoint_engine/test_utils.py
@@ -29,7 +29,7 @@ class TrainingWorkerTest(TrainingWorker):
def __init__(self, config: TrainingWorkerConfig, checkpoint_backend: str, checkpoint_kwargs: dict) -> None:
copy_to_local(config.model_config.path)
super().__init__(config)
- if torch.distributed.get_rank() == 0 and checkpoint_backend in ["nccl", "hccl"]:
+ if torch.distributed.get_rank() == 0 and checkpoint_backend in ["nccl", "hccl", "kimi_ckpt_engine"]:
checkpoint_kwargs["is_master"] = True
self.checkpoint_engine = CheckpointEngineRegistry.new(checkpoint_backend, **checkpoint_kwargs)
diff --git a/verl/checkpoint_engine/README.md b/verl/checkpoint_engine/README.md
index 2318dd9477d..ed335e1cbba 100644
--- a/verl/checkpoint_engine/README.md
+++ b/verl/checkpoint_engine/README.md
@@ -18,16 +18,20 @@ Checkpoint Engine is an unified abstract layer to synchronize weights between va
|nccl|NCCL|all_gather+broadcast|NVIDIA GPU & NCCL|Very High|Low: rebuild nccl group|Off-policy training
- Trainer/rollout disaggregated
- Fixed clusters
|hccl|HCCL|all_gather+broadcast|Ascend NPU & HCCL| High|Low: rebuild hccl group|Off-policy training
- Trainer/rollout disaggregated
- Fixed clusters
|nixl|NIXL|all_gather+ring p2p|Various transport backends (D2D, H2H, H2D, etc)
- UCX
- UCCL
- Mooncacke|Medium/High|High: dynamic adjust ring topology|Off-policy training
- Trainer/rollout disaggregated
- Elastic rollout
- Rollout fault tolerance
- Heterogeneous hardware rollout
+|kimi_ckpt_engine|MOONCAKE+NCCL/HCCL|p2p+broadcast|NVIDIA/Ascend|High|Low: rebuild communication group|Off-policy training
- Trainer/rollout disaggregated
- Save checkpoint each time
+
+PS: kimi_ckpt_engine first offloads all weights to the CPU. Then, using Mooncake transfer engine, these weights are transmitted via P2P to a specific worker in the rollout, followed by a broadcast to all other rollout workers.
### Benchmark
1. benchmark setup
- model: Qwen/Qwen3-30B-A3B-Base
-- trainer: fsdp world_size=2
+- trainer: fsdp world_size=2 (since Ascend 910C has 64GB of HBM, we set world_size=4)
- rollout: num_rollout=30 (only receive weight without cuda ipc to vllm/sglang)
```bash
python3 tests/checkpoint_engine/test_nixl_checkpoint_engine.py
python3 tests/checkpoint_engine/test_nccl_checkpoint_engine.py
python3 tests/checkpoint_engine/test_hccl_checkpoint_engine.py
+python3 tests/checkpoint_engine/test_kimi_checkpoint_engine.py
```
2. benchmark result
@@ -36,4 +40,5 @@ python3 tests/checkpoint_engine/test_hccl_checkpoint_engine.py
|----|----|----|----|
|4*8 H100, ConnectX-7 400 Gbps (InfiniBand)| NCCL | ~7 | 8.25|
|4*8 H100, ConnectX-7 400 Gbps (InfiniBand)| NIXL | ~7 | 8.25|
-|2*16 Ascend 910C, inner suppernode| HCCL | ~11 | 5.3|
\ No newline at end of file
+|2*16 Ascend 910C, inner suppernode| HCCL | ~11 | 5.3|
+|2*16 Ascend 910C, inner suppernode| kimi_ckpt_engine | offload: 7 update: 3.5 | 16.5|
diff --git a/verl/checkpoint_engine/__init__.py b/verl/checkpoint_engine/__init__.py
index 3c20eb35401..c36292621fd 100644
--- a/verl/checkpoint_engine/__init__.py
+++ b/verl/checkpoint_engine/__init__.py
@@ -37,3 +37,10 @@
__all__ += ["NIXLCheckpointEngine"]
except ImportError:
NIXLCheckpointEngine = None
+
+try:
+ from .kimi_checkpoint_engine import KIMICheckpointEngine
+
+ __all__ += ["KIMICheckpointEngine"]
+except ImportError:
+ KIMICheckpointEngine = None
diff --git a/verl/checkpoint_engine/kimi_checkpoint_engine.py b/verl/checkpoint_engine/kimi_checkpoint_engine.py
new file mode 100644
index 00000000000..be6f92f35d2
--- /dev/null
+++ b/verl/checkpoint_engine/kimi_checkpoint_engine.py
@@ -0,0 +1,357 @@
+# Copyright 2024 Bytedance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import asyncio
+import concurrent.futures
+import logging
+import os
+import time
+import types
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import AsyncGenerator, Generator
+
+import checkpoint_engine.distributed as dist
+import ray
+import torch
+from checkpoint_engine.ps import H2DBucket, ParameterMeta, ParameterServer, _gen_h2d_buckets, _to_named_tensor
+
+from verl.checkpoint_engine.base import CheckpointEngine, CheckpointEngineRegistry
+from verl.utils.device import get_device_name, get_nccl_backend, get_torch_device
+from verl.utils.net_utils import get_free_port
+
+logger = logging.getLogger(__name__)
+logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
+
+
+def ckpt_get_named_tensor_buckets(
+ iterable: Generator[tuple[str, torch.Tensor], None, None],
+ bucket_bytes: int,
+ world_size: int,
+ rank_id: int,
+ rollout_dtype: torch.dtype = torch.bfloat16,
+) -> dict[str, torch.Tensor]:
+ if bucket_bytes <= 0:
+ raise ValueError(f"bucket_bytes must be greater than 0, got {bucket_bytes}")
+
+ current_bucket = {}
+ current_size = 0
+ for tensor_idx, (name, tensor) in enumerate(iterable):
+ tensor = tensor.to(rollout_dtype)
+ if tensor_idx % world_size == rank_id:
+ tensor_size = tensor.element_size() * tensor.numel()
+ if current_size + tensor_size > bucket_bytes:
+ if current_bucket:
+ yield current_bucket
+ current_bucket = {}
+ current_size = 0
+
+ current_bucket[name] = tensor
+ current_size += tensor_size
+
+ if current_bucket:
+ yield current_bucket
+
+
+async def receive_tensor(
+ self,
+ checkpoint_name: str,
+ ranks_group: int,
+ ranks: list[int] | None = None,
+ bucket_size: int = 2 << 30,
+ disable_h2d_buffer: bool = False,
+) -> AsyncGenerator[tuple[str, torch.Tensor], None]:
+ assert len(self._current_global_parameter_metas) != 0, "parameter metas is empty"
+ assert dist.is_initialized(), "process group is not initialized"
+ assert self._p2p_store is not None, "p2p store is not initialized"
+ assert ranks, "ranks should be set"
+
+ # first execute a barrier to avoid subsequent device oom
+ dist.barrier(group=ranks_group)
+ buckets = _gen_h2d_buckets(
+ self._current_global_parameter_metas,
+ bucket_size,
+ self._local_rdma_devices,
+ self._remote_rdma_devices,
+ ranks,
+ )
+ h2d_buffer: torch.Tensor | None = (
+ None
+ if disable_h2d_buffer
+ else torch.empty(bucket_size, dtype=torch.uint8, device=self.device_manager.device_type)
+ )
+ # p2p store need to register h2d_buffer to let other ranks read
+ if ranks:
+ h2d_buffer_name = "__h2d_buffer__"
+ if h2d_buffer is not None and self._p2p_store is not None:
+ self._p2p_store.register_named_tensors({h2d_buffer_name: h2d_buffer})
+ receiver_rank_buckets: list[tuple[int, H2DBucket]] = []
+ for receiver_rank, owner_rank, bucket in buckets:
+ if receiver_rank != self._rank:
+ continue
+ receiver_rank_buckets.append((owner_rank, bucket))
+ buffer = torch.empty(bucket_size * 2, dtype=torch.uint8, device=self.device_manager.device_type)
+ buckets_by_receiver_rank: dict[int, list[H2DBucket]] = defaultdict(list)
+
+ max_len = 0
+ for receiver_rank, _, bucket in buckets:
+ buckets_by_receiver_rank[receiver_rank].append(bucket)
+ if len(buckets_by_receiver_rank[receiver_rank]) > max_len:
+ max_len = len(buckets_by_receiver_rank[receiver_rank])
+ gidx = 0
+ metadata: list[ParameterMeta]
+ try:
+ for i in range(max_len):
+ if i < len(receiver_rank_buckets) and not disable_h2d_buffer:
+ self._copy_to_buffer(
+ checkpoint_name,
+ receiver_rank_buckets[i][1],
+ h2d_buffer,
+ receiver_rank_buckets[i][0] if ranks else None,
+ )
+ for receiver_rank, _buckets in buckets_by_receiver_rank.items():
+ if i >= len(_buckets):
+ continue
+ bucket = _buckets[i]
+ start = gidx % 2 * bucket_size
+ buffer_b: torch.Tensor = buffer[start : start + bucket.size]
+ if receiver_rank == self._rank:
+ if disable_h2d_buffer:
+ self._copy_to_buffer(checkpoint_name, bucket, buffer_b)
+ else:
+ buffer_b.data.copy_(h2d_buffer[: bucket.size])
+ broadcast_op = BroadcastOperation(
+ rank=receiver_rank,
+ ranks_group=ranks_group,
+ bucket=buffer_b,
+ metadata=bucket.items,
+ )
+ if gidx == 0:
+ metadata = await broadcast_op.wait_for_complete()
+ gidx += 1
+ continue
+ meta_list = _to_named_tensor(metadata, (gidx - 1) % 2 * bucket_size)
+ for item in meta_list:
+ shape = item["shape"]
+ if isinstance(shape, list | tuple):
+ shape = torch.Size(shape)
+ assert isinstance(shape, torch.Size)
+ dtype, offset = item["dtype"], item["offset"]
+ size = dtype.itemsize * shape.numel()
+ tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape)
+ yield item["name"], tensor
+ metadata = await broadcast_op.wait_for_complete()
+ self.device_manager.device_module.synchronize()
+ gidx += 1
+
+ meta_list = _to_named_tensor(metadata, (gidx - 1) % 2 * bucket_size)
+ for item in meta_list:
+ shape = item["shape"]
+ if isinstance(shape, list | tuple):
+ shape = torch.Size(shape)
+ assert isinstance(shape, torch.Size)
+ dtype, offset = item["dtype"], item["offset"]
+ size = dtype.itemsize * shape.numel()
+ tensor = buffer[offset : offset + size].view(dtype=dtype).view(shape)
+ yield item["name"], tensor
+
+ finally:
+ dist.barrier(group=ranks_group)
+ if ranks and h2d_buffer is not None:
+ self._p2p_store.unregister_named_tensors([h2d_buffer_name])
+ self.device_manager.device_module.empty_cache()
+
+
+@dataclass
+class MasterMetadata:
+ ip: str
+ port: int
+
+
+class BroadcastOperation:
+ """Async broadcast operation with NCCL in separate thread.
+
+ Args:
+ rank (int): The rank of the current process.
+ ranks_group (int): The process group's value.
+ bucket (torch.Tensor): The tensor to broadcast.
+ metadata (list[ParameterMeta]): The metadata of the tensor.
+ """
+
+ def __init__(
+ self,
+ rank: int,
+ ranks_group: int,
+ bucket: torch.Tensor,
+ metadata: list[ParameterMeta],
+ ) -> None:
+ self.rank = rank
+ self.ranks_group = ranks_group
+ self.bucket = bucket
+ self.metadata = metadata
+
+ loop = asyncio.get_running_loop()
+ self._task = loop.run_in_executor(None, self._run)
+
+ def _run(self):
+ # broadcast tensor
+ dist.broadcast(self.bucket, src=self.rank, group=self.ranks_group)
+
+ async def wait_for_complete(self) -> list[ParameterMeta]:
+ """Wait for the broadcast operation to complete.
+
+ Returns:
+ list[ParameterMeta]: The bucket meta after broadcast.
+ """
+ await self._task
+ return self.metadata
+
+
+@CheckpointEngineRegistry.register("kimi_ckpt_engine")
+class KIMICheckpointEngine(CheckpointEngine):
+ """NCCL checkpoint engine with collective communication.
+
+ Args:
+ bucket_size (int): Bucket size in bytes to transfer multiple weights at one time. Note that we use
+ two buffer to send and recv weights at same time, so the device memory overhead is 2 * bucket_size.
+ rebuild_group (bool): Whether to rebuild the NCCL process group in each update. Defaults to False.
+ is_master (bool): Whether the current process is the master process. Defaults to False.
+ rollout_dtype (torch.dtype): The dtype of the weights received from rollout workers. Defaults to torch.bfloat16.
+ """
+
+ def __init__(
+ self,
+ train_world_size: int,
+ rollout_world_size: int,
+ bucket_size: int,
+ rebuild_group: bool = False,
+ is_master: bool = False,
+ rollout_dtype: torch.dtype = torch.bfloat16,
+ ) -> None:
+ self.train_world_size = train_world_size
+ self.rollout_world_size = rollout_world_size
+ self.world_size = train_world_size + rollout_world_size
+
+ self.bucket_size = bucket_size
+ self.rebuild_group = rebuild_group
+ self.rollout_dtype = rollout_dtype
+ self.is_master = is_master
+ self.initialized = False
+ self.checkpoint_name = "kimi_checkpoint_engine"
+
+ def prepare(self) -> MasterMetadata:
+ if self.is_master:
+ self.ip = ray.util.get_node_ip_address().strip("[]")
+ self.listen_port, _ = get_free_port(self.ip)
+
+ return MasterMetadata(ip=self.ip, port=self.listen_port) if self.is_master else None
+
+ def finish(self):
+ """Destroy the ckpt engine process group if rebuild_group is True."""
+ if self.rebuild_group:
+ dist.destroy_process_group()
+ self.rank = None
+ self.world_size = None
+ self.initialized = False
+
+ def init_process_group(self, rank: int, world_size: int, master_metadata: MasterMetadata):
+ """Initialize the ckpt engine process group.
+
+ Args:
+ rank (int): The rank of the current process.
+ world_size (int): The total number of processes.
+ """
+ self.rank = rank
+ # unregister_memory in transfer engine is not supported on NPU,
+ # so we have to initialize ParameterServer each time
+ if get_device_name() == "npu" or not self.initialized:
+ self.parameter_server = ParameterServer(rank=rank, world_size=world_size, auto_pg=False, custom_dist=True)
+ self.parameter_server.receive_tensor = types.MethodType(receive_tensor, self.parameter_server)
+ if not self.initialized:
+ dist.init_process_group(
+ host=master_metadata.ip,
+ port=master_metadata.port,
+ rank=rank,
+ world_size=world_size,
+ backend=get_nccl_backend(),
+ )
+
+ self.rollout_ranks = list(range(self.train_world_size, world_size))
+ self.rollout_group = dist.new_group(self.rollout_ranks)
+ self.initialized = True
+
+ @torch.no_grad()
+ async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None, None]):
+ """Send the weights of the model.
+
+ Args:
+ weights: A generator that yields the name of the weight tensor and the tensor itself.
+ """
+
+ def offload_cpu(named_tensors: dict[str, torch.Tensor], name: str, tensor: torch.Tensor):
+ named_tensors[name] = tensor.to("cpu", non_blocking=True)
+
+ start_time = time.time()
+ named_tensors = {}
+ for named_tensors_gpu in ckpt_get_named_tensor_buckets(
+ weights, self.bucket_size, self.train_world_size, self.rank, self.rollout_dtype
+ ):
+ with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
+ futures = [
+ executor.submit(
+ offload_cpu,
+ named_tensors,
+ name,
+ tensor,
+ )
+ for name, tensor in named_tensors_gpu.items()
+ ]
+ for future in concurrent.futures.as_completed(futures):
+ future.result()
+
+ get_torch_device().synchronize()
+
+ self.parameter_server.register_checkpoint(self.checkpoint_name, named_tensors=named_tensors)
+ named_tensors = {}
+ torch.cuda.empty_cache()
+ logger.info(f"Rank {self.rank} offload and register, time cost: {time.time() - start_time:.2f}s")
+
+ self.parameter_server.gather_metas(self.checkpoint_name)
+ dist.barrier()
+ self.parameter_server.unregister_checkpoint(self.checkpoint_name)
+ logger.info(f"Rank {self.rank} send weights done, time cost: {time.time() - start_time:.2f}s")
+
+ @torch.no_grad()
+ async def receive_weights(self) -> AsyncGenerator[tuple[str, torch.Tensor], None]:
+ """Receive the weights of the model.
+
+ Yields:
+ A tuple of the name of the weight tensor and the tensor itself.
+ """
+ self.parameter_server.gather_metas(self.checkpoint_name)
+
+ start_time = time.time()
+ total_bytes, total_params = 0, 0
+ async for name, tensor in self.parameter_server.receive_tensor(
+ self.checkpoint_name, self.rollout_group, self.rollout_ranks, self.bucket_size
+ ):
+ total_bytes += tensor.element_size() * tensor.nelement()
+ total_params += 1
+ yield name, tensor
+ dist.barrier()
+ time_cost = time.time() - start_time
+ bandwidth = total_bytes / time_cost / (1024 * 1024 * 1024)
+ logger.info(
+ f"Rank {self.rank} receive weights done, total_params: {total_params}, "
+ f"time cost: {time_cost:.2f}s, bandwidth: {bandwidth:.2f} GB/s"
+ )