Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions tests/checkpoint_engine/test_kimi_checkpoint_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import pytest
import ray

from tests.checkpoint_engine.test_utils import create_rollout_worker_group, create_trainer_worker_group
from verl.single_controller.ray.base import (
RayResourcePool,
split_resource_pool,
)
from verl.utils.device import get_device_name


@pytest.mark.parametrize("rebuild_group", [False, True])
@pytest.mark.parametrize("num_trainer, num_rollout", [(2, 6)])
def test_kimi_checkpoint_engine(
rebuild_group,
num_trainer,
num_rollout,
num_nodes=1,
num_gpus_per_node=8,
check_allclose=True,
model_path="~/models/Qwen/Qwen3-8B-Base",
):
model_path = os.path.expanduser(model_path)
ray.init(
runtime_env={
"env_vars": {
"UCX_TLS": "rc,tcp,cuda",
"UCX_MAX_RNDV_RAILS": "4",
"UCX_LOG_LEVEL": "INFO",
"VERL_LOGGING_LEVEL": "DEBUG",
"NCCL_IB_HCA": "mlx5",
"ASCEND_USE_SHORT_CONNECTION": "1",
}
}
)

resource_pool = RayResourcePool(process_on_nodes=[num_gpus_per_node] * num_nodes, max_colocate_count=3)
resource_pool.get_placement_groups(device_name=get_device_name())
trainer_pool, rollout_pool = split_resource_pool(resource_pool, [num_trainer, num_rollout])
checkpoint_kwargs = {
"train_world_size": num_trainer,
"rollout_world_size": num_rollout,
"bucket_size": 2 * 1024 * 1024 * 1024, # 2GB
"rebuild_group": rebuild_group,
}

trainer = create_trainer_worker_group(model_path, trainer_pool, "kimi_ckpt_engine", checkpoint_kwargs)
trainer.reset()
rollout = create_rollout_worker_group(
model_path,
rollout_pool,
"kimi_ckpt_engine",
checkpoint_kwargs,
check_allclose=check_allclose,
)

world_size = trainer.world_size + rollout.world_size
for _ in range(3):
# 1. prepare all workers
metadata = ray.get(
trainer.execute_checkpoint_engine(["prepare"] * trainer.world_size)
+ rollout.execute_checkpoint_engine(["prepare"] * rollout.world_size)
)
trainer_kwargs = {
"method": ["init_process_group"] * trainer.world_size,
"rank": list(range(0, trainer.world_size)),
"world_size": [world_size] * trainer.world_size,
"master_metadata": [metadata[0]] * trainer.world_size,
}
rollout_kwargs = {
"method": ["init_process_group"] * rollout.world_size,
"rank": list(range(trainer.world_size, world_size)),
"world_size": [world_size] * rollout.world_size,
"master_metadata": [metadata[0]] * rollout.world_size,
}

# 2. init process group between all workers
ray.get(
trainer.execute_checkpoint_engine(**trainer_kwargs) + rollout.execute_checkpoint_engine(**rollout_kwargs)
)

# 3. update weights of all workers
ray.get(trainer.update_weights() + rollout.update_weights())

# 4. finish all workers
ray.get(
trainer.execute_checkpoint_engine(["finish"] * trainer.world_size)
+ rollout.execute_checkpoint_engine(["finish"] * rollout.world_size)
)

# 5. check weights of rollout workers
rollout.check_weights()

ray.shutdown()


if __name__ == "__main__":
test_kimi_checkpoint_engine(
rebuild_group=False,
num_trainer=4,
num_rollout=28,
num_nodes=2,
num_gpus_per_node=16,
check_allclose=False,
model_path=os.environ["HDFS_ROOT"] + "/model/Qwen3-30B-A3B-Base",
)
2 changes: 1 addition & 1 deletion tests/checkpoint_engine/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class TrainingWorkerTest(TrainingWorker):
def __init__(self, config: TrainingWorkerConfig, checkpoint_backend: str, checkpoint_kwargs: dict) -> None:
copy_to_local(config.model_config.path)
super().__init__(config)
if torch.distributed.get_rank() == 0 and checkpoint_backend in ["nccl", "hccl"]:
if torch.distributed.get_rank() == 0 and checkpoint_backend in ["nccl", "hccl", "kimi_ckpt_engine"]:
checkpoint_kwargs["is_master"] = True
self.checkpoint_engine = CheckpointEngineRegistry.new(checkpoint_backend, **checkpoint_kwargs)

Expand Down
9 changes: 7 additions & 2 deletions verl/checkpoint_engine/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,20 @@ Checkpoint Engine is an unified abstract layer to synchronize weights between va
|nccl|NCCL|all_gather+broadcast|NVIDIA GPU & NCCL|Very High|Low: rebuild nccl group|Off-policy training<br>- Trainer/rollout disaggregated<br>- Fixed clusters
|hccl|HCCL|all_gather+broadcast|Ascend NPU & HCCL| High|Low: rebuild hccl group|Off-policy training<br>- Trainer/rollout disaggregated<br>- Fixed clusters
|nixl|NIXL|all_gather+ring p2p|Various transport backends (D2D, H2H, H2D, etc)<br>- UCX<br>- UCCL<br>- Mooncacke|Medium/High|High: dynamic adjust ring topology|Off-policy training<br>- Trainer/rollout disaggregated<br>- Elastic rollout<br>- Rollout fault tolerance<br>- Heterogeneous hardware rollout
|kimi_ckpt_engine|MOONCAKE+NCCL/HCCL|p2p+broadcast|NVIDIA/Ascend|High|Low: rebuild communication group|Off-policy training<br>- Trainer/rollout disaggregated<br>- Save checkpoint each time

PS: kimi_ckpt_engine first offloads all weights to the CPU. Then, using Mooncake transfer engine, these weights are transmitted via P2P to a specific worker in the rollout, followed by a broadcast to all other rollout workers.

### Benchmark
1. benchmark setup
- model: Qwen/Qwen3-30B-A3B-Base
- trainer: fsdp world_size=2
- trainer: fsdp world_size=2 (since Ascend 910C has 64GB of HBM, we set world_size=4)
- rollout: num_rollout=30 (only receive weight without cuda ipc to vllm/sglang)
```bash
python3 tests/checkpoint_engine/test_nixl_checkpoint_engine.py
python3 tests/checkpoint_engine/test_nccl_checkpoint_engine.py
python3 tests/checkpoint_engine/test_hccl_checkpoint_engine.py
python3 tests/checkpoint_engine/test_kimi_checkpoint_engine.py
```

2. benchmark result
Expand All @@ -36,4 +40,5 @@ python3 tests/checkpoint_engine/test_hccl_checkpoint_engine.py
|----|----|----|----|
|4*8 H100, ConnectX-7 400 Gbps (InfiniBand)| NCCL | ~7 | 8.25|
|4*8 H100, ConnectX-7 400 Gbps (InfiniBand)| NIXL | ~7 | 8.25|
|2*16 Ascend 910C, inner suppernode| HCCL | ~11 | 5.3|
|2*16 Ascend 910C, inner suppernode| HCCL | ~11 | 5.3|
|2*16 Ascend 910C, inner suppernode| kimi_ckpt_engine | offload: 7 update: 3.5 | 16.5|
7 changes: 7 additions & 0 deletions verl/checkpoint_engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,10 @@
__all__ += ["NIXLCheckpointEngine"]
except ImportError:
NIXLCheckpointEngine = None

try:
from .kimi_checkpoint_engine import KIMICheckpointEngine

__all__ += ["KIMICheckpointEngine"]
except ImportError:
KIMICheckpointEngine = None
Loading